diff --git a/friendli/auth.py b/friendli/auth.py index 928a1b54..c1cd0170 100644 --- a/friendli/auth.py +++ b/friendli/auth.py @@ -15,6 +15,7 @@ import friendli from friendli.di.injector import get_injector from friendli.errors import APIError, AuthorizationError, AuthTokenNotFoundError +from friendli.logging import logger from friendli.utils.fs import get_friendli_directory from friendli.utils.request import DEFAULT_REQ_TIMEOUT, decode_http_err from friendli.utils.url import URLProvider @@ -52,16 +53,30 @@ def get_auth_header( """ token_: Optional[str] + token_from_cfg = get_token(TokenType.ACCESS) + token_from_env = friendli.token + if token is not None: token_ = token - elif friendli.token: - token_ = friendli.token + elif token_from_env: + if token_from_cfg: + logger.warning( + "You've entered your login information in two places - through the " + "'FRIENDLI_TOKEN' environment variable and the 'friendli login' CLI " + "command. We will use the access token from the 'FRIENDLI_TOKEN' " + "environment variable and ignore the login session details. This might " + "lead to unexpected authorization errors. If you prefer to use the " + "login session instead, unset the 'FRIENDLI_TOKEN' environment " + "variable. If you don't want to see this warning again, run " + "'friendli logout' to remove the login session." + ) + token_ = token_from_env else: - token_ = get_token(TokenType.ACCESS) + token_ = token_from_cfg if token_ is None: raise AuthTokenNotFoundError( - "Should set FRIENDLI_TOKEN environment variable or sign in with 'friendli login'." + "Should set 'FRIENDLI_TOKEN' environment variable or sign in with 'friendli login'." ) headers = {"Authorization": f"Bearer {token_}"} diff --git a/friendli/cli/api/chat_completions.py b/friendli/cli/api/chat_completions.py index 5496af42..da840d2a 100644 --- a/friendli/cli/api/chat_completions.py +++ b/friendli/cli/api/chat_completions.py @@ -58,6 +58,16 @@ def create( min=1, help="The maximum number of tokens to generate.", ), + stop: Optional[List[str]] = typer.Option( + None, + "--stop", + "-S", + help=( + "When one of the stop phrases appears in the generation result, the API " + "will stop generation. The stop phrases are excluded from the result. " + "Repeat this option to use multiple stop phrases." + ), + ), temperature: Optional[float] = typer.Option( None, "--temperature", @@ -120,6 +130,7 @@ def create( presence_penalty=presence_penalty, max_tokens=max_tokens, n=n, + stop=stop, temperature=temperature, top_p=top_p, ) @@ -137,6 +148,7 @@ def create( presence_penalty=presence_penalty, max_tokens=max_tokens, n=n, + stop=stop, temperature=temperature, top_p=top_p, ) diff --git a/friendli/cli/api/completions.py b/friendli/cli/api/completions.py index d2df461c..230509ba 100644 --- a/friendli/cli/api/completions.py +++ b/friendli/cli/api/completions.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Optional +from typing import List, Optional import typer @@ -53,6 +53,16 @@ def create( min=1, help="The maximum number of tokens to generate.", ), + stop: Optional[List[str]] = typer.Option( + None, + "--stop", + "-S", + help=( + "When one of the stop phrases appears in the generation result, the API " + "will stop generation. The stop phrases are excluded from the result. " + "Repeat this option to use multiple stop phrases." + ), + ), temperature: Optional[float] = typer.Option( None, "--temperature", @@ -113,6 +123,7 @@ def create( presence_penalty=presence_penalty, max_tokens=max_tokens, n=n, + stop=stop, temperature=temperature, top_p=top_p, ) @@ -130,6 +141,7 @@ def create( presence_penalty=presence_penalty, max_tokens=max_tokens, n=n, + stop=stop, temperature=temperature, top_p=top_p, ) diff --git a/friendli/cli/login.py b/friendli/cli/login.py new file mode 100644 index 00000000..1c226ca8 --- /dev/null +++ b/friendli/cli/login.py @@ -0,0 +1,127 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""CLI command to sign in Friendli.""" + +from __future__ import annotations + +import threading +import time +import webbrowser +from contextlib import contextmanager +from typing import Iterator, Tuple + +import typer +import uvicorn +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import HTMLResponse + +from friendli.client.login import LoginClient +from friendli.di.injector import get_injector +from friendli.utils.url import URLProvider + +server_app = FastAPI() + + +@contextmanager +def run_server(port: int) -> Iterator[None]: + """Run temporary local server to handle SSO redirection.""" + config = uvicorn.Config( + app=server_app, host="127.0.0.1", port=port, log_level="error" + ) + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run) + thread.start() + try: + yield + finally: + server.should_exit = True + thread.join() + + +def oauth2_login() -> Tuple[str, str]: + """Login with SSO.""" + injector = get_injector() + url_provider = injector.get(URLProvider) + authorization_url = url_provider.get_suite_uri("/login/cli") + + access_token = None + refresh_token = None + + @server_app.get("/sso") + async def callback(request: Request) -> HTMLResponse: + nonlocal access_token + nonlocal refresh_token + + access_token = request.query_params.get("access_token") + refresh_token = request.query_params.get("refresh_token") + + if not access_token: + raise HTTPException( + status_code=400, detail="Access token not found in cookies" + ) + + success_page = r""" + + + + + SSO Login Success + + + +
+

Authentication was successful

+

You can now close this window and return to CLI.

+

Redirecting to Friendli Documentation in 10 seconds.

+
+ + + +""" + return HTMLResponse(content=success_page, status_code=200) + + typer.secho( + f"Opening browser for authentication: {authorization_url}", fg=typer.colors.BLUE + ) + + webbrowser.open(authorization_url) + + with run_server(33333): + while access_token is None or refresh_token is None: + time.sleep(1) + + return access_token, refresh_token + + +def pwd_login(email: str, pwd: str) -> Tuple[str, str]: + """Login with email and password.""" + client = LoginClient() + return client.login(email, pwd) diff --git a/friendli/cli/main.py b/friendli/cli/main.py index b52e23c9..949b3a9a 100644 --- a/friendli/cli/main.py +++ b/friendli/cli/main.py @@ -6,26 +6,18 @@ from __future__ import annotations -import requests import typer -from requests import HTTPError, Response -from friendli.auth import TokenType, clear_tokens, get_token, update_token +import friendli +from friendli.auth import TokenType, clear_tokens, update_token from friendli.cli import api, checkpoint -from friendli.client.project import ProjectClient -from friendli.client.user import UserClient, UserGroupClient, UserMFAClient -from friendli.context import ( - get_current_project_id, - project_context_path, - set_current_group_id, -) -from friendli.di.injector import get_injector +from friendli.cli.login import oauth2_login, pwd_login +from friendli.client.user import UserClient from friendli.errors import AuthTokenNotFoundError from friendli.formatter import PanelFormatter +from friendli.graphql.user import UserGqlClient from friendli.utils.decorator import check_api from friendli.utils.format import secho_error_and_exit -from friendli.utils.request import DEFAULT_REQ_TIMEOUT -from friendli.utils.url import URLProvider from friendli.utils.version import get_installed_version app = typer.Typer( @@ -60,7 +52,7 @@ def whoami(): """Show my user info.""" try: - client = UserClient() + client = UserGqlClient() info = client.get_current_user_info() except AuthTokenNotFoundError as exc: secho_error_and_exit(str(exc)) @@ -71,55 +63,27 @@ def whoami(): # @app.command() @check_api def login( - email: str = typer.Option(..., prompt="Enter your email"), - password: str = typer.Option(..., prompt="Enter your password", hide_input=True), + use_sso: bool = typer.Option(False, "--sso", help="Use SSO login."), ): - """Sign in.""" - injector = get_injector() - url_provider = injector.get(URLProvider) - r = requests.post( - url_provider.get_web_backend_uri("/api/auth/cli/access_token"), - json={"username": email, "password": password}, - timeout=DEFAULT_REQ_TIMEOUT, - ) - try: - resp = r.json() - except requests.exceptions.JSONDecodeError: - if r.status_code != 200: - secho_error_and_exit(r.content.decode()) - secho_error_and_exit("Invalid response format.") - - if "code" in resp and resp["code"] == "mfa_required": - mfa_token = resp["mfaToken"] - client = UserMFAClient() - # TODO: MFA type currently defaults to totp, need changes when new options are added - client.initiate_mfa(mfa_type="totp", mfa_token=mfa_token) - update_token(token_type=TokenType.MFA, token=mfa_token) - typer.run(_mfa_verify) + """Sign in Friendli.""" + if friendli.token: + typer.secho( + "You've already set the 'FRIENDLI_TOKEN' environment variable for " + "authentication, which takes precedence over the login session. Using both " + "methods of authentication simultaneously could lead to unexpected issues. " + "We suggest removing the 'FRIENDLI_TOKEN' environment variable if you " + "prefer to log in through the standard login session.", + fg=typer.colors.RED, + ) + + if use_sso: + access_token, refresh_token = oauth2_login() else: - _handle_login_response(r, False) + email = typer.prompt("Enter your email") + pwd = typer.prompt("Enter your password", hide_input=True) + access_token, refresh_token = pwd_login(email, pwd) - # Save user's organiztion context - project_client = ProjectClient() - user_group_client = UserGroupClient() - - try: - org = user_group_client.get_group_info() - except IndexError: - secho_error_and_exit("You are not included in any organization.") - org_id = org["id"] - - project_id = get_current_project_id() - if project_id is not None: - if project_client.check_project_membership(pf_project_id=project_id): - project_org_id = project_client.get_project(pf_project_id=project_id)[ - "pf_group_id" - ] - if project_org_id != org_id: - project_context_path.unlink(missing_ok=True) - else: - project_context_path.unlink(missing_ok=True) - set_current_group_id(org_id) + _display_login_success(access_token, refresh_token) # @app.command() @@ -163,42 +127,15 @@ def version(): typer.echo(installed_version) -def _mfa_verify(_, code: str = typer.Option(..., prompt="Enter MFA Code")): - injector = get_injector() - url_provider = injector.get(URLProvider) - - mfa_token = get_token(TokenType.MFA) - # TODO: MFA type currently defaults to totp, need changes when new options are added - mfa_type = "totp" - username = f"mfa://{mfa_type}/{mfa_token}" - r = requests.post( - url_provider.get_web_backend_uri("/api/auth/cli/access_token"), - json={"username": username, "password": code}, - timeout=DEFAULT_REQ_TIMEOUT, - ) - _handle_login_response(r, True) - - -def _handle_login_response(r: Response, mfa: bool): - try: - r.raise_for_status() - update_token(token_type=TokenType.ACCESS, token=r.json()["accessToken"]) - update_token(token_type=TokenType.REFRESH, token=r.json()["refreshToken"]) +def _display_login_success(access_token: str, refresh_token: str): + update_token(token_type=TokenType.ACCESS, token=access_token) + update_token(token_type=TokenType.REFRESH, token=refresh_token) - typer.echo("\n\nLogin success!") - typer.echo("Welcome back to...") - typography = r""" + typography = r""" _____ _ _ | ___|_ _(_) ___ _ __ _| || |(_) | |__ | '__| |/ _ \| '__ \/ _ || || | | __|| | | | __/| | | | (_) || || | |_| |_| |_|\___||_| |_|\___/ |_||_| """ - typer.secho(typography, fg=typer.colors.BLUE) - except HTTPError: - if mfa: - secho_error_and_exit("Login failed... Invalid MFA Code.") - else: - secho_error_and_exit( - "Login failed... Please check your email and password." - ) + typer.secho(f"\nLOGIN SUCCESS!\n{typography}", fg=typer.colors.BLUE) diff --git a/friendli/client/base.py b/friendli/client/base.py index e3d1c267..4b639f44 100644 --- a/friendli/client/base.py +++ b/friendli/client/base.py @@ -151,14 +151,14 @@ class Client(ABC, Generic[T], RequestInterface): def __init__(self, **kwargs): """Initialize client.""" - injector = get_injector() - self.url_provider = injector.get(URLProvider) - self.url_template = URLTemplate(self.url_path) + self.injector = get_injector() + self.url_provider = self.injector.get(URLProvider) + self.url_template = URLTemplate(Template(self.url_path)) self.url_kwargs = kwargs @property @abstractmethod - def url_path(self) -> Template: + def url_path(self) -> str: """URL path template to render.""" @property diff --git a/friendli/client/catalog.py b/friendli/client/catalog.py index 05939870..918b0dd1 100644 --- a/friendli/client/catalog.py +++ b/friendli/client/catalog.py @@ -4,7 +4,6 @@ from __future__ import annotations -from string import Template from typing import Any, Dict, List, Optional from uuid import UUID @@ -30,9 +29,9 @@ def __init__(self, **kwargs): super().__init__(**kwargs) @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_mr_uri("catalogs/")) + return self.url_provider.get_mr_uri("catalogs/") def get_catalog(self, catalog_id: UUID) -> Dict[str, Any]: """Get a public checkpoint in catalog.""" diff --git a/friendli/client/checkpoint.py b/friendli/client/checkpoint.py index 143ae026..20d41abd 100644 --- a/friendli/client/checkpoint.py +++ b/friendli/client/checkpoint.py @@ -6,7 +6,6 @@ from __future__ import annotations -from string import Template from typing import Any, Dict, List, Optional from uuid import UUID @@ -24,9 +23,9 @@ class CheckpointClient(Client[UUID]): """Checkpoint client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_mr_uri("models/")) + return self.url_provider.get_mr_uri("models/") def get_checkpoint(self, checkpoint_id: UUID) -> Dict[str, Any]: """Get a checkpoint info.""" @@ -57,9 +56,9 @@ class CheckpointFormClient(UploadableClient[UUID]): """Checkpoint form client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_mr_uri("model_forms/")) + return self.url_provider.get_mr_uri("model_forms/") def update_checkpoint_files( self, @@ -89,11 +88,9 @@ def __init__(self, **kwargs): super().__init__(group_id=self.group_id, project_id=self.project_id, **kwargs) @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template( - self.url_provider.get_mr_uri("orgs/$group_id/prjs/$project_id/models/") - ) + return self.url_provider.get_mr_uri("orgs/$group_id/prjs/$project_id/models/") def list_checkpoints( self, category: Optional[CheckpointCategory], limit: int, deleted: bool diff --git a/friendli/client/credential.py b/friendli/client/credential.py index cc08257e..ea6e7f66 100644 --- a/friendli/client/credential.py +++ b/friendli/client/credential.py @@ -5,7 +5,6 @@ from __future__ import annotations -from string import Template from typing import Any, Dict, Optional from uuid import UUID @@ -18,9 +17,9 @@ class CredentialClient(Client[UUID]): """Credential client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_auth_uri("credential")) + return self.url_provider.get_auth_uri("credential") def get_credential(self, credential_id: UUID) -> Dict[str, Any]: """Get a credential info.""" @@ -62,11 +61,9 @@ class CredentialTypeClient(Client): """Credential type client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template( - self.url_provider.get_training_uri("credential_type/") - ) # TODO: move this out of the training API + return self.url_provider.get_training_uri("credential_type/") def get_schema_by_type(self, cred_type: CredType) -> Optional[Dict[str, Any]]: """Get a credential JSON schema.""" diff --git a/friendli/client/deployment.py b/friendli/client/deployment.py index 70d8c351..211b30d3 100644 --- a/friendli/client/deployment.py +++ b/friendli/client/deployment.py @@ -5,7 +5,6 @@ from __future__ import annotations from datetime import datetime -from string import Template from typing import Any, Dict, List, Optional from friendli.client.base import Client, ProjectRequestMixin @@ -15,9 +14,9 @@ class DeploymentClient(Client[str]): """Deployment client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_serving_uri("deployment/")) + return self.url_provider.get_serving_uri("deployment/") def get_deployment(self, deployment_id: str) -> Dict[str, Any]: """Get a deployment info.""" @@ -78,11 +77,9 @@ class DeploymentLogClient(Client[str]): """Deployment log client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template( - self.url_provider.get_serving_uri("deployment/$deployment_id/log/") - ) + return self.url_provider.get_serving_uri("deployment/$deployment_id/log/") def get_deployment_logs(self, replica_index: int) -> List[Dict[str, Any]]: """Get logs from a deployment.""" @@ -97,11 +94,9 @@ class DeploymentMetricsClient(Client): """Deployment metrics client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template( - self.url_provider.get_serving_uri("deployment/$deployment_id/metrics/") - ) + return self.url_provider.get_serving_uri("deployment/$deployment_id/metrics/") def get_metrics( self, start: datetime, end: datetime, time_window: int @@ -123,11 +118,9 @@ class DeploymentEventClient(Client): """Deployment event client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template( - self.url_provider.get_serving_uri("deployment/$deployment_id/event/") - ) + return self.url_provider.get_serving_uri("deployment/$deployment_id/event/") def get_events(self) -> List[Dict[str, Any]]: """Get deployment events.""" @@ -141,12 +134,10 @@ class DeploymentReqRespClient(Client): """Deployment request-response client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template( - self.url_provider.get_serving_uri( - "deployment/$deployment_id/req_resp/download/" - ) + return self.url_provider.get_serving_uri( + "deployment/$deployment_id/req_resp/download/" ) def get_download_urls(self, start: datetime, end: datetime) -> List[Dict[str, str]]: @@ -171,11 +162,9 @@ def __init__(self, **kwargs): super().__init__(project_id=self.project_id, **kwargs) @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template( - self.url_provider.get_serving_uri("usage/project/$project_id/duration") - ) + return self.url_provider.get_serving_uri("usage/project/$project_id/duration") def get_project_deployment_durations( self, @@ -198,9 +187,9 @@ class PFSVMClient(Client): """VM client for serving.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_serving_uri("vm/")) + return self.url_provider.get_serving_uri("vm/") def list_vms(self) -> List[Dict[str, Any]]: """List all VM info.""" diff --git a/friendli/client/file.py b/friendli/client/file.py index eb2e5ae3..11b1273a 100644 --- a/friendli/client/file.py +++ b/friendli/client/file.py @@ -4,7 +4,6 @@ from __future__ import annotations -from string import Template from typing import Any, Dict from uuid import UUID @@ -20,9 +19,9 @@ class FileClient(Client[UUID]): """File client service.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_mr_uri("files/")) + return self.url_provider.get_mr_uri("files/") def get_misc_file_upload_url(self, misc_file_id: UUID) -> str: """Get an URL to upload file. @@ -77,11 +76,9 @@ def __init__(self, **kwargs): super().__init__(group_id=self.group_id, project_id=self.project_id, **kwargs) @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template( - self.url_provider.get_mr_uri("orgs/$group_id/prjs/$project_id/files/") - ) + return self.url_provider.get_mr_uri("orgs/$group_id/prjs/$project_id/files/") def create_misc_file(self, file_info: Dict[str, Any]) -> Dict[str, Any]: """Request to create a misc file. diff --git a/friendli/client/group.py b/friendli/client/group.py index 844baf82..5f0517b7 100644 --- a/friendli/client/group.py +++ b/friendli/client/group.py @@ -8,7 +8,6 @@ import json import uuid -from string import Template from typing import Any, Dict, List from friendli.client.base import Client, GroupRequestMixin @@ -18,9 +17,9 @@ class GroupClient(Client): """Organization client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_auth_uri("pf_group")) + return self.url_provider.get_auth_uri("pf_group") def create_group(self, name: str) -> Dict[str, Any]: """Create a new organization.""" @@ -62,11 +61,9 @@ def __init__(self, **kwargs): super().__init__(pf_group_id=self.group_id, **kwargs) @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template( - self.url_provider.get_auth_uri("pf_group/$pf_group_id/pf_project") - ) + return self.url_provider.get_auth_uri("pf_group/$pf_group_id/pf_project") def create_project(self, name: str) -> Dict[str, Any]: """Create a new project in the organization.""" diff --git a/friendli/client/login.py b/friendli/client/login.py new file mode 100644 index 00000000..ab40ea4f --- /dev/null +++ b/friendli/client/login.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Login Client.""" + +from __future__ import annotations + +from typing import Tuple + +from friendli.client.base import Client +from friendli.settings import Settings + + +class LoginClient(Client): + """Login client.""" + + @property + def url_path(self) -> str: + """Get an URL path.""" + return self.url_provider.get_web_backend_uri("/api/auth/login") + + def login(self, email: str, pwd: str) -> Tuple[str, str]: + """Send request to sign in with email and password.""" + settings = self.injector.get(Settings) + payload = { + "email": email, + "password": pwd, + } + headers = {"Accept": "application/json"} + resp = self.bare_post(json=payload, headers=headers) + cookies = resp.cookies + access_token = cookies[settings.access_token_cookie_key] + refresh_token = cookies[settings.refresh_token_cookie_key] + return access_token, refresh_token diff --git a/friendli/client/project.py b/friendli/client/project.py index 4d585248..e7e5e0ed 100644 --- a/friendli/client/project.py +++ b/friendli/client/project.py @@ -6,7 +6,6 @@ from __future__ import annotations -from string import Template from typing import Any, Dict, List, Optional from uuid import UUID @@ -30,9 +29,9 @@ class ProjectClient(Client[UUID]): """Project client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_auth_uri("pf_project")) + return self.url_provider.get_auth_uri("pf_project") def get_project(self, pf_project_id: UUID) -> Dict[str, Any]: """Get project info.""" @@ -69,11 +68,9 @@ def __init__(self, **kwargs): super().__init__(project_id=self.project_id, **kwargs) @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template( - self.url_provider.get_auth_uri("pf_project/$project_id/credential") - ) + return self.url_provider.get_auth_uri("pf_project/$project_id/credential") def list_credentials( self, cred_type: Optional[CredType] = None diff --git a/friendli/client/user.py b/friendli/client/user.py index 818fc81f..bd96b11b 100644 --- a/friendli/client/user.py +++ b/friendli/client/user.py @@ -4,7 +4,6 @@ from __future__ import annotations -from string import Template from typing import Any, Dict, List from uuid import UUID @@ -16,9 +15,9 @@ class UserMFAClient(Client): """User MFA client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_auth_uri("mfa")) + return self.url_provider.get_auth_uri("mfa") def initiate_mfa(self, mfa_type: str, mfa_token: str) -> None: """Authenticate by MFA token.""" @@ -29,9 +28,9 @@ class UserSignUpClient(Client): """User sign-up client.""" @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_auth_uri("pf_user/self_signup")) + return self.url_provider.get_auth_uri("pf_user/self_signup") def verify(self, token: str, key: str) -> None: """Verify the email account with the token to sign up.""" @@ -47,9 +46,9 @@ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_auth_uri("pf_user")) + return self.url_provider.get_auth_uri("pf_user") def change_password(self, old_password: str, new_password: str) -> None: """Change password.""" @@ -114,9 +113,9 @@ def __init__(self, **kwargs): super().__init__(pf_user_id=self.user_id, **kwargs) @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_auth_uri("pf_user/$pf_user_id/pf_group")) + return self.url_provider.get_auth_uri("pf_user/$pf_user_id/pf_group") def get_group_info(self) -> Dict[str, Any]: """Get organization info where user belongs to.""" @@ -134,12 +133,10 @@ def __init__(self, **kwargs): super().__init__(pf_user_id=self.user_id, pf_group_id=self.group_id, **kwargs) @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template( - self.url_provider.get_auth_uri( - "pf_user/$pf_user_id/pf_group/$pf_group_id/pf_project" - ) + return self.url_provider.get_auth_uri( + "pf_user/$pf_user_id/pf_group/$pf_group_id/pf_project" ) def list_projects(self) -> List[Dict[str, Any]]: @@ -157,9 +154,9 @@ def __init__(self, **kwargs): super().__init__(pf_user_id=self.user_id, **kwargs) @property - def url_path(self) -> Template: + def url_path(self) -> str: """Get an URL path.""" - return Template(self.url_provider.get_auth_uri("pf_user")) + return self.url_provider.get_auth_uri("pf_user") def create_access_key(self, name: str) -> Dict[str, Any]: """Create a new access key.""" diff --git a/friendli/di/modules.py b/friendli/di/modules.py index 7a045a4e..14c57866 100644 --- a/friendli/di/modules.py +++ b/friendli/di/modules.py @@ -6,15 +6,17 @@ from injector import Binder, Module +from friendli import settings from friendli.utils import url -class URLModule(Module): +class SettingsModule(Module): """Friendli client module.""" def configure(self, binder: Binder) -> None: """Configures bindings for clients.""" binder.bind(url.URLProvider, to=url.ProductionURLProvider) # type: ignore + binder.bind(settings.Settings, to=settings.ProductionSettings) # type: ignore -default_modules = [URLModule] +default_modules = [SettingsModule] diff --git a/friendli/graphql/__init__.py b/friendli/graphql/__init__.py new file mode 100644 index 00000000..1b0661fd --- /dev/null +++ b/friendli/graphql/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2022-present, FriendliAI Inc. All rights reserved. + +"""Friendli graphql clients to interact with Friendli system.""" diff --git a/friendli/graphql/base.py b/friendli/graphql/base.py new file mode 100644 index 00000000..52abbc55 --- /dev/null +++ b/friendli/graphql/base.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022-present, FriendliAI Inc. All rights reserved. + +"""Friendli GQL Client Service.""" + +from __future__ import annotations + +from typing import Any, Dict, Optional + +from friendli.client.base import Client + + +class GqlClient(Client): + """Base interface of graphql client to Friendli system.""" + + @property + def url_path(self) -> str: + """URL path template to render.""" + return self.url_provider.get_web_backend_uri("api/graphql") + + def run( + self, query: str, variables: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """Run graphql.""" + return self.post( + json={ + "query": query, + "variables": variables, + } + )["data"] diff --git a/friendli/graphql/user.py b/friendli/graphql/user.py new file mode 100644 index 00000000..f23bcf18 --- /dev/null +++ b/friendli/graphql/user.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-present, FriendliAI Inc. All rights reserved. + +"""Friendli User GQL Clients.""" + +from __future__ import annotations + +from typing import Any, Dict + +from friendli.graphql.base import GqlClient + +CurrUserInfoGql = """ +query GetclientSession { + clientSession { + user { + id + name + email + } + } +} +""" + + +class UserGqlClient(GqlClient): + """User gql client.""" + + def get_current_user_info(self) -> Dict[str, Any]: + """Get current user info.""" + response = self.run(query=CurrUserInfoGql) + return response["clientSession"]["user"] diff --git a/friendli/logging.py b/friendli/logging.py index e1359bb9..b1d27e65 100644 --- a/friendli/logging.py +++ b/friendli/logging.py @@ -7,10 +7,42 @@ import logging import os -_formatter = logging.Formatter( - fmt="%(asctime)s.%(msecs)05d: %(name)s %(levelname)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) +_formatter = logging.Formatter() + + +class ColorFormatter(logging.Formatter): + """Customized formatter with ANSI color.""" + + grey = "\x1b[38;20m" + yellow = "\x1b[33;20m" + red = "\x1b[31;20m" + bold_red = "\x1b[31;1m" + reset = "\x1b[0m" + + default_fmt = "%(asctime)s.%(msecs)05d: %(name)s %(levelname)s: %(message)s" + default_datefmt = "%Y-%m-%d %H:%M:%S" + + FORMATS = { + logging.DEBUG: grey + default_fmt + reset, + logging.INFO: grey + default_fmt + reset, + logging.WARNING: yellow + default_fmt + reset, + logging.ERROR: red + default_fmt + reset, + logging.CRITICAL: bold_red + default_fmt + reset, + } + + def __init__(self): + """Initialize CustomFormatter.""" + super().__init__(fmt=self.default_fmt, datefmt=self.default_datefmt) + + # Pre-create Formatter objects for each level to improve efficiency + self.formatters = { + level: logging.Formatter(fmt) for level, fmt in self.FORMATS.items() + } + + def format(self, record): + """Override format method.""" + formatter = self.formatters.get(record.levelno, self.formatters[logging.INFO]) + return formatter.format(record) def get_logger(name: str) -> logging.Logger: @@ -18,7 +50,8 @@ def get_logger(name: str) -> logging.Logger: logger = logging.getLogger(name) handler = logging.StreamHandler() - handler.setFormatter(_formatter) + formatter = ColorFormatter() + handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(os.environ.get("FRIENDLI_LOG_LEVEL", "INFO")) diff --git a/friendli/modules/converter/models/mixtral.py b/friendli/modules/converter/models/mixtral.py index b8a82980..5cf5a366 100644 --- a/friendli/modules/converter/models/mixtral.py +++ b/friendli/modules/converter/models/mixtral.py @@ -160,7 +160,7 @@ def decoder_convert_info_list( @property def model_type(self) -> str: """Model type.""" - return "mistral" + return "mixtral" @property def decoder_layer_num(self) -> int: diff --git a/friendli/settings.py b/friendli/settings.py new file mode 100644 index 00000000..342ae30e --- /dev/null +++ b/friendli/settings.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""CLI App Settings.""" + + +class Settings: + """CLI app settings.""" + + access_token_cookie_key = "" + refresh_token_cookie_key = "" + + +class ProductionSettings: + """Production CLI app settings.""" + + access_token_cookie_key = "sAccessTokenProduction" + refresh_token_cookie_key = "sRefreshTokenProduction" + + +class StagingSettings: + """Staging CLI app settings.""" + + access_token_cookie_key = "sAccessTokenStaging" + refresh_token_cookie_key = "sRefreshTokenStaging" + + +class DevSettings: + """Dev CLI app settings.""" + + access_token_cookie_key = "sAccessTokenDev" + refresh_token_cookie_key = "sRefreshTokenDev" diff --git a/friendli/utils/url.py b/friendli/utils/url.py index 2419c4fe..2176c764 100644 --- a/friendli/utils/url.py +++ b/friendli/utils/url.py @@ -22,6 +22,8 @@ def get_host(url: str) -> str: class URLProvider: """Service URL provider.""" + suite_url = "" + api_url = "" training_url = "" registry_url = "" serving_url = "" @@ -30,6 +32,11 @@ class URLProvider: observatory_url = "" web_backend_url = "" + @classmethod + def get_suite_uri(cls, path: str) -> str: + """Get Friendli Suite URI.""" + return urljoin(cls.suite_url, path) + @classmethod def get_auth_uri(cls, path: str) -> str: """Get PFA URI.""" @@ -46,6 +53,11 @@ def get_training_uri(cls, path: str) -> str: """Get PFT URI.""" return urljoin(cls.training_url, path) + @classmethod + def get_api_uri(cls, path: str) -> str: + """Get PFT URI.""" + return urljoin(cls.api_url, path) + @classmethod def get_serving_uri(cls, path: str) -> str: """Get PFS URI.""" @@ -70,37 +82,37 @@ def get_observatory_uri(cls, path: str) -> str: class ProductionURLProvider(URLProvider): """Production service URL provider.""" - training_url = "https://training.friendli.ai/api/" - training_ws_url = "wss://training-ws.friendli.ai/ws/" + suite_url = "https://suite.friendli.ai/" registry_url = "https://modelregistry.friendli.ai/" serving_url = "https://serving.friendli.ai/" auth_url = "https://auth.friendli.ai/" meter_url = "https://metering.friendli.ai/" observatory_url = "https://observatory.friendli.ai/" - web_backend_url = "https://cloud.friendli.ai/" + web_backend_url = "https://suite.friendli.ai/" + training_url = "https://training.friendli.ai/api/" class StagingURLProvider(URLProvider): """Staging service URL provider.""" - training_url = "https://api-staging.friendli.ai/api/" - training_ws_url = "wss://api-ws-staging.friendli.ai/ws/" + suite_url = "https://suite-staging.friendli.ai/" registry_url = "https://pfmodelregistry-staging.friendli.ai/" serving_url = "https://pfs-staging.friendli.ai/" auth_url = "https://pfauth-staging.friendli.ai/" meter_url = "https://pfmeter-staging.friendli.ai/" observatory_url = "https://pfo-staging.friendli.ai/" web_backend_url = "https://api-staging.friendli.ai/" + training_url = "https://api-staging.friendli.ai/api/" class DevURLProvider(URLProvider): """Dev service URL provider.""" - training_url = "https://api-dev.friendli.ai/api/" - training_ws_url = "wss://api-ws-dev.friendli.ai/ws/" + suite_url = "https://suite-dev.friendli.ai/" registry_url = "https://pfmodelregistry-dev.friendli.ai/" serving_url = "https://pfs-dev.friendli.ai/" auth_url = "https://pfauth-dev.friendli.ai/" meter_url = "https://pfmeter-dev.friendli.ai/" observatory_url = "https://pfo-dev.friendli.ai/" web_backend_url = "https://api-dev.friendli.ai/" + training_url = "https://api-dev.friendli.ai/api/" diff --git a/poetry.lock b/poetry.lock index f2d63831..619148a8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "accelerate" @@ -1221,11 +1221,30 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "fastapi" +version = "0.109.2" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastapi-0.109.2-py3-none-any.whl", hash = "sha256:2c9bab24667293b501cad8dd388c05240c850b58ec5876ee3283c47d6e1e3a4d"}, + {file = "fastapi-0.109.2.tar.gz", hash = "sha256:f3817eac96fe4f65a2ebb4baa000f394e55f5fccdaf7f75250804bc58f354f73"}, +] + +[package.dependencies] +pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" +starlette = ">=0.36.3,<0.37.0" +typing-extensions = ">=4.8.0" + +[package.extras] +all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] + [[package]] name = "filelock" version = "3.12.2" description = "A platform independent file lock." -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"}, @@ -2620,7 +2639,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -3016,6 +3034,24 @@ files = [ {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"}, ] +[[package]] +name = "starlette" +version = "0.36.3" +description = "The little ASGI library that shines." +optional = false +python-versions = ">=3.8" +files = [ + {file = "starlette-0.36.3-py3-none-any.whl", hash = "sha256:13d429aa93a61dc40bf503e8c801db1f1bca3dc706b10ef2434a36123568f044"}, + {file = "starlette-0.36.3.tar.gz", hash = "sha256:90a671733cfb35771d8cc605e0b679d23b992f8dcfad48cc60b38cb29aeb7080"}, +] + +[package.dependencies] +anyio = ">=3.4.0,<5" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] + [[package]] name = "sympy" version = "1.12" @@ -3441,13 +3477,13 @@ files = [ [[package]] name = "typing-extensions" -version = "4.7.1" -description = "Backported and Experimental Type Hints for Python 3.7+" +version = "4.9.0" +description = "Backported and Experimental Type Hints for Python 3.8+" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, - {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, + {file = "typing_extensions-4.9.0-py3-none-any.whl", hash = "sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd"}, + {file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"}, ] [[package]] @@ -3478,62 +3514,24 @@ secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "p socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] [[package]] -name = "websockets" -version = "10.1" -description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" +name = "uvicorn" +version = "0.27.0.post1" +description = "The lightning-fast ASGI server." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "websockets-10.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:38db6e2163b021642d0a43200ee2dec8f4980bdbda96db54fde72b283b54cbfc"}, - {file = "websockets-10.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e1b60fd297adb9fc78375778a5220da7f07bf54d2a33ac781319650413fc6a60"}, - {file = "websockets-10.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3477146d1f87ead8df0f27e8960249f5248dceb7c2741e8bbec9aa5338d0c053"}, - {file = "websockets-10.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb01ea7b5f52e7125bdc3c5807aeaa2d08a0553979cf2d96a8b7803ea33e15e7"}, - {file = "websockets-10.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9fd62c6dc83d5d35fb6a84ff82ec69df8f4657fff05f9cd6c7d9bec0dd57f0f6"}, - {file = "websockets-10.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3bbf080f3892ba1dc8838786ec02899516a9d227abe14a80ef6fd17d4fb57127"}, - {file = "websockets-10.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5560558b0dace8312c46aa8915da977db02738ac8ecffbc61acfbfe103e10155"}, - {file = "websockets-10.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:667c41351a6d8a34b53857ceb8343a45c85d438ee4fd835c279591db8aeb85be"}, - {file = "websockets-10.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:468f0031fdbf4d643f89403a66383247eb82803430b14fa27ce2d44d2662ca37"}, - {file = "websockets-10.1-cp310-cp310-win32.whl", hash = "sha256:d0d81b46a5c87d443e40ce2272436da8e6092aa91f5fbeb60d1be9f11eff5b4c"}, - {file = "websockets-10.1-cp310-cp310-win_amd64.whl", hash = "sha256:b68b6caecb9a0c6db537aa79750d1b592a841e4f1a380c6196091e65b2ad35f9"}, - {file = "websockets-10.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a249139abc62ef333e9e85064c27fefb113b16ffc5686cefc315bdaef3eefbc8"}, - {file = "websockets-10.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8877861e3dee38c8d302eee0d5dbefa6663de3b46dc6a888f70cd7e82562d1f7"}, - {file = "websockets-10.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e3872ae57acd4306ecf937d36177854e218e999af410a05c17168cd99676c512"}, - {file = "websockets-10.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b66e6d514f12c28d7a2d80bb2a48ef223342e99c449782d9831b0d29a9e88a17"}, - {file = "websockets-10.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9f304a22ece735a3da8a51309bc2c010e23961a8f675fae46fdf62541ed62123"}, - {file = "websockets-10.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:189ed478395967d6a98bb293abf04e8815349e17456a0a15511f1088b6cb26e4"}, - {file = "websockets-10.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:08a42856158307e231b199671c4fce52df5786dd3d703f36b5d8ac76b206c485"}, - {file = "websockets-10.1-cp37-cp37m-win32.whl", hash = "sha256:3ef6f73854cded34e78390dbdf40dfdcf0b89b55c0e282468ef92646fce8d13a"}, - {file = "websockets-10.1-cp37-cp37m-win_amd64.whl", hash = "sha256:89e985d40d407545d5f5e2e58e1fdf19a22bd2d8cd54d20a882e29f97e930a0a"}, - {file = "websockets-10.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:002071169d2e44ce8eb9e5ebac9fbce142ba4b5146eef1cfb16b177a27662657"}, - {file = "websockets-10.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cfae282c2aa7f0c4be45df65c248481f3509f8c40ca8b15ed96c35668ae0ff69"}, - {file = "websockets-10.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:97b4b68a2ddaf5c4707ae79c110bfd874c5be3c6ac49261160fb243fa45d8bbb"}, - {file = "websockets-10.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c9407719f42cb77049975410490c58a705da6af541adb64716573e550e5c9db"}, - {file = "websockets-10.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1d858fb31e5ac992a2cdf17e874c95f8a5b1e917e1fb6b45ad85da30734b223f"}, - {file = "websockets-10.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7bdd3d26315db0a9cf8a0af30ca95e0aa342eda9c1377b722e71ccd86bc5d1dd"}, - {file = "websockets-10.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e259be0863770cb91b1a6ccf6907f1ac2f07eff0b7f01c249ed751865a70cb0d"}, - {file = "websockets-10.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6b014875fae19577a392372075e937ebfebf53fd57f613df07b35ab210f31534"}, - {file = "websockets-10.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:98de71f86bdb29430fd7ba9997f47a6b10866800e3ea577598a786a785701bb0"}, - {file = "websockets-10.1-cp38-cp38-win32.whl", hash = "sha256:3a02ab91d84d9056a9ee833c254895421a6333d7ae7fff94b5c68e4fa8095519"}, - {file = "websockets-10.1-cp38-cp38-win_amd64.whl", hash = "sha256:7d6673b2753f9c5377868a53445d0c321ef41ff3c8e3b6d57868e72054bfce5f"}, - {file = "websockets-10.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ddab2dc69ee5ae27c74dbfe9d7bb6fee260826c136dca257faa1a41d1db61a89"}, - {file = "websockets-10.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:14e9cf68a08d1a5d42109549201aefba473b1d925d233ae19035c876dd845da9"}, - {file = "websockets-10.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e4819c6fb4f336fd5388372cb556b1f3a165f3f68e66913d1a2fc1de55dc6f58"}, - {file = "websockets-10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05e7f098c76b0a4743716590bb8f9706de19f1ef5148d61d0cf76495ec3edb9c"}, - {file = "websockets-10.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5bb6256de5a4fb1d42b3747b4e2268706c92965d75d0425be97186615bf2f24f"}, - {file = "websockets-10.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:888a5fa2a677e0c2b944f9826c756475980f1b276b6302e606f5c4ff5635be9e"}, - {file = "websockets-10.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6fdec1a0b3e5630c58e3d8704d2011c678929fce90b40908c97dfc47de8dca72"}, - {file = "websockets-10.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:531d8eb013a9bc6b3ad101588182aa9b6dd994b190c56df07f0d84a02b85d530"}, - {file = "websockets-10.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0d93b7cadc761347d98da12ec1930b5c71b2096f1ceed213973e3cda23fead9c"}, - {file = "websockets-10.1-cp39-cp39-win32.whl", hash = "sha256:d9b245db5a7e64c95816e27d72830e51411c4609c05673d1ae81eb5d23b0be54"}, - {file = "websockets-10.1-cp39-cp39-win_amd64.whl", hash = "sha256:882c0b8bdff3bf1bd7f024ce17c6b8006042ec4cceba95cf15df57e57efa471c"}, - {file = "websockets-10.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:10edd9d7d3581cfb9ff544ac09fc98cab7ee8f26778a5a8b2d5fd4b0684c5ba5"}, - {file = "websockets-10.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baa83174390c0ff4fc1304fbe24393843ac7a08fdd59295759c4b439e06b1536"}, - {file = "websockets-10.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:483edee5abed738a0b6a908025be47f33634c2ad8e737edd03ffa895bd600909"}, - {file = "websockets-10.1-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:816ae7dac2c6522cfa620947ead0ca95ac654916eebf515c94d7c28de5601a6e"}, - {file = "websockets-10.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:1dafe98698ece09b8ccba81b910643ff37198e43521d977be76caf37709cf62b"}, - {file = "websockets-10.1.tar.gz", hash = "sha256:181d2b25de5a437b36aefedaf006ecb6fa3aa1328ec0236cdde15f32f9d3ff6d"}, + {file = "uvicorn-0.27.0.post1-py3-none-any.whl", hash = "sha256:4b85ba02b8a20429b9b205d015cbeb788a12da527f731811b643fd739ef90d5f"}, + {file = "uvicorn-0.27.0.post1.tar.gz", hash = "sha256:54898fcd80c13ff1cd28bf77b04ec9dbd8ff60c5259b499b4b12bb0917f22907"}, ] +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" +typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} + +[package.extras] +standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [[package]] name = "wrapt" version = "1.15.0" @@ -3820,4 +3818,4 @@ mllib = ["accelerate", "datasets", "einops", "h5py", "peft", "transformers"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "041119a17bb7b489082c252360d9cf2bc11e1e7103ba60c07afeb5f6dd4c0a4d" +content-hash = "4a2eb692717176c07248be4eb0ae543653edb1a0c0d4da1e35213b79c1dec6cb" diff --git a/pyproject.toml b/pyproject.toml index e6b6374d..903868bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "friendli-client" -version = "1.2.0" +version = "1.2.1" description = "Client of Friendli Suite." license = "Apache-2.0" authors = ["FriendliAI teams "] @@ -26,57 +26,55 @@ priority = "primary" [tool.poetry.dependencies] python = "^3.8" -filelock = "3.12.2" -requests = "2.31.0" -websockets = "10.1" -PyYaml = "6.0.1" -typer = "0.9.0" -rich = "12.2.0" -jsonschema = "4.17.3" -boto3 = "1.22.8" -botocore = "1.25.8" -tqdm = "4.64.0" -azure-mgmt-storage = "20.1.0" -azure-storage-blob = "12.12.0" -packaging = "23.1" -pathspec = "0.9.0" -boto3-stubs = "1.26.90" -mypy-boto3-s3 = "1.26.163" -ruamel-yaml = "0.17.32" -pydantic = {extras = ["email"], version = "2.0.2"} +requests = "^2.31.0" +PyYaml = "^6.0.1" +typer = "^0.9.0" +rich = "^12.2.0" +jsonschema = "^4.17.3" +boto3 = "^1.22.8" +botocore = "^1.25.8" +tqdm = "^4.64.0" +azure-mgmt-storage = "^20.1.0" +azure-storage-blob = "^12.12.0" +pathspec = "^0.9.0" +boto3-stubs = "^1.26.90" +mypy-boto3-s3 = "^1.26.163" +ruamel-yaml = "^0.17.32" +pydantic = {extras = ["email"], version = "^2.0.2"} transformers = { version = "4.36.2", optional = true } -h5py = { version = "3.9.0", optional = true } -einops = { version = "0.6.1", optional = true } +h5py = { version = "^3.9.0", optional = true } +einops = { version = "^0.6.1", optional = true } accelerate = { version = "0.21.0", optional = true } datasets = { version = "2.16.0", optional = true } -injector = "0.21.0" -protobuf = "4.24.2" -types-protobuf = "4.24.0.1" +injector = "^0.21.0" +protobuf = "^4.24.2" +types-protobuf = "^4.24.0.1" peft = { version = "0.6.0", optional = true } -httpx = "0.26.0" +httpx = "^0.26.0" +fastapi = "^0.109.2" +uvicorn = "^0.27.0.post1" [tool.poetry.group.dev] optional = true [tool.poetry.group.dev.dependencies] -typer = "0.9.0" -pytest = "7.4.0" -coverage = "7.2.7" -pytest-asyncio = "0.15.1" -pytest-cov = "4.1.0" -requests-mock = "1.11.0" -black = "23.3.0" -isort = "5.12.0" -mypy = "1.4.1" -pydocstyle = "6.3.0" -pylint = "2.17.4" -toml = "0.10.2" -types-pyyaml = "6.0.12.10" -types-jsonschema = "4.17.0.8" -types-python-dateutil = "2.8.19.13" -types-requests = "2.31.0.1" -types-toml = "0.10.8.6" -types-tqdm = "4.65.0.1" +pytest = "^7.4.0" +coverage = "^7.2.7" +pytest-asyncio = "^0.15.1" +pytest-cov = "^4.1.0" +requests-mock = "^1.11.0" +black = "^23.3.0" +isort = "^5.12.0" +mypy = "^1.4.1" +pydocstyle = "^6.3.0" +pylint = "^2.17.4" +toml = "^0.10.2" +types-pyyaml = "^6.0.12.10" +types-jsonschema = "^4.17.0.8" +types-python-dateutil = "^2.8.19.13" +types-requests = "^2.31.0.1" +types-toml = "^0.10.8.6" +types-tqdm = "^4.65.0.1" [tool.poetry.extras] mllib = ["transformers", "h5py", "accelerate", "einops", "datasets", "peft"] diff --git a/tests/unit_tests/client/test_base.py b/tests/unit_tests/client/test_base.py index 69a2a872..de171594 100644 --- a/tests/unit_tests/client/test_base.py +++ b/tests/unit_tests/client/test_base.py @@ -85,8 +85,8 @@ def test_client_service_base( class TestClient(Client[int]): @property - def url_path(self) -> Template: - return Template(url_pattern) + def url_path(self) -> str: + return url_pattern client = TestClient(test_id=1)