Skip to content

Commit

Permalink
Update patch version v1.2.3
Browse files Browse the repository at this point in the history
  • Loading branch information
kooyunmo committed Feb 16, 2024
1 parent dc05114 commit 88d61c8
Show file tree
Hide file tree
Showing 20 changed files with 119 additions and 53 deletions.
5 changes: 3 additions & 2 deletions friendli/cli/api/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from friendli.schema.api.v1.chat.completions import MessageParam
from friendli.sdk.client import Friendli
from friendli.utils.compat import model_dump
from friendli.utils.decorator import check_api
from friendli.utils.format import secho_error_and_exit

Expand Down Expand Up @@ -136,7 +137,7 @@ def create(
)
for chunk in stream:
if n is not None and n > 1:
typer.echo(chunk.model_dump())
typer.echo(model_dump(chunk))
else:
typer.echo(chunk.choices[0].delta.content or "", nl=False)
else:
Expand All @@ -152,7 +153,7 @@ def create(
temperature=temperature,
top_p=top_p,
)
typer.echo(chat_completion.model_dump())
typer.echo(model_dump(chat_completion))


def _prepare_messages(messages: List[str]) -> List[MessageParam]:
Expand Down
5 changes: 3 additions & 2 deletions friendli/cli/api/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import typer

from friendli.sdk.client import Friendli
from friendli.utils.compat import model_dump
from friendli.utils.decorator import check_api

