diff --git a/friendli/cli/api/chat_completions.py b/friendli/cli/api/chat_completions.py index 072081d6..40337107 100644 --- a/friendli/cli/api/chat_completions.py +++ b/friendli/cli/api/chat_completions.py @@ -12,6 +12,7 @@ from friendli.schema.api.v1.chat.completions import MessageParam from friendli.sdk.client import Friendli +from friendli.utils.compat import model_dump from friendli.utils.decorator import check_api from friendli.utils.format import secho_error_and_exit @@ -136,7 +137,7 @@ def create( ) for chunk in stream: if n is not None and n > 1: - typer.echo(chunk.model_dump()) + typer.echo(model_dump(chunk)) else: typer.echo(chunk.choices[0].delta.content or "", nl=False) else: @@ -152,7 +153,7 @@ def create( temperature=temperature, top_p=top_p, ) - typer.echo(chat_completion.model_dump()) + typer.echo(model_dump(chat_completion)) def _prepare_messages(messages: List[str]) -> List[MessageParam]: diff --git a/friendli/cli/api/completions.py b/friendli/cli/api/completions.py index 2b6402f0..5bd86810 100644 --- a/friendli/cli/api/completions.py +++ b/friendli/cli/api/completions.py @@ -11,6 +11,7 @@ import typer from friendli.sdk.client import Friendli +from friendli.utils.compat import model_dump from friendli.utils.decorator import check_api app = typer.Typer( @@ -129,7 +130,7 @@ def create( ) for chunk in stream: if n is not None and n > 1: - typer.echo(chunk.model_dump()) + typer.echo(model_dump(chunk)) else: typer.echo(chunk.text, nl=False) else: @@ -145,4 +146,4 @@ def create( temperature=temperature, top_p=top_p, ) - typer.echo(completion.model_dump()) + typer.echo(model_dump(completion)) diff --git a/friendli/cli/api/text_to_image.py b/friendli/cli/api/text_to_image.py index 53c735e2..f89d779d 100644 --- a/friendli/cli/api/text_to_image.py +++ b/friendli/cli/api/text_to_image.py @@ -12,6 +12,7 @@ from friendli.enums import ResponseFormat from friendli.sdk.client import Friendli +from friendli.utils.compat import model_dump from friendli.utils.decorator import check_api app = typer.Typer( @@ -106,4 +107,4 @@ def create( seed=seed, response_format=response_format, ) - typer.echo(image.model_dump()) + typer.echo(model_dump(image)) diff --git a/friendli/cli/checkpoint.py b/friendli/cli/checkpoint.py index 1c0e20fb..eb607e80 100644 --- a/friendli/cli/checkpoint.py +++ b/friendli/cli/checkpoint.py @@ -20,6 +20,7 @@ TableFormatter, TreeFormatter, ) +from friendli.utils.compat import model_parse from friendli.utils.format import secho_error_and_exit app = typer.Typer( @@ -236,8 +237,8 @@ def convert( quant_config_dict = cast(dict, yaml.safe_load(quant_config_file.read())) except yaml.YAMLError as err: secho_error_and_exit(f"Failed to load the quant config file: {err}") - quant_config = QuantConfig.model_validate( - {"config": quant_config_dict} + quant_config = model_parse( + QuantConfig, {"config": quant_config_dict} ).config else: quant_config = AWQConfig() diff --git a/friendli/cli/deployment.py b/friendli/cli/deployment.py index 6ec6415f..03616223 100644 --- a/friendli/cli/deployment.py +++ b/friendli/cli/deployment.py @@ -12,6 +12,7 @@ from friendli.formatter import PanelFormatter, TableFormatter from friendli.schema.config.deployment import DeploymentConfig from friendli.sdk.client import Friendli +from friendli.utils.compat import model_dump, model_parse from friendli.utils.decorator import check_api app = typer.Typer( @@ -105,7 +106,7 @@ def create( """Creates a deployment object by using model checkpoint.""" client = Friendli() config_dict = yaml.safe_load(config_file) - config = DeploymentConfig.model_validate(config_dict) + config = model_parse(DeploymentConfig, config_dict) deployment = client.deployment.create( name=config.name, @@ -115,7 +116,7 @@ def create( adapter_eids=config.adapters, launch_config=config.launch_config, ) - deployment_panel.render(deployment.model_dump()) + deployment_panel.render(model_dump(deployment)) @app.command("list") @@ -124,5 +125,5 @@ def list_deployments(): """List deployments.""" client = Friendli() deployments = client.deployment.list() - deployments_ = [deployment.model_dump() for deployment in iter(deployments)] + deployments_ = [model_dump(deployment) for deployment in iter(deployments)] deployment_table.render(deployments_) diff --git a/friendli/schema/api/v1/images/image.py b/friendli/schema/api/v1/images/image.py index 17ffda9e..c57e9732 100644 --- a/friendli/schema/api/v1/images/image.py +++ b/friendli/schema/api/v1/images/image.py @@ -6,8 +6,8 @@ from typing import List, Literal, Union -from pydantic import AnyHttpUrl, Base64Bytes, BaseModel, Field -from typing_extensions import TypeAlias +from pydantic import AnyHttpUrl, BaseModel, Field +from typing_extensions import Annotated, TypeAlias ImageResponseFormatParam: TypeAlias = Union[str, Literal["url", "png", "jpeg", "raw"]] @@ -25,10 +25,15 @@ class ImageDataB64(BaseModel): format: Literal["png", "jpeg", "raw"] seed: int - b64_json: Base64Bytes + b64_json: str + + +_ImageData = Annotated[ + Union[ImageDataUrl, ImageDataB64], Field(..., discriminator="format") +] class Image(BaseModel): """Image data.""" - data: List[Union[ImageDataUrl, ImageDataB64]] = Field(..., discriminator="format") + data: List[_ImageData] diff --git a/friendli/schema/resource/v1/attributes.py b/friendli/schema/resource/v1/attributes.py index 3183003a..59cb0063 100644 --- a/friendli/schema/resource/v1/attributes.py +++ b/friendli/schema/resource/v1/attributes.py @@ -10,12 +10,16 @@ from typing_extensions import Annotated from friendli.enums import CheckpointDataType, QuantMode +from friendli.utils.compat import PYDANTIC_V2 class V1CommonAttributes(BaseModel): """V1 checkpoint attributes schema.""" - model_config = ConfigDict(protected_namespaces=(), extra=Extra.forbid) + if PYDANTIC_V2: + model_config = ConfigDict(protected_namespaces=(), extra=Extra.forbid) # type: ignore + else: + model_config = ConfigDict(extra=Extra.forbid) dtype: CheckpointDataType quant_scheme: Optional[QuantMode] = None diff --git a/friendli/schema/resource/v1/checkpoint.py b/friendli/schema/resource/v1/checkpoint.py index a7f331a9..e718841d 100644 --- a/friendli/schema/resource/v1/checkpoint.py +++ b/friendli/schema/resource/v1/checkpoint.py @@ -16,6 +16,7 @@ CheckpointStatus, CheckpointValidationStatus, ) +from friendli.utils.compat import PYDANTIC_V2 class V1Catalog(BaseModel): @@ -77,7 +78,8 @@ class V1CheckpointOwnership(BaseModel): class V1Checkpoint(BaseModel): """V1 checkpoint schema.""" - model_config = ConfigDict(protected_namespaces=()) + if PYDANTIC_V2: + model_config = ConfigDict(protected_namespaces=()) # type: ignore id: UUID user_id: UUID diff --git a/friendli/sdk/api/chat/completions.py b/friendli/sdk/api/chat/completions.py index 101d2953..2b02575d 100644 --- a/friendli/sdk/api/chat/completions.py +++ b/friendli/sdk/api/chat/completions.py @@ -24,6 +24,7 @@ GenerationStream, ServingAPI, ) +from friendli.utils.compat import model_parse class Completions(ServingAPI[Type[V1ChatCompletionsRequest]]): @@ -135,7 +136,7 @@ def create( if stream: return ChatCompletionStream(response=response) - return ChatCompletion.model_validate(response.json()) + return model_parse(ChatCompletion, response.json()) class AsyncCompletions(AsyncServingAPI[Type[V1ChatCompletionsRequest]]): @@ -247,7 +248,7 @@ async def create( if stream: return AsyncChatCompletionStream(response=response) - return ChatCompletion.model_validate(response.json()) + return model_parse(ChatCompletion, response.json()) class ChatCompletionStream(GenerationStream[ChatCompletionLine]): @@ -264,7 +265,7 @@ def __next__(self) -> ChatCompletionLine: # noqa: D105 parsed = json.loads(data) try: - return ChatCompletionLine.model_validate(parsed) + return model_parse(ChatCompletionLine, parsed) except ValidationError as exc: raise InvalidGenerationError( f"Generation result has invalid schema: {str(exc)}" @@ -285,7 +286,7 @@ async def __anext__(self) -> ChatCompletionLine: # noqa: D105 parsed = json.loads(data) try: - return ChatCompletionLine.model_validate(parsed) + return model_parse(ChatCompletionLine, parsed) except ValidationError as exc: raise InvalidGenerationError( f"Generation result has invalid schema: {str(exc)}" diff --git a/friendli/sdk/api/completions.py b/friendli/sdk/api/completions.py index cba7deef..60908edb 100644 --- a/friendli/sdk/api/completions.py +++ b/friendli/sdk/api/completions.py @@ -25,6 +25,7 @@ GenerationStream, ServingAPI, ) +from friendli.utils.compat import model_parse class Completions(ServingAPI[Type[V1CompletionsRequest]]): @@ -323,7 +324,7 @@ def create( if stream: return CompletionStream(response=response) - return Completion.model_validate(response.json()) + return model_parse(Completion, response.json()) class AsyncCompletions(AsyncServingAPI[Type[V1CompletionsRequest]]): @@ -610,7 +611,7 @@ async def main() -> None: if stream: return AsyncCompletionStream(response=response) - return Completion.model_validate(response.json()) + return model_parse(Completion, response.json()) class CompletionStream(GenerationStream[CompletionLine]): @@ -623,11 +624,11 @@ def __next__(self) -> CompletionLine: # noqa: D105 parsed = json.loads(line.strip("data: ")) try: - return CompletionLine.model_validate(parsed) + return model_parse(CompletionLine, parsed) except ValidationError as exc: try: # The last iteration of the stream returns a response with `V1Completion` schema. - Completion.model_validate(parsed) + model_parse(Completion, parsed) raise StopIteration from exc except ValidationError: raise InvalidGenerationError( @@ -649,11 +650,11 @@ def wait(self) -> Optional[Completion]: parsed = json.loads(line.strip("data: ")) try: # The last iteration of the stream returns a response with `V1Completion` schema. - return Completion.model_validate(parsed) + return model_parse(Completion, parsed) except ValidationError as exc: try: # Skip the line response. - CompletionLine.model_validate(parsed) + model_parse(CompletionLine, parsed) except ValidationError: raise InvalidGenerationError( f"Generation result has invalid schema: {str(exc)}" @@ -671,11 +672,11 @@ async def __anext__(self) -> CompletionLine: # noqa: D105 parsed = json.loads(line.strip("data: ")) try: - return CompletionLine.model_validate(parsed) + return model_parse(CompletionLine, parsed) except ValidationError as exc: try: # The last iteration of the stream returns a response with `V1Completion` schema. - Completion.model_validate(parsed) + model_parse(Completion, parsed) raise StopAsyncIteration from exc except ValidationError: raise InvalidGenerationError( @@ -697,11 +698,11 @@ async def wait(self) -> Optional[Completion]: # noqa: D105 parsed = json.loads(line.strip("data: ")) try: # The last iteration of the stream returns a response with `V1Completion` schema. - return Completion.model_validate(parsed) + return model_parse(Completion, parsed) except ValidationError as exc: try: # Skip the line response. - CompletionLine.model_validate(parsed) + model_parse(CompletionLine, parsed) except ValidationError: raise InvalidGenerationError( f"Generation result has invalid schema: {str(exc)}" diff --git a/friendli/sdk/api/images/text_to_image.py b/friendli/sdk/api/images/text_to_image.py index c4d4368f..16913be4 100644 --- a/friendli/sdk/api/images/text_to_image.py +++ b/friendli/sdk/api/images/text_to_image.py @@ -12,6 +12,7 @@ from friendli.schema.api.v1.codegen.text_to_image_pb2 import V1TextToImageRequest from friendli.schema.api.v1.images.image import Image, ImageResponseFormatParam from friendli.sdk.api.base import AsyncServingAPI, ServingAPI +from friendli.utils.compat import model_parse class TextToImage(ServingAPI[Type[V1TextToImageRequest]]): @@ -73,7 +74,7 @@ def create( } response = self._request(data=request_dict, stream=False, model=model) - return Image.model_validate(response.json()) + return model_parse(Image, response.json()) class AsyncTextToImage(AsyncServingAPI[Type[V1TextToImageRequest]]): @@ -135,4 +136,4 @@ async def create( } response = await self._request(data=request_dict, stream=False, model=model) - return Image.model_validate(response.json()) + return model_parse(Image, response.json()) diff --git a/friendli/sdk/resource/base.py b/friendli/sdk/resource/base.py index 30ad5bd3..de2127be 100644 --- a/friendli/sdk/resource/base.py +++ b/friendli/sdk/resource/base.py @@ -16,6 +16,7 @@ from friendli.client.graphql.base import GqlClient from friendli.context import get_current_project_id, get_current_team_id from friendli.errors import AuthorizationError +from friendli.utils.compat import model_parse _Resource = TypeVar("_Resource", bound=pydantic.BaseModel) _ResourceId = TypeVar("_ResourceId") @@ -47,21 +48,19 @@ def list(self, *args, **kwargs) -> List[_Resource]: """Lists reousrces.""" @overload - def _model_validate(self, data: Dict[str, Any]) -> _Resource: + def _model_parse(self, data: Dict[str, Any]) -> _Resource: ... @overload - def _model_validate(self, data: List[Dict[str, Any]]) -> List[_Resource]: + def _model_parse(self, data: List[Dict[str, Any]]) -> List[_Resource]: ... - def _model_validate( + def _model_parse( self, data: Union[Dict[str, Any], List[Dict[str, Any]]] ) -> Union[_Resource, List[_Resource]]: if isinstance(data, list): - return [ - self._resource_model.model_validate(entry["node"]) for entry in data - ] - return self._resource_model.model_validate(data) + return [model_parse(self._resource_model, entry["node"]) for entry in data] + return model_parse(self._resource_model, data) def _get_project_id(self) -> str: project_id = get_current_project_id() diff --git a/friendli/sdk/resource/deployment.py b/friendli/sdk/resource/deployment.py index c93764a0..78271646 100644 --- a/friendli/sdk/resource/deployment.py +++ b/friendli/sdk/resource/deployment.py @@ -44,7 +44,7 @@ def create( adapter_eids=adapter_eids or [], launch_config=launch_config or {}, ) - deployment = self._model_validate(data) + deployment = self._model_parse(data) return deployment @@ -55,5 +55,5 @@ def get(self, eid: str, *args, **kwargs) -> Deployment: def list(self) -> List[Deployment]: """List deployments.""" data = self.client.get_deployments(project_eid=self._get_project_id()) - deployments = self._model_validate(data) + deployments = self._model_parse(data) return deployments diff --git a/friendli/utils/compat.py b/friendli/utils/compat.py new file mode 100644 index 00000000..1f86eb87 --- /dev/null +++ b/friendli/utils/compat.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024-present, FriendliAI Inc. All rights reserved. + +"""Compatibility Utils.""" + +from __future__ import annotations + +from typing import Any, Dict, Type, TypeVar, cast + +import pydantic + +_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) + +PYDANTIC_V2 = pydantic.VERSION.startswith("2.") + + +def model_parse(model: Type[_ModelT], data: Any) -> _ModelT: + """Parse a pydantic model from data.""" + if PYDANTIC_V2: + return model.model_validate(data) # type: ignore + return model.parse_obj(data) # type: ignore + + +def model_dump( + model: pydantic.BaseModel, + *, + exclude_unset: bool = False, + exclude_defaults: bool = False, +) -> dict[str, Any]: + """Dump data from a pydantic model.""" + if PYDANTIC_V2: + return model.model_dump( # type: ignore + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + ) + return cast( + Dict[str, Any], + model.dict( # type: ignore + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + ), + ) diff --git a/friendli/utils/transfer.py b/friendli/utils/transfer.py index 3f8e9b7d..f675070d 100644 --- a/friendli/utils/transfer.py +++ b/friendli/utils/transfer.py @@ -36,6 +36,7 @@ UploadedPartETag, UploadTask, ) +from friendli.utils.compat import model_dump from friendli.utils.fs import get_file_size, storage_path_to_local_path from friendli.utils.request import DEFAULT_REQ_TIMEOUT @@ -340,7 +341,7 @@ def multipart_upload_file( ) for fut in done: part_etag = fut.result() - uploaded_part_etags.append(part_etag.model_dump()) + uploaded_part_etags.append(model_dump(part_etag)) complete_callback( upload_task.path, upload_task.upload_id, uploaded_part_etags ) diff --git a/friendli/utils/validate.py b/friendli/utils/validate.py index 9d9b3e34..05c36b54 100644 --- a/friendli/utils/validate.py +++ b/friendli/utils/validate.py @@ -20,6 +20,7 @@ NotSupportedError, ) from friendli.schema.resource.v1.attributes import V1AttributesValidationModel +from friendli.utils.compat import model_parse from friendli.utils.format import secho_error_and_exit from friendli.utils.version import ( FRIENDLI_PACKAGE_NAME, @@ -87,7 +88,7 @@ def check_package_version() -> None: def validate_checkpoint_attributes(attr: Dict[str, Any]) -> None: """Validate checkpoint attributes schema.""" try: - V1AttributesValidationModel.model_validate({"attr": attr}) + model_parse(V1AttributesValidationModel, {"attr": attr}) except ValidationError as exc: msgs = [] for error in exc.errors(): diff --git a/poetry.lock b/poetry.lock index 7fcb1995..401ffd6a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2874,6 +2874,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4297,5 +4298,5 @@ mllib = ["accelerate", "datasets", "einops", "h5py", "peft", "transformers"] [metadata] lock-version = "2.0" -python-versions = "^3.8" -content-hash = "5f7baaef7cb712cefece6e740a698dac3964dec638d2ad4a3e6f38c6707b0ccf" +python-versions = "^3.8.1" +content-hash = "b0989b17724d419c7880686b4ba0c938dd2f1045c1d2e40055d1eb805267dcb6" diff --git a/pyproject.toml b/pyproject.toml index d696ea96..32db3ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "friendli-client" -version = "1.2.2" +version = "1.2.3" description = "Client of Friendli Suite." license = "Apache-2.0" authors = ["FriendliAI teams "] @@ -25,21 +25,21 @@ name = "PyPI" priority = "primary" [tool.poetry.dependencies] -python = "^3.8" -requests = "^2.31.0" +python = "^3.8.1" +requests = "^2" PyYaml = "^6.0.1" typer = "^0.9.0" rich = "^12.2.0" jsonschema = "^4.17.3" boto3 = "^1.22.8" botocore = "^1.25.8" -tqdm = "^4.64.0" +tqdm = "^4.48.0" azure-mgmt-storage = "^20.1.0" azure-storage-blob = "^12.12.0" pathspec = "^0.9.0" boto3-stubs = "^1.26.90" mypy-boto3-s3 = "^1.26.163" -pydantic = {extras = ["email"], version = "^2.0.2"} +pydantic = {extras = ["email"], version = ">=1.9.0, <3"} transformers = { version = "4.36.2", optional = true } h5py = { version = "^3.9.0", optional = true } einops = { version = "^0.6.1", optional = true } @@ -52,7 +52,7 @@ peft = { version = "0.6.0", optional = true } httpx = "^0.24.1" fastapi = "^0.104.0" uvicorn = "^0.23.2" -gql = "^3.5.0" +gql = "^3.4.1" [tool.poetry.group.dev] optional = true @@ -121,6 +121,7 @@ disable = [ "consider-using-set-comprehension", "redefined-outer-name" ] +extension-pkg-whitelist = "pydantic" [tool.pylint.TYPECHECK] generated-members = [ diff --git a/tests/unit_tests/modules/helpers/spec.py b/tests/unit_tests/modules/helpers/spec.py index 1fe5e7ae..127d8d59 100644 --- a/tests/unit_tests/modules/helpers/spec.py +++ b/tests/unit_tests/modules/helpers/spec.py @@ -13,6 +13,8 @@ from jinja2.environment import Template as JinjaTemplate from pydantic import BaseModel +from friendli.utils.compat import model_parse + class InvalidSpecFormatError(Exception): """Invalid model spec format that can be handled by users.""" @@ -149,7 +151,7 @@ def _get_param_info( return res if node_type == SpecNodeType.REPEAT_GROUP: try: - repeat_range = RepeatRange.model_validate(spec["range"]) # type: ignore + repeat_range = model_parse(RepeatRange, spec["range"]) # type: ignore except KeyError as exc: raise InvalidSpecFormatError from exc res = {} diff --git a/tests/unit_tests/modules/helpers/utils.py b/tests/unit_tests/modules/helpers/utils.py index 04b1b19a..be6b73b9 100644 --- a/tests/unit_tests/modules/helpers/utils.py +++ b/tests/unit_tests/modules/helpers/utils.py @@ -28,6 +28,7 @@ from friendli.modules.quantizer.schema.config import AWQConfig from friendli.modules.quantizer.schema.data import QuantInput from friendli.modules.quantizer.smoothquant.base import SmoothQuantQuantizer +from friendli.utils.compat import model_dump from tests.unit_tests.modules.helpers.spec import ModelSpecParser, ParamInfo, Template @@ -89,7 +90,7 @@ def get_param_specs( ) -> Dict[str, ParamInfo]: file_path = f"{SPEC_PATH_PREFIX}{spec_folder}/{model_name}.yaml" template = Template.from_file(file_path) - render_config = model_config.model_dump() + render_config = model_dump(model_config) rendered = template.render(**render_config) assert isinstance(rendered, dict) parser = ModelSpecParser(model_spec=rendered)