diff --git a/friendli/auth.py b/friendli/auth.py
index 928a1b54..c1cd0170 100644
--- a/friendli/auth.py
+++ b/friendli/auth.py
@@ -15,6 +15,7 @@
import friendli
from friendli.di.injector import get_injector
from friendli.errors import APIError, AuthorizationError, AuthTokenNotFoundError
+from friendli.logging import logger
from friendli.utils.fs import get_friendli_directory
from friendli.utils.request import DEFAULT_REQ_TIMEOUT, decode_http_err
from friendli.utils.url import URLProvider
@@ -52,16 +53,30 @@ def get_auth_header(
"""
token_: Optional[str]
+ token_from_cfg = get_token(TokenType.ACCESS)
+ token_from_env = friendli.token
+
if token is not None:
token_ = token
- elif friendli.token:
- token_ = friendli.token
+ elif token_from_env:
+ if token_from_cfg:
+ logger.warning(
+ "You've entered your login information in two places - through the "
+ "'FRIENDLI_TOKEN' environment variable and the 'friendli login' CLI "
+ "command. We will use the access token from the 'FRIENDLI_TOKEN' "
+ "environment variable and ignore the login session details. This might "
+ "lead to unexpected authorization errors. If you prefer to use the "
+ "login session instead, unset the 'FRIENDLI_TOKEN' environment "
+ "variable. If you don't want to see this warning again, run "
+ "'friendli logout' to remove the login session."
+ )
+ token_ = token_from_env
else:
- token_ = get_token(TokenType.ACCESS)
+ token_ = token_from_cfg
if token_ is None:
raise AuthTokenNotFoundError(
- "Should set FRIENDLI_TOKEN environment variable or sign in with 'friendli login'."
+ "Should set 'FRIENDLI_TOKEN' environment variable or sign in with 'friendli login'."
)
headers = {"Authorization": f"Bearer {token_}"}
diff --git a/friendli/cli/api/chat_completions.py b/friendli/cli/api/chat_completions.py
index 5496af42..da840d2a 100644
--- a/friendli/cli/api/chat_completions.py
+++ b/friendli/cli/api/chat_completions.py
@@ -58,6 +58,16 @@ def create(
min=1,
help="The maximum number of tokens to generate.",
),
+ stop: Optional[List[str]] = typer.Option(
+ None,
+ "--stop",
+ "-S",
+ help=(
+ "When one of the stop phrases appears in the generation result, the API "
+ "will stop generation. The stop phrases are excluded from the result. "
+ "Repeat this option to use multiple stop phrases."
+ ),
+ ),
temperature: Optional[float] = typer.Option(
None,
"--temperature",
@@ -120,6 +130,7 @@ def create(
presence_penalty=presence_penalty,
max_tokens=max_tokens,
n=n,
+ stop=stop,
temperature=temperature,
top_p=top_p,
)
@@ -137,6 +148,7 @@ def create(
presence_penalty=presence_penalty,
max_tokens=max_tokens,
n=n,
+ stop=stop,
temperature=temperature,
top_p=top_p,
)
diff --git a/friendli/cli/api/completions.py b/friendli/cli/api/completions.py
index d2df461c..230509ba 100644
--- a/friendli/cli/api/completions.py
+++ b/friendli/cli/api/completions.py
@@ -6,7 +6,7 @@
from __future__ import annotations
-from typing import Optional
+from typing import List, Optional
import typer
@@ -53,6 +53,16 @@ def create(
min=1,
help="The maximum number of tokens to generate.",
),
+ stop: Optional[List[str]] = typer.Option(
+ None,
+ "--stop",
+ "-S",
+ help=(
+ "When one of the stop phrases appears in the generation result, the API "
+ "will stop generation. The stop phrases are excluded from the result. "
+ "Repeat this option to use multiple stop phrases."
+ ),
+ ),
temperature: Optional[float] = typer.Option(
None,
"--temperature",
@@ -113,6 +123,7 @@ def create(
presence_penalty=presence_penalty,
max_tokens=max_tokens,
n=n,
+ stop=stop,
temperature=temperature,
top_p=top_p,
)
@@ -130,6 +141,7 @@ def create(
presence_penalty=presence_penalty,
max_tokens=max_tokens,
n=n,
+ stop=stop,
temperature=temperature,
top_p=top_p,
)
diff --git a/friendli/cli/login.py b/friendli/cli/login.py
new file mode 100644
index 00000000..1c226ca8
--- /dev/null
+++ b/friendli/cli/login.py
@@ -0,0 +1,127 @@
+# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved.
+
+"""CLI command to sign in Friendli."""
+
+from __future__ import annotations
+
+import threading
+import time
+import webbrowser
+from contextlib import contextmanager
+from typing import Iterator, Tuple
+
+import typer
+import uvicorn
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.responses import HTMLResponse
+
+from friendli.client.login import LoginClient
+from friendli.di.injector import get_injector
+from friendli.utils.url import URLProvider
+
+server_app = FastAPI()
+
+
+@contextmanager
+def run_server(port: int) -> Iterator[None]:
+ """Run temporary local server to handle SSO redirection."""
+ config = uvicorn.Config(
+ app=server_app, host="127.0.0.1", port=port, log_level="error"
+ )
+ server = uvicorn.Server(config)
+ thread = threading.Thread(target=server.run)
+ thread.start()
+ try:
+ yield
+ finally:
+ server.should_exit = True
+ thread.join()
+
+
+def oauth2_login() -> Tuple[str, str]:
+ """Login with SSO."""
+ injector = get_injector()
+ url_provider = injector.get(URLProvider)
+ authorization_url = url_provider.get_suite_uri("/login/cli")
+
+ access_token = None
+ refresh_token = None
+
+ @server_app.get("/sso")
+ async def callback(request: Request) -> HTMLResponse:
+ nonlocal access_token
+ nonlocal refresh_token
+
+ access_token = request.query_params.get("access_token")
+ refresh_token = request.query_params.get("refresh_token")
+
+ if not access_token:
+ raise HTTPException(
+ status_code=400, detail="Access token not found in cookies"
+ )
+
+ success_page = r"""
+
+
+
+
+ SSO Login Success
+
+
+
+
+
Authentication was successful
+
You can now close this window and return to CLI.
+
Redirecting to Friendli Documentation in 10 seconds.
+
+
+
+
+"""
+ return HTMLResponse(content=success_page, status_code=200)
+
+ typer.secho(
+ f"Opening browser for authentication: {authorization_url}", fg=typer.colors.BLUE
+ )
+
+ webbrowser.open(authorization_url)
+
+ with run_server(33333):
+ while access_token is None or refresh_token is None:
+ time.sleep(1)
+
+ return access_token, refresh_token
+
+
+def pwd_login(email: str, pwd: str) -> Tuple[str, str]:
+ """Login with email and password."""
+ client = LoginClient()
+ return client.login(email, pwd)
diff --git a/friendli/cli/main.py b/friendli/cli/main.py
index b52e23c9..949b3a9a 100644
--- a/friendli/cli/main.py
+++ b/friendli/cli/main.py
@@ -6,26 +6,18 @@
from __future__ import annotations
-import requests
import typer
-from requests import HTTPError, Response
-from friendli.auth import TokenType, clear_tokens, get_token, update_token
+import friendli
+from friendli.auth import TokenType, clear_tokens, update_token
from friendli.cli import api, checkpoint
-from friendli.client.project import ProjectClient
-from friendli.client.user import UserClient, UserGroupClient, UserMFAClient
-from friendli.context import (
- get_current_project_id,
- project_context_path,
- set_current_group_id,
-)
-from friendli.di.injector import get_injector
+from friendli.cli.login import oauth2_login, pwd_login
+from friendli.client.user import UserClient
from friendli.errors import AuthTokenNotFoundError
from friendli.formatter import PanelFormatter
+from friendli.graphql.user import UserGqlClient
from friendli.utils.decorator import check_api
from friendli.utils.format import secho_error_and_exit
-from friendli.utils.request import DEFAULT_REQ_TIMEOUT
-from friendli.utils.url import URLProvider
from friendli.utils.version import get_installed_version
app = typer.Typer(
@@ -60,7 +52,7 @@
def whoami():
"""Show my user info."""
try:
- client = UserClient()
+ client = UserGqlClient()
info = client.get_current_user_info()
except AuthTokenNotFoundError as exc:
secho_error_and_exit(str(exc))
@@ -71,55 +63,27 @@ def whoami():
# @app.command()
@check_api
def login(
- email: str = typer.Option(..., prompt="Enter your email"),
- password: str = typer.Option(..., prompt="Enter your password", hide_input=True),
+ use_sso: bool = typer.Option(False, "--sso", help="Use SSO login."),
):
- """Sign in."""
- injector = get_injector()
- url_provider = injector.get(URLProvider)
- r = requests.post(
- url_provider.get_web_backend_uri("/api/auth/cli/access_token"),
- json={"username": email, "password": password},
- timeout=DEFAULT_REQ_TIMEOUT,
- )
- try:
- resp = r.json()
- except requests.exceptions.JSONDecodeError:
- if r.status_code != 200:
- secho_error_and_exit(r.content.decode())
- secho_error_and_exit("Invalid response format.")
-
- if "code" in resp and resp["code"] == "mfa_required":
- mfa_token = resp["mfaToken"]
- client = UserMFAClient()
- # TODO: MFA type currently defaults to totp, need changes when new options are added
- client.initiate_mfa(mfa_type="totp", mfa_token=mfa_token)
- update_token(token_type=TokenType.MFA, token=mfa_token)
- typer.run(_mfa_verify)
+ """Sign in Friendli."""
+ if friendli.token:
+ typer.secho(
+ "You've already set the 'FRIENDLI_TOKEN' environment variable for "
+ "authentication, which takes precedence over the login session. Using both "
+ "methods of authentication simultaneously could lead to unexpected issues. "
+ "We suggest removing the 'FRIENDLI_TOKEN' environment variable if you "
+ "prefer to log in through the standard login session.",
+ fg=typer.colors.RED,
+ )
+
+ if use_sso:
+ access_token, refresh_token = oauth2_login()
else:
- _handle_login_response(r, False)
+ email = typer.prompt("Enter your email")
+ pwd = typer.prompt("Enter your password", hide_input=True)
+ access_token, refresh_token = pwd_login(email, pwd)
- # Save user's organiztion context
- project_client = ProjectClient()
- user_group_client = UserGroupClient()
-
- try:
- org = user_group_client.get_group_info()
- except IndexError:
- secho_error_and_exit("You are not included in any organization.")
- org_id = org["id"]
-
- project_id = get_current_project_id()
- if project_id is not None:
- if project_client.check_project_membership(pf_project_id=project_id):
- project_org_id = project_client.get_project(pf_project_id=project_id)[
- "pf_group_id"
- ]
- if project_org_id != org_id:
- project_context_path.unlink(missing_ok=True)
- else:
- project_context_path.unlink(missing_ok=True)
- set_current_group_id(org_id)
+ _display_login_success(access_token, refresh_token)
# @app.command()
@@ -163,42 +127,15 @@ def version():
typer.echo(installed_version)
-def _mfa_verify(_, code: str = typer.Option(..., prompt="Enter MFA Code")):
- injector = get_injector()
- url_provider = injector.get(URLProvider)
-
- mfa_token = get_token(TokenType.MFA)
- # TODO: MFA type currently defaults to totp, need changes when new options are added
- mfa_type = "totp"
- username = f"mfa://{mfa_type}/{mfa_token}"
- r = requests.post(
- url_provider.get_web_backend_uri("/api/auth/cli/access_token"),
- json={"username": username, "password": code},
- timeout=DEFAULT_REQ_TIMEOUT,
- )
- _handle_login_response(r, True)
-
-
-def _handle_login_response(r: Response, mfa: bool):
- try:
- r.raise_for_status()
- update_token(token_type=TokenType.ACCESS, token=r.json()["accessToken"])
- update_token(token_type=TokenType.REFRESH, token=r.json()["refreshToken"])
+def _display_login_success(access_token: str, refresh_token: str):
+ update_token(token_type=TokenType.ACCESS, token=access_token)
+ update_token(token_type=TokenType.REFRESH, token=refresh_token)
- typer.echo("\n\nLogin success!")
- typer.echo("Welcome back to...")
- typography = r"""
+ typography = r"""
_____ _ _
| ___|_ _(_) ___ _ __ _| || |(_)
| |__ | '__| |/ _ \| '__ \/ _ || || |
| __|| | | | __/| | | | (_) || || |
|_| |_| |_|\___||_| |_|\___/ |_||_|
"""
- typer.secho(typography, fg=typer.colors.BLUE)
- except HTTPError:
- if mfa:
- secho_error_and_exit("Login failed... Invalid MFA Code.")
- else:
- secho_error_and_exit(
- "Login failed... Please check your email and password."
- )
+ typer.secho(f"\nLOGIN SUCCESS!\n{typography}", fg=typer.colors.BLUE)
diff --git a/friendli/client/base.py b/friendli/client/base.py
index e3d1c267..4b639f44 100644
--- a/friendli/client/base.py
+++ b/friendli/client/base.py
@@ -151,14 +151,14 @@ class Client(ABC, Generic[T], RequestInterface):
def __init__(self, **kwargs):
"""Initialize client."""
- injector = get_injector()
- self.url_provider = injector.get(URLProvider)
- self.url_template = URLTemplate(self.url_path)
+ self.injector = get_injector()
+ self.url_provider = self.injector.get(URLProvider)
+ self.url_template = URLTemplate(Template(self.url_path))
self.url_kwargs = kwargs
@property
@abstractmethod
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""URL path template to render."""
@property
diff --git a/friendli/client/catalog.py b/friendli/client/catalog.py
index 05939870..918b0dd1 100644
--- a/friendli/client/catalog.py
+++ b/friendli/client/catalog.py
@@ -4,7 +4,6 @@
from __future__ import annotations
-from string import Template
from typing import Any, Dict, List, Optional
from uuid import UUID
@@ -30,9 +29,9 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_mr_uri("catalogs/"))
+ return self.url_provider.get_mr_uri("catalogs/")
def get_catalog(self, catalog_id: UUID) -> Dict[str, Any]:
"""Get a public checkpoint in catalog."""
diff --git a/friendli/client/checkpoint.py b/friendli/client/checkpoint.py
index 143ae026..20d41abd 100644
--- a/friendli/client/checkpoint.py
+++ b/friendli/client/checkpoint.py
@@ -6,7 +6,6 @@
from __future__ import annotations
-from string import Template
from typing import Any, Dict, List, Optional
from uuid import UUID
@@ -24,9 +23,9 @@ class CheckpointClient(Client[UUID]):
"""Checkpoint client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_mr_uri("models/"))
+ return self.url_provider.get_mr_uri("models/")
def get_checkpoint(self, checkpoint_id: UUID) -> Dict[str, Any]:
"""Get a checkpoint info."""
@@ -57,9 +56,9 @@ class CheckpointFormClient(UploadableClient[UUID]):
"""Checkpoint form client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_mr_uri("model_forms/"))
+ return self.url_provider.get_mr_uri("model_forms/")
def update_checkpoint_files(
self,
@@ -89,11 +88,9 @@ def __init__(self, **kwargs):
super().__init__(group_id=self.group_id, project_id=self.project_id, **kwargs)
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(
- self.url_provider.get_mr_uri("orgs/$group_id/prjs/$project_id/models/")
- )
+ return self.url_provider.get_mr_uri("orgs/$group_id/prjs/$project_id/models/")
def list_checkpoints(
self, category: Optional[CheckpointCategory], limit: int, deleted: bool
diff --git a/friendli/client/credential.py b/friendli/client/credential.py
index cc08257e..ea6e7f66 100644
--- a/friendli/client/credential.py
+++ b/friendli/client/credential.py
@@ -5,7 +5,6 @@
from __future__ import annotations
-from string import Template
from typing import Any, Dict, Optional
from uuid import UUID
@@ -18,9 +17,9 @@ class CredentialClient(Client[UUID]):
"""Credential client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_auth_uri("credential"))
+ return self.url_provider.get_auth_uri("credential")
def get_credential(self, credential_id: UUID) -> Dict[str, Any]:
"""Get a credential info."""
@@ -62,11 +61,9 @@ class CredentialTypeClient(Client):
"""Credential type client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(
- self.url_provider.get_training_uri("credential_type/")
- ) # TODO: move this out of the training API
+ return self.url_provider.get_training_uri("credential_type/")
def get_schema_by_type(self, cred_type: CredType) -> Optional[Dict[str, Any]]:
"""Get a credential JSON schema."""
diff --git a/friendli/client/deployment.py b/friendli/client/deployment.py
index 70d8c351..211b30d3 100644
--- a/friendli/client/deployment.py
+++ b/friendli/client/deployment.py
@@ -5,7 +5,6 @@
from __future__ import annotations
from datetime import datetime
-from string import Template
from typing import Any, Dict, List, Optional
from friendli.client.base import Client, ProjectRequestMixin
@@ -15,9 +14,9 @@ class DeploymentClient(Client[str]):
"""Deployment client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_serving_uri("deployment/"))
+ return self.url_provider.get_serving_uri("deployment/")
def get_deployment(self, deployment_id: str) -> Dict[str, Any]:
"""Get a deployment info."""
@@ -78,11 +77,9 @@ class DeploymentLogClient(Client[str]):
"""Deployment log client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(
- self.url_provider.get_serving_uri("deployment/$deployment_id/log/")
- )
+ return self.url_provider.get_serving_uri("deployment/$deployment_id/log/")
def get_deployment_logs(self, replica_index: int) -> List[Dict[str, Any]]:
"""Get logs from a deployment."""
@@ -97,11 +94,9 @@ class DeploymentMetricsClient(Client):
"""Deployment metrics client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(
- self.url_provider.get_serving_uri("deployment/$deployment_id/metrics/")
- )
+ return self.url_provider.get_serving_uri("deployment/$deployment_id/metrics/")
def get_metrics(
self, start: datetime, end: datetime, time_window: int
@@ -123,11 +118,9 @@ class DeploymentEventClient(Client):
"""Deployment event client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(
- self.url_provider.get_serving_uri("deployment/$deployment_id/event/")
- )
+ return self.url_provider.get_serving_uri("deployment/$deployment_id/event/")
def get_events(self) -> List[Dict[str, Any]]:
"""Get deployment events."""
@@ -141,12 +134,10 @@ class DeploymentReqRespClient(Client):
"""Deployment request-response client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(
- self.url_provider.get_serving_uri(
- "deployment/$deployment_id/req_resp/download/"
- )
+ return self.url_provider.get_serving_uri(
+ "deployment/$deployment_id/req_resp/download/"
)
def get_download_urls(self, start: datetime, end: datetime) -> List[Dict[str, str]]:
@@ -171,11 +162,9 @@ def __init__(self, **kwargs):
super().__init__(project_id=self.project_id, **kwargs)
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(
- self.url_provider.get_serving_uri("usage/project/$project_id/duration")
- )
+ return self.url_provider.get_serving_uri("usage/project/$project_id/duration")
def get_project_deployment_durations(
self,
@@ -198,9 +187,9 @@ class PFSVMClient(Client):
"""VM client for serving."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_serving_uri("vm/"))
+ return self.url_provider.get_serving_uri("vm/")
def list_vms(self) -> List[Dict[str, Any]]:
"""List all VM info."""
diff --git a/friendli/client/file.py b/friendli/client/file.py
index eb2e5ae3..11b1273a 100644
--- a/friendli/client/file.py
+++ b/friendli/client/file.py
@@ -4,7 +4,6 @@
from __future__ import annotations
-from string import Template
from typing import Any, Dict
from uuid import UUID
@@ -20,9 +19,9 @@ class FileClient(Client[UUID]):
"""File client service."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_mr_uri("files/"))
+ return self.url_provider.get_mr_uri("files/")
def get_misc_file_upload_url(self, misc_file_id: UUID) -> str:
"""Get an URL to upload file.
@@ -77,11 +76,9 @@ def __init__(self, **kwargs):
super().__init__(group_id=self.group_id, project_id=self.project_id, **kwargs)
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(
- self.url_provider.get_mr_uri("orgs/$group_id/prjs/$project_id/files/")
- )
+ return self.url_provider.get_mr_uri("orgs/$group_id/prjs/$project_id/files/")
def create_misc_file(self, file_info: Dict[str, Any]) -> Dict[str, Any]:
"""Request to create a misc file.
diff --git a/friendli/client/group.py b/friendli/client/group.py
index 844baf82..5f0517b7 100644
--- a/friendli/client/group.py
+++ b/friendli/client/group.py
@@ -8,7 +8,6 @@
import json
import uuid
-from string import Template
from typing import Any, Dict, List
from friendli.client.base import Client, GroupRequestMixin
@@ -18,9 +17,9 @@ class GroupClient(Client):
"""Organization client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_auth_uri("pf_group"))
+ return self.url_provider.get_auth_uri("pf_group")
def create_group(self, name: str) -> Dict[str, Any]:
"""Create a new organization."""
@@ -62,11 +61,9 @@ def __init__(self, **kwargs):
super().__init__(pf_group_id=self.group_id, **kwargs)
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(
- self.url_provider.get_auth_uri("pf_group/$pf_group_id/pf_project")
- )
+ return self.url_provider.get_auth_uri("pf_group/$pf_group_id/pf_project")
def create_project(self, name: str) -> Dict[str, Any]:
"""Create a new project in the organization."""
diff --git a/friendli/client/login.py b/friendli/client/login.py
new file mode 100644
index 00000000..ab40ea4f
--- /dev/null
+++ b/friendli/client/login.py
@@ -0,0 +1,33 @@
+# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved.
+
+"""Login Client."""
+
+from __future__ import annotations
+
+from typing import Tuple
+
+from friendli.client.base import Client
+from friendli.settings import Settings
+
+
+class LoginClient(Client):
+ """Login client."""
+
+ @property
+ def url_path(self) -> str:
+ """Get an URL path."""
+ return self.url_provider.get_web_backend_uri("/api/auth/login")
+
+ def login(self, email: str, pwd: str) -> Tuple[str, str]:
+ """Send request to sign in with email and password."""
+ settings = self.injector.get(Settings)
+ payload = {
+ "email": email,
+ "password": pwd,
+ }
+ headers = {"Accept": "application/json"}
+ resp = self.bare_post(json=payload, headers=headers)
+ cookies = resp.cookies
+ access_token = cookies[settings.access_token_cookie_key]
+ refresh_token = cookies[settings.refresh_token_cookie_key]
+ return access_token, refresh_token
diff --git a/friendli/client/project.py b/friendli/client/project.py
index 4d585248..e7e5e0ed 100644
--- a/friendli/client/project.py
+++ b/friendli/client/project.py
@@ -6,7 +6,6 @@
from __future__ import annotations
-from string import Template
from typing import Any, Dict, List, Optional
from uuid import UUID
@@ -30,9 +29,9 @@ class ProjectClient(Client[UUID]):
"""Project client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_auth_uri("pf_project"))
+ return self.url_provider.get_auth_uri("pf_project")
def get_project(self, pf_project_id: UUID) -> Dict[str, Any]:
"""Get project info."""
@@ -69,11 +68,9 @@ def __init__(self, **kwargs):
super().__init__(project_id=self.project_id, **kwargs)
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(
- self.url_provider.get_auth_uri("pf_project/$project_id/credential")
- )
+ return self.url_provider.get_auth_uri("pf_project/$project_id/credential")
def list_credentials(
self, cred_type: Optional[CredType] = None
diff --git a/friendli/client/user.py b/friendli/client/user.py
index 818fc81f..bd96b11b 100644
--- a/friendli/client/user.py
+++ b/friendli/client/user.py
@@ -4,7 +4,6 @@
from __future__ import annotations
-from string import Template
from typing import Any, Dict, List
from uuid import UUID
@@ -16,9 +15,9 @@ class UserMFAClient(Client):
"""User MFA client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_auth_uri("mfa"))
+ return self.url_provider.get_auth_uri("mfa")
def initiate_mfa(self, mfa_type: str, mfa_token: str) -> None:
"""Authenticate by MFA token."""
@@ -29,9 +28,9 @@ class UserSignUpClient(Client):
"""User sign-up client."""
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_auth_uri("pf_user/self_signup"))
+ return self.url_provider.get_auth_uri("pf_user/self_signup")
def verify(self, token: str, key: str) -> None:
"""Verify the email account with the token to sign up."""
@@ -47,9 +46,9 @@ def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_auth_uri("pf_user"))
+ return self.url_provider.get_auth_uri("pf_user")
def change_password(self, old_password: str, new_password: str) -> None:
"""Change password."""
@@ -114,9 +113,9 @@ def __init__(self, **kwargs):
super().__init__(pf_user_id=self.user_id, **kwargs)
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_auth_uri("pf_user/$pf_user_id/pf_group"))
+ return self.url_provider.get_auth_uri("pf_user/$pf_user_id/pf_group")
def get_group_info(self) -> Dict[str, Any]:
"""Get organization info where user belongs to."""
@@ -134,12 +133,10 @@ def __init__(self, **kwargs):
super().__init__(pf_user_id=self.user_id, pf_group_id=self.group_id, **kwargs)
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(
- self.url_provider.get_auth_uri(
- "pf_user/$pf_user_id/pf_group/$pf_group_id/pf_project"
- )
+ return self.url_provider.get_auth_uri(
+ "pf_user/$pf_user_id/pf_group/$pf_group_id/pf_project"
)
def list_projects(self) -> List[Dict[str, Any]]:
@@ -157,9 +154,9 @@ def __init__(self, **kwargs):
super().__init__(pf_user_id=self.user_id, **kwargs)
@property
- def url_path(self) -> Template:
+ def url_path(self) -> str:
"""Get an URL path."""
- return Template(self.url_provider.get_auth_uri("pf_user"))
+ return self.url_provider.get_auth_uri("pf_user")
def create_access_key(self, name: str) -> Dict[str, Any]:
"""Create a new access key."""
diff --git a/friendli/di/modules.py b/friendli/di/modules.py
index 7a045a4e..14c57866 100644
--- a/friendli/di/modules.py
+++ b/friendli/di/modules.py
@@ -6,15 +6,17 @@
from injector import Binder, Module
+from friendli import settings
from friendli.utils import url
-class URLModule(Module):
+class SettingsModule(Module):
"""Friendli client module."""
def configure(self, binder: Binder) -> None:
"""Configures bindings for clients."""
binder.bind(url.URLProvider, to=url.ProductionURLProvider) # type: ignore
+ binder.bind(settings.Settings, to=settings.ProductionSettings) # type: ignore
-default_modules = [URLModule]
+default_modules = [SettingsModule]
diff --git a/friendli/graphql/__init__.py b/friendli/graphql/__init__.py
new file mode 100644
index 00000000..1b0661fd
--- /dev/null
+++ b/friendli/graphql/__init__.py
@@ -0,0 +1,3 @@
+# Copyright (c) 2022-present, FriendliAI Inc. All rights reserved.
+
+"""Friendli graphql clients to interact with Friendli system."""
diff --git a/friendli/graphql/base.py b/friendli/graphql/base.py
new file mode 100644
index 00000000..52abbc55
--- /dev/null
+++ b/friendli/graphql/base.py
@@ -0,0 +1,29 @@
+# Copyright (c) 2022-present, FriendliAI Inc. All rights reserved.
+
+"""Friendli GQL Client Service."""
+
+from __future__ import annotations
+
+from typing import Any, Dict, Optional
+
+from friendli.client.base import Client
+
+
+class GqlClient(Client):
+ """Base interface of graphql client to Friendli system."""
+
+ @property
+ def url_path(self) -> str:
+ """URL path template to render."""
+ return self.url_provider.get_web_backend_uri("api/graphql")
+
+ def run(
+ self, query: str, variables: Optional[Dict[str, Any]] = None
+ ) -> Dict[str, Any]:
+ """Run graphql."""
+ return self.post(
+ json={
+ "query": query,
+ "variables": variables,
+ }
+ )["data"]
diff --git a/friendli/graphql/user.py b/friendli/graphql/user.py
new file mode 100644
index 00000000..f23bcf18
--- /dev/null
+++ b/friendli/graphql/user.py
@@ -0,0 +1,30 @@
+# Copyright (c) 2022-present, FriendliAI Inc. All rights reserved.
+
+"""Friendli User GQL Clients."""
+
+from __future__ import annotations
+
+from typing import Any, Dict
+
+from friendli.graphql.base import GqlClient
+
+CurrUserInfoGql = """
+query GetclientSession {
+ clientSession {
+ user {
+ id
+ name
+ email
+ }
+ }
+}
+"""
+
+
+class UserGqlClient(GqlClient):
+ """User gql client."""
+
+ def get_current_user_info(self) -> Dict[str, Any]:
+ """Get current user info."""
+ response = self.run(query=CurrUserInfoGql)
+ return response["clientSession"]["user"]
diff --git a/friendli/logging.py b/friendli/logging.py
index e1359bb9..b1d27e65 100644
--- a/friendli/logging.py
+++ b/friendli/logging.py
@@ -7,10 +7,42 @@
import logging
import os
-_formatter = logging.Formatter(
- fmt="%(asctime)s.%(msecs)05d: %(name)s %(levelname)s: %(message)s",
- datefmt="%Y-%m-%d %H:%M:%S",
-)
+_formatter = logging.Formatter()
+
+
+class ColorFormatter(logging.Formatter):
+ """Customized formatter with ANSI color."""
+
+ grey = "\x1b[38;20m"
+ yellow = "\x1b[33;20m"
+ red = "\x1b[31;20m"
+ bold_red = "\x1b[31;1m"
+ reset = "\x1b[0m"
+
+ default_fmt = "%(asctime)s.%(msecs)05d: %(name)s %(levelname)s: %(message)s"
+ default_datefmt = "%Y-%m-%d %H:%M:%S"
+
+ FORMATS = {
+ logging.DEBUG: grey + default_fmt + reset,
+ logging.INFO: grey + default_fmt + reset,
+ logging.WARNING: yellow + default_fmt + reset,
+ logging.ERROR: red + default_fmt + reset,
+ logging.CRITICAL: bold_red + default_fmt + reset,
+ }
+
+ def __init__(self):
+ """Initialize CustomFormatter."""
+ super().__init__(fmt=self.default_fmt, datefmt=self.default_datefmt)
+
+ # Pre-create Formatter objects for each level to improve efficiency
+ self.formatters = {
+ level: logging.Formatter(fmt) for level, fmt in self.FORMATS.items()
+ }
+
+ def format(self, record):
+ """Override format method."""
+ formatter = self.formatters.get(record.levelno, self.formatters[logging.INFO])
+ return formatter.format(record)
def get_logger(name: str) -> logging.Logger:
@@ -18,7 +50,8 @@ def get_logger(name: str) -> logging.Logger:
logger = logging.getLogger(name)
handler = logging.StreamHandler()
- handler.setFormatter(_formatter)
+ formatter = ColorFormatter()
+ handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(os.environ.get("FRIENDLI_LOG_LEVEL", "INFO"))
diff --git a/friendli/modules/converter/models/mixtral.py b/friendli/modules/converter/models/mixtral.py
index b8a82980..5cf5a366 100644
--- a/friendli/modules/converter/models/mixtral.py
+++ b/friendli/modules/converter/models/mixtral.py
@@ -160,7 +160,7 @@ def decoder_convert_info_list(
@property
def model_type(self) -> str:
"""Model type."""
- return "mistral"
+ return "mixtral"
@property
def decoder_layer_num(self) -> int:
diff --git a/friendli/settings.py b/friendli/settings.py
new file mode 100644
index 00000000..342ae30e
--- /dev/null
+++ b/friendli/settings.py
@@ -0,0 +1,31 @@
+# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved.
+
+"""CLI App Settings."""
+
+
+class Settings:
+ """CLI app settings."""
+
+ access_token_cookie_key = ""
+ refresh_token_cookie_key = ""
+
+
+class ProductionSettings:
+ """Production CLI app settings."""
+
+ access_token_cookie_key = "sAccessTokenProduction"
+ refresh_token_cookie_key = "sRefreshTokenProduction"
+
+
+class StagingSettings:
+ """Staging CLI app settings."""
+
+ access_token_cookie_key = "sAccessTokenStaging"
+ refresh_token_cookie_key = "sRefreshTokenStaging"
+
+
+class DevSettings:
+ """Dev CLI app settings."""
+
+ access_token_cookie_key = "sAccessTokenDev"
+ refresh_token_cookie_key = "sRefreshTokenDev"
diff --git a/friendli/utils/url.py b/friendli/utils/url.py
index 2419c4fe..2176c764 100644
--- a/friendli/utils/url.py
+++ b/friendli/utils/url.py
@@ -22,6 +22,8 @@ def get_host(url: str) -> str:
class URLProvider:
"""Service URL provider."""
+ suite_url = ""
+ api_url = ""
training_url = ""
registry_url = ""
serving_url = ""
@@ -30,6 +32,11 @@ class URLProvider:
observatory_url = ""
web_backend_url = ""
+ @classmethod
+ def get_suite_uri(cls, path: str) -> str:
+ """Get Friendli Suite URI."""
+ return urljoin(cls.suite_url, path)
+
@classmethod
def get_auth_uri(cls, path: str) -> str:
"""Get PFA URI."""
@@ -46,6 +53,11 @@ def get_training_uri(cls, path: str) -> str:
"""Get PFT URI."""
return urljoin(cls.training_url, path)
+ @classmethod
+ def get_api_uri(cls, path: str) -> str:
+ """Get PFT URI."""
+ return urljoin(cls.api_url, path)
+
@classmethod
def get_serving_uri(cls, path: str) -> str:
"""Get PFS URI."""
@@ -70,37 +82,37 @@ def get_observatory_uri(cls, path: str) -> str:
class ProductionURLProvider(URLProvider):
"""Production service URL provider."""
- training_url = "https://training.friendli.ai/api/"
- training_ws_url = "wss://training-ws.friendli.ai/ws/"
+ suite_url = "https://suite.friendli.ai/"
registry_url = "https://modelregistry.friendli.ai/"
serving_url = "https://serving.friendli.ai/"
auth_url = "https://auth.friendli.ai/"
meter_url = "https://metering.friendli.ai/"
observatory_url = "https://observatory.friendli.ai/"
- web_backend_url = "https://cloud.friendli.ai/"
+ web_backend_url = "https://suite.friendli.ai/"
+ training_url = "https://training.friendli.ai/api/"
class StagingURLProvider(URLProvider):
"""Staging service URL provider."""
- training_url = "https://api-staging.friendli.ai/api/"
- training_ws_url = "wss://api-ws-staging.friendli.ai/ws/"
+ suite_url = "https://suite-staging.friendli.ai/"
registry_url = "https://pfmodelregistry-staging.friendli.ai/"
serving_url = "https://pfs-staging.friendli.ai/"
auth_url = "https://pfauth-staging.friendli.ai/"
meter_url = "https://pfmeter-staging.friendli.ai/"
observatory_url = "https://pfo-staging.friendli.ai/"
web_backend_url = "https://api-staging.friendli.ai/"
+ training_url = "https://api-staging.friendli.ai/api/"
class DevURLProvider(URLProvider):
"""Dev service URL provider."""
- training_url = "https://api-dev.friendli.ai/api/"
- training_ws_url = "wss://api-ws-dev.friendli.ai/ws/"
+ suite_url = "https://suite-dev.friendli.ai/"
registry_url = "https://pfmodelregistry-dev.friendli.ai/"
serving_url = "https://pfs-dev.friendli.ai/"
auth_url = "https://pfauth-dev.friendli.ai/"
meter_url = "https://pfmeter-dev.friendli.ai/"
observatory_url = "https://pfo-dev.friendli.ai/"
web_backend_url = "https://api-dev.friendli.ai/"
+ training_url = "https://api-dev.friendli.ai/api/"
diff --git a/poetry.lock b/poetry.lock
index f2d63831..619148a8 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "accelerate"
@@ -1221,11 +1221,30 @@ files = [
[package.extras]
test = ["pytest (>=6)"]
+[[package]]
+name = "fastapi"
+version = "0.109.2"
+description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "fastapi-0.109.2-py3-none-any.whl", hash = "sha256:2c9bab24667293b501cad8dd388c05240c850b58ec5876ee3283c47d6e1e3a4d"},
+ {file = "fastapi-0.109.2.tar.gz", hash = "sha256:f3817eac96fe4f65a2ebb4baa000f394e55f5fccdaf7f75250804bc58f354f73"},
+]
+
+[package.dependencies]
+pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0"
+starlette = ">=0.36.3,<0.37.0"
+typing-extensions = ">=4.8.0"
+
+[package.extras]
+all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
+
[[package]]
name = "filelock"
version = "3.12.2"
description = "A platform independent file lock."
-optional = false
+optional = true
python-versions = ">=3.7"
files = [
{file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"},
@@ -2620,7 +2639,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
- {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@@ -3016,6 +3034,24 @@ files = [
{file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"},
]
+[[package]]
+name = "starlette"
+version = "0.36.3"
+description = "The little ASGI library that shines."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "starlette-0.36.3-py3-none-any.whl", hash = "sha256:13d429aa93a61dc40bf503e8c801db1f1bca3dc706b10ef2434a36123568f044"},
+ {file = "starlette-0.36.3.tar.gz", hash = "sha256:90a671733cfb35771d8cc605e0b679d23b992f8dcfad48cc60b38cb29aeb7080"},
+]
+
+[package.dependencies]
+anyio = ">=3.4.0,<5"
+typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""}
+
+[package.extras]
+full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
+
[[package]]
name = "sympy"
version = "1.12"
@@ -3441,13 +3477,13 @@ files = [
[[package]]
name = "typing-extensions"
-version = "4.7.1"
-description = "Backported and Experimental Type Hints for Python 3.7+"
+version = "4.9.0"
+description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"},
- {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"},
+ {file = "typing_extensions-4.9.0-py3-none-any.whl", hash = "sha256:af72aea155e91adfc61c3ae9e0e342dbc0cba726d6cba4b6c72c1f34e47291cd"},
+ {file = "typing_extensions-4.9.0.tar.gz", hash = "sha256:23478f88c37f27d76ac8aee6c905017a143b0b1b886c3c9f66bc2fd94f9f5783"},
]
[[package]]
@@ -3478,62 +3514,24 @@ secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "p
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
[[package]]
-name = "websockets"
-version = "10.1"
-description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
+name = "uvicorn"
+version = "0.27.0.post1"
+description = "The lightning-fast ASGI server."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "websockets-10.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:38db6e2163b021642d0a43200ee2dec8f4980bdbda96db54fde72b283b54cbfc"},
- {file = "websockets-10.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e1b60fd297adb9fc78375778a5220da7f07bf54d2a33ac781319650413fc6a60"},
- {file = "websockets-10.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3477146d1f87ead8df0f27e8960249f5248dceb7c2741e8bbec9aa5338d0c053"},
- {file = "websockets-10.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb01ea7b5f52e7125bdc3c5807aeaa2d08a0553979cf2d96a8b7803ea33e15e7"},
- {file = "websockets-10.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9fd62c6dc83d5d35fb6a84ff82ec69df8f4657fff05f9cd6c7d9bec0dd57f0f6"},
- {file = "websockets-10.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:3bbf080f3892ba1dc8838786ec02899516a9d227abe14a80ef6fd17d4fb57127"},
- {file = "websockets-10.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5560558b0dace8312c46aa8915da977db02738ac8ecffbc61acfbfe103e10155"},
- {file = "websockets-10.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:667c41351a6d8a34b53857ceb8343a45c85d438ee4fd835c279591db8aeb85be"},
- {file = "websockets-10.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:468f0031fdbf4d643f89403a66383247eb82803430b14fa27ce2d44d2662ca37"},
- {file = "websockets-10.1-cp310-cp310-win32.whl", hash = "sha256:d0d81b46a5c87d443e40ce2272436da8e6092aa91f5fbeb60d1be9f11eff5b4c"},
- {file = "websockets-10.1-cp310-cp310-win_amd64.whl", hash = "sha256:b68b6caecb9a0c6db537aa79750d1b592a841e4f1a380c6196091e65b2ad35f9"},
- {file = "websockets-10.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a249139abc62ef333e9e85064c27fefb113b16ffc5686cefc315bdaef3eefbc8"},
- {file = "websockets-10.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8877861e3dee38c8d302eee0d5dbefa6663de3b46dc6a888f70cd7e82562d1f7"},
- {file = "websockets-10.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e3872ae57acd4306ecf937d36177854e218e999af410a05c17168cd99676c512"},
- {file = "websockets-10.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:b66e6d514f12c28d7a2d80bb2a48ef223342e99c449782d9831b0d29a9e88a17"},
- {file = "websockets-10.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:9f304a22ece735a3da8a51309bc2c010e23961a8f675fae46fdf62541ed62123"},
- {file = "websockets-10.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:189ed478395967d6a98bb293abf04e8815349e17456a0a15511f1088b6cb26e4"},
- {file = "websockets-10.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:08a42856158307e231b199671c4fce52df5786dd3d703f36b5d8ac76b206c485"},
- {file = "websockets-10.1-cp37-cp37m-win32.whl", hash = "sha256:3ef6f73854cded34e78390dbdf40dfdcf0b89b55c0e282468ef92646fce8d13a"},
- {file = "websockets-10.1-cp37-cp37m-win_amd64.whl", hash = "sha256:89e985d40d407545d5f5e2e58e1fdf19a22bd2d8cd54d20a882e29f97e930a0a"},
- {file = "websockets-10.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:002071169d2e44ce8eb9e5ebac9fbce142ba4b5146eef1cfb16b177a27662657"},
- {file = "websockets-10.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cfae282c2aa7f0c4be45df65c248481f3509f8c40ca8b15ed96c35668ae0ff69"},
- {file = "websockets-10.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:97b4b68a2ddaf5c4707ae79c110bfd874c5be3c6ac49261160fb243fa45d8bbb"},
- {file = "websockets-10.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c9407719f42cb77049975410490c58a705da6af541adb64716573e550e5c9db"},
- {file = "websockets-10.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:1d858fb31e5ac992a2cdf17e874c95f8a5b1e917e1fb6b45ad85da30734b223f"},
- {file = "websockets-10.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7bdd3d26315db0a9cf8a0af30ca95e0aa342eda9c1377b722e71ccd86bc5d1dd"},
- {file = "websockets-10.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e259be0863770cb91b1a6ccf6907f1ac2f07eff0b7f01c249ed751865a70cb0d"},
- {file = "websockets-10.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6b014875fae19577a392372075e937ebfebf53fd57f613df07b35ab210f31534"},
- {file = "websockets-10.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:98de71f86bdb29430fd7ba9997f47a6b10866800e3ea577598a786a785701bb0"},
- {file = "websockets-10.1-cp38-cp38-win32.whl", hash = "sha256:3a02ab91d84d9056a9ee833c254895421a6333d7ae7fff94b5c68e4fa8095519"},
- {file = "websockets-10.1-cp38-cp38-win_amd64.whl", hash = "sha256:7d6673b2753f9c5377868a53445d0c321ef41ff3c8e3b6d57868e72054bfce5f"},
- {file = "websockets-10.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ddab2dc69ee5ae27c74dbfe9d7bb6fee260826c136dca257faa1a41d1db61a89"},
- {file = "websockets-10.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:14e9cf68a08d1a5d42109549201aefba473b1d925d233ae19035c876dd845da9"},
- {file = "websockets-10.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e4819c6fb4f336fd5388372cb556b1f3a165f3f68e66913d1a2fc1de55dc6f58"},
- {file = "websockets-10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05e7f098c76b0a4743716590bb8f9706de19f1ef5148d61d0cf76495ec3edb9c"},
- {file = "websockets-10.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5bb6256de5a4fb1d42b3747b4e2268706c92965d75d0425be97186615bf2f24f"},
- {file = "websockets-10.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:888a5fa2a677e0c2b944f9826c756475980f1b276b6302e606f5c4ff5635be9e"},
- {file = "websockets-10.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6fdec1a0b3e5630c58e3d8704d2011c678929fce90b40908c97dfc47de8dca72"},
- {file = "websockets-10.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:531d8eb013a9bc6b3ad101588182aa9b6dd994b190c56df07f0d84a02b85d530"},
- {file = "websockets-10.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0d93b7cadc761347d98da12ec1930b5c71b2096f1ceed213973e3cda23fead9c"},
- {file = "websockets-10.1-cp39-cp39-win32.whl", hash = "sha256:d9b245db5a7e64c95816e27d72830e51411c4609c05673d1ae81eb5d23b0be54"},
- {file = "websockets-10.1-cp39-cp39-win_amd64.whl", hash = "sha256:882c0b8bdff3bf1bd7f024ce17c6b8006042ec4cceba95cf15df57e57efa471c"},
- {file = "websockets-10.1-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:10edd9d7d3581cfb9ff544ac09fc98cab7ee8f26778a5a8b2d5fd4b0684c5ba5"},
- {file = "websockets-10.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:baa83174390c0ff4fc1304fbe24393843ac7a08fdd59295759c4b439e06b1536"},
- {file = "websockets-10.1-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:483edee5abed738a0b6a908025be47f33634c2ad8e737edd03ffa895bd600909"},
- {file = "websockets-10.1-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:816ae7dac2c6522cfa620947ead0ca95ac654916eebf515c94d7c28de5601a6e"},
- {file = "websockets-10.1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:1dafe98698ece09b8ccba81b910643ff37198e43521d977be76caf37709cf62b"},
- {file = "websockets-10.1.tar.gz", hash = "sha256:181d2b25de5a437b36aefedaf006ecb6fa3aa1328ec0236cdde15f32f9d3ff6d"},
+ {file = "uvicorn-0.27.0.post1-py3-none-any.whl", hash = "sha256:4b85ba02b8a20429b9b205d015cbeb788a12da527f731811b643fd739ef90d5f"},
+ {file = "uvicorn-0.27.0.post1.tar.gz", hash = "sha256:54898fcd80c13ff1cd28bf77b04ec9dbd8ff60c5259b499b4b12bb0917f22907"},
]
+[package.dependencies]
+click = ">=7.0"
+h11 = ">=0.8"
+typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""}
+
+[package.extras]
+standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
+
[[package]]
name = "wrapt"
version = "1.15.0"
@@ -3820,4 +3818,4 @@ mllib = ["accelerate", "datasets", "einops", "h5py", "peft", "transformers"]
[metadata]
lock-version = "2.0"
python-versions = "^3.8"
-content-hash = "041119a17bb7b489082c252360d9cf2bc11e1e7103ba60c07afeb5f6dd4c0a4d"
+content-hash = "4a2eb692717176c07248be4eb0ae543653edb1a0c0d4da1e35213b79c1dec6cb"
diff --git a/pyproject.toml b/pyproject.toml
index e6b6374d..903868bd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "friendli-client"
-version = "1.2.0"
+version = "1.2.1"
description = "Client of Friendli Suite."
license = "Apache-2.0"
authors = ["FriendliAI teams "]
@@ -26,57 +26,55 @@ priority = "primary"
[tool.poetry.dependencies]
python = "^3.8"
-filelock = "3.12.2"
-requests = "2.31.0"
-websockets = "10.1"
-PyYaml = "6.0.1"
-typer = "0.9.0"
-rich = "12.2.0"
-jsonschema = "4.17.3"
-boto3 = "1.22.8"
-botocore = "1.25.8"
-tqdm = "4.64.0"
-azure-mgmt-storage = "20.1.0"
-azure-storage-blob = "12.12.0"
-packaging = "23.1"
-pathspec = "0.9.0"
-boto3-stubs = "1.26.90"
-mypy-boto3-s3 = "1.26.163"
-ruamel-yaml = "0.17.32"
-pydantic = {extras = ["email"], version = "2.0.2"}
+requests = "^2.31.0"
+PyYaml = "^6.0.1"
+typer = "^0.9.0"
+rich = "^12.2.0"
+jsonschema = "^4.17.3"
+boto3 = "^1.22.8"
+botocore = "^1.25.8"
+tqdm = "^4.64.0"
+azure-mgmt-storage = "^20.1.0"
+azure-storage-blob = "^12.12.0"
+pathspec = "^0.9.0"
+boto3-stubs = "^1.26.90"
+mypy-boto3-s3 = "^1.26.163"
+ruamel-yaml = "^0.17.32"
+pydantic = {extras = ["email"], version = "^2.0.2"}
transformers = { version = "4.36.2", optional = true }
-h5py = { version = "3.9.0", optional = true }
-einops = { version = "0.6.1", optional = true }
+h5py = { version = "^3.9.0", optional = true }
+einops = { version = "^0.6.1", optional = true }
accelerate = { version = "0.21.0", optional = true }
datasets = { version = "2.16.0", optional = true }
-injector = "0.21.0"
-protobuf = "4.24.2"
-types-protobuf = "4.24.0.1"
+injector = "^0.21.0"
+protobuf = "^4.24.2"
+types-protobuf = "^4.24.0.1"
peft = { version = "0.6.0", optional = true }
-httpx = "0.26.0"
+httpx = "^0.26.0"
+fastapi = "^0.109.2"
+uvicorn = "^0.27.0.post1"
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
-typer = "0.9.0"
-pytest = "7.4.0"
-coverage = "7.2.7"
-pytest-asyncio = "0.15.1"
-pytest-cov = "4.1.0"
-requests-mock = "1.11.0"
-black = "23.3.0"
-isort = "5.12.0"
-mypy = "1.4.1"
-pydocstyle = "6.3.0"
-pylint = "2.17.4"
-toml = "0.10.2"
-types-pyyaml = "6.0.12.10"
-types-jsonschema = "4.17.0.8"
-types-python-dateutil = "2.8.19.13"
-types-requests = "2.31.0.1"
-types-toml = "0.10.8.6"
-types-tqdm = "4.65.0.1"
+pytest = "^7.4.0"
+coverage = "^7.2.7"
+pytest-asyncio = "^0.15.1"
+pytest-cov = "^4.1.0"
+requests-mock = "^1.11.0"
+black = "^23.3.0"
+isort = "^5.12.0"
+mypy = "^1.4.1"
+pydocstyle = "^6.3.0"
+pylint = "^2.17.4"
+toml = "^0.10.2"
+types-pyyaml = "^6.0.12.10"
+types-jsonschema = "^4.17.0.8"
+types-python-dateutil = "^2.8.19.13"
+types-requests = "^2.31.0.1"
+types-toml = "^0.10.8.6"
+types-tqdm = "^4.65.0.1"
[tool.poetry.extras]
mllib = ["transformers", "h5py", "accelerate", "einops", "datasets", "peft"]
diff --git a/tests/unit_tests/client/test_base.py b/tests/unit_tests/client/test_base.py
index 69a2a872..de171594 100644
--- a/tests/unit_tests/client/test_base.py
+++ b/tests/unit_tests/client/test_base.py
@@ -85,8 +85,8 @@ def test_client_service_base(
class TestClient(Client[int]):
@property
- def url_path(self) -> Template:
- return Template(url_pattern)
+ def url_path(self) -> str:
+ return url_pattern
client = TestClient(test_id=1)