app = typer.Typer(
Expand Down Expand Up @@ -129,7 +130,7 @@ def create(
)
for chunk in stream:
if n is not None and n > 1:
typer.echo(chunk.model_dump())
typer.echo(model_dump(chunk))
else:
typer.echo(chunk.text, nl=False)
else:
Expand All @@ -145,4 +146,4 @@ def create(
temperature=temperature,
top_p=top_p,
)
typer.echo(completion.model_dump())
typer.echo(model_dump(completion))
3 changes: 2 additions & 1 deletion friendli/cli/api/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from friendli.enums import ResponseFormat
from friendli.sdk.client import Friendli
from friendli.utils.compat import model_dump
from friendli.utils.decorator import check_api

app = typer.Typer(
Expand Down Expand Up @@ -106,4 +107,4 @@ def create(
seed=seed,
response_format=response_format,
)
typer.echo(image.model_dump())
typer.echo(model_dump(image))
5 changes: 3 additions & 2 deletions friendli/cli/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
TableFormatter,
TreeFormatter,
)
from friendli.utils.compat import model_parse
from friendli.utils.format import secho_error_and_exit

app = typer.Typer(
Expand Down Expand Up @@ -236,8 +237,8 @@ def convert(
quant_config_dict = cast(dict, yaml.safe_load(quant_config_file.read()))
except yaml.YAMLError as err:
secho_error_and_exit(f"Failed to load the quant config file: {err}")
quant_config = QuantConfig.model_validate(
{"config": quant_config_dict}
quant_config = model_parse(
QuantConfig, {"config": quant_config_dict}
).config
else:
quant_config = AWQConfig()
Expand Down
7 changes: 4 additions & 3 deletions friendli/cli/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from friendli.formatter import PanelFormatter, TableFormatter
from friendli.schema.config.deployment import DeploymentConfig
from friendli.sdk.client import Friendli
from friendli.utils.compat import model_dump, model_parse
from friendli.utils.decorator import check_api

app = typer.Typer(
Expand Down Expand Up @@ -105,7 +106,7 @@ def create(
"""Creates a deployment object by using model checkpoint."""
client = Friendli()
config_dict = yaml.safe_load(config_file)
config = DeploymentConfig.model_validate(config_dict)
config = model_parse(DeploymentConfig, config_dict)

deployment = client.deployment.create(
name=config.name,
Expand All @@ -115,7 +116,7 @@ def create(
adapter_eids=config.adapters,
launch_config=config.launch_config,
)
deployment_panel.render(deployment.model_dump())
deployment_panel.render(model_dump(deployment))


@app.command("list")
Expand All @@ -124,5 +125,5 @@ def list_deployments():
"""List deployments."""
client = Friendli()
deployments = client.deployment.list()
deployments_ = [deployment.model_dump() for deployment in iter(deployments)]
deployments_ = [model_dump(deployment) for deployment in iter(deployments)]
deployment_table.render(deployments_)
13 changes: 9 additions & 4 deletions friendli/schema/api/v1/images/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from typing import List, Literal, Union

from pydantic import AnyHttpUrl, Base64Bytes, BaseModel, Field
from typing_extensions import TypeAlias
from pydantic import AnyHttpUrl, BaseModel, Field
from typing_extensions import Annotated, TypeAlias

ImageResponseFormatParam: TypeAlias = Union[str, Literal["url", "png", "jpeg", "raw"]]

Expand All @@ -25,10 +25,15 @@ class ImageDataB64(BaseModel):

format: Literal["png", "jpeg", "raw"]
seed: int
b64_json: Base64Bytes
b64_json: str


_ImageData = Annotated[
Union[ImageDataUrl, ImageDataB64], Field(..., discriminator="format")
]


class Image(BaseModel):
"""Image data."""

data: List[Union[ImageDataUrl, ImageDataB64]] = Field(..., discriminator="format")
data: List[_ImageData]
6 changes: 5 additions & 1 deletion friendli/schema/resource/v1/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
from typing_extensions import Annotated

from friendli.enums import CheckpointDataType, QuantMode
from friendli.utils.compat import PYDANTIC_V2


class V1CommonAttributes(BaseModel):
"""V1 checkpoint attributes schema."""

model_config = ConfigDict(protected_namespaces=(), extra=Extra.forbid)
if PYDANTIC_V2:
model_config = ConfigDict(protected_namespaces=(), extra=Extra.forbid) # type: ignore
else:
model_config = ConfigDict(extra=Extra.forbid)

dtype: CheckpointDataType
quant_scheme: Optional[QuantMode] = None
Expand Down
4 changes: 3 additions & 1 deletion friendli/schema/resource/v1/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CheckpointStatus,
CheckpointValidationStatus,
)
from friendli.utils.compat import PYDANTIC_V2


class V1Catalog(BaseModel):
Expand Down Expand Up @@ -77,7 +78,8 @@ class V1CheckpointOwnership(BaseModel):
class V1Checkpoint(BaseModel):
"""V1 checkpoint schema."""

model_config = ConfigDict(protected_namespaces=())
if PYDANTIC_V2:
model_config = ConfigDict(protected_namespaces=()) # type: ignore

id: UUID
user_id: UUID
Expand Down
9 changes: 5 additions & 4 deletions friendli/sdk/api/chat/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
GenerationStream,
ServingAPI,
)
from friendli.utils.compat import model_parse


class Completions(ServingAPI[Type[V1ChatCompletionsRequest]]):
Expand Down Expand Up @@ -135,7 +136,7 @@ def create(

if stream:
return ChatCompletionStream(response=response)
return ChatCompletion.model_validate(response.json())
return model_parse(ChatCompletion, response.json())


class AsyncCompletions(AsyncServingAPI[Type[V1ChatCompletionsRequest]]):
Expand Down Expand Up @@ -247,7 +248,7 @@ async def create(

if stream:
return AsyncChatCompletionStream(response=response)
return ChatCompletion.model_validate(response.json())
return model_parse(ChatCompletion, response.json())


class ChatCompletionStream(GenerationStream[ChatCompletionLine]):
Expand All @@ -264,7 +265,7 @@ def __next__(self) -> ChatCompletionLine: # noqa: D105
parsed = json.loads(data)

try:
return ChatCompletionLine.model_validate(parsed)
return model_parse(ChatCompletionLine, parsed)
except ValidationError as exc:
raise InvalidGenerationError(
f"Generation result has invalid schema: {str(exc)}"
Expand All @@ -285,7 +286,7 @@ async def __anext__(self) -> ChatCompletionLine: # noqa: D105
parsed = json.loads(data)

try:
return ChatCompletionLine.model_validate(parsed)
return model_parse(ChatCompletionLine, parsed)
except ValidationError as exc:
raise InvalidGenerationError(
f"Generation result has invalid schema: {str(exc)}"
Expand Down
21 changes: 11 additions & 10 deletions friendli/sdk/api/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
GenerationStream,
ServingAPI,
)
from friendli.utils.compat import model_parse


class Completions(ServingAPI[Type[V1CompletionsRequest]]):
Expand Down Expand Up @@ -323,7 +324,7 @@ def create(

if stream:
return CompletionStream(response=response)
return Completion.model_validate(response.json())
return model_parse(Completion, response.json())


class AsyncCompletions(AsyncServingAPI[Type[V1CompletionsRequest]]):
Expand Down Expand Up @@ -610,7 +611,7 @@ async def main() -> None:

if stream:
return AsyncCompletionStream(response=response)
return Completion.model_validate(response.json())
return model_parse(Completion, response.json())


class CompletionStream(GenerationStream[CompletionLine]):
Expand All @@ -623,11 +624,11 @@ def __next__(self) -> CompletionLine: # noqa: D105

parsed = json.loads(line.strip("data: "))
try:
return CompletionLine.model_validate(parsed)
return model_parse(CompletionLine, parsed)
except ValidationError as exc:
try:
# The last iteration of the stream returns a response with `V1Completion` schema.
Completion.model_validate(parsed)
model_parse(Completion, parsed)
raise StopIteration from exc
except ValidationError:
raise InvalidGenerationError(
Expand All @@ -649,11 +650,11 @@ def wait(self) -> Optional[Completion]:
parsed = json.loads(line.strip("data: "))
try:
# The last iteration of the stream returns a response with `V1Completion` schema.
return Completion.model_validate(parsed)
return model_parse(Completion, parsed)
except ValidationError as exc:
try:
# Skip the line response.
CompletionLine.model_validate(parsed)
model_parse(CompletionLine, parsed)
except ValidationError:
raise InvalidGenerationError(
f"Generation result has invalid schema: {str(exc)}"
Expand All @@ -671,11 +672,11 @@ async def __anext__(self) -> CompletionLine: # noqa: D105

parsed = json.loads(line.strip("data: "))
try:
return CompletionLine.model_validate(parsed)
return model_parse(CompletionLine, parsed)
except ValidationError as exc:
try:
# The last iteration of the stream returns a response with `V1Completion` schema.
Completion.model_validate(parsed)
model_parse(Completion, parsed)
raise StopAsyncIteration from exc
except ValidationError:
raise InvalidGenerationError(
Expand All @@ -697,11 +698,11 @@ async def wait(self) -> Optional[Completion]: # noqa: D105
parsed = json.loads(line.strip("data: "))
try:
# The last iteration of the stream returns a response with `V1Completion` schema.
return Completion.model_validate(parsed)
return model_parse(Completion, parsed)
except ValidationError as exc:
try:
# Skip the line response.
CompletionLine.model_validate(parsed)
model_parse(CompletionLine, parsed)
except ValidationError:
raise InvalidGenerationError(
f"Generation result has invalid schema: {str(exc)}"
Expand Down
5 changes: 3 additions & 2 deletions friendli/sdk/api/images/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from friendli.schema.api.v1.codegen.text_to_image_pb2 import V1TextToImageRequest
from friendli.schema.api.v1.images.image import Image, ImageResponseFormatParam
from friendli.sdk.api.base import AsyncServingAPI, ServingAPI
from friendli.utils.compat import model_parse


class TextToImage(ServingAPI[Type[V1TextToImageRequest]]):
Expand Down Expand Up @@ -73,7 +74,7 @@ def create(
}
response = self._request(data=request_dict, stream=False, model=model)

return Image.model_validate(response.json())
return model_parse(Image, response.json())


class AsyncTextToImage(AsyncServingAPI[Type[V1TextToImageRequest]]):
Expand Down Expand Up @@ -135,4 +136,4 @@ async def create(
}
response = await self._request(data=request_dict, stream=False, model=model)

return Image.model_validate(response.json())
return model_parse(Image, response.json())
13 changes: 6 additions & 7 deletions friendli/sdk/resource/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from friendli.client.graphql.base import GqlClient
from friendli.context import get_current_project_id, get_current_team_id
from friendli.errors import AuthorizationError
from friendli.utils.compat import model_parse

_Resource = TypeVar("_Resource", bound=pydantic.BaseModel)
_ResourceId = TypeVar("_ResourceId")
Expand Down Expand Up @@ -47,21 +48,19 @@ def list(self, *args, **kwargs) -> List[_Resource]:
"""Lists reousrces."""

@overload
def _model_validate(self, data: Dict[str, Any]) -> _Resource:
def _model_parse(self, data: Dict[str, Any]) -> _Resource:
...

@overload
def _model_validate(self, data: List[Dict[str, Any]]) -> List[_Resource]:
def _model_parse(self, data: List[Dict[str, Any]]) -> List[_Resource]:
...

def _model_validate(
def _model_parse(
self, data: Union[Dict[str, Any], List[Dict[str, Any]]]
) -> Union[_Resource, List[_Resource]]:
if isinstance(data, list):
return [
self._resource_model.model_validate(entry["node"]) for entry in data
]
return self._resource_model.model_validate(data)
return [model_parse(self._resource_model, entry["node"]) for entry in data]
return model_parse(self._resource_model, data)

def _get_project_id(self) -> str:
project_id = get_current_project_id()
Expand Down
4 changes: 2 additions & 2 deletions friendli/sdk/resource/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def create(
adapter_eids=adapter_eids or [],
launch_config=launch_config or {},
)
deployment = self._model_validate(data)
deployment = self._model_parse(data)

return deployment

Expand All @@ -55,5 +55,5 @@ def get(self, eid: str, *args, **kwargs) -> Deployment:
def list(self) -> List[Deployment]:
"""List deployments."""
data = self.client.get_deployments(project_eid=self._get_project_id())
deployments = self._model_validate(data)
deployments = self._model_parse(data)
return deployments
Loading

0 comments on commit 88d61c8

Please sign in to comment.