From 8cfa0c75ee84d18900d5a3f866a0155a40721ea8 Mon Sep 17 00:00:00 2001 From: Mattias Aabmets <6948036+aabmets@users.noreply.github.com> Date: Mon, 18 Mar 2024 01:57:44 +0200 Subject: [PATCH 01/32] Added column tz autoconversion to Table __init__ method --- piccolo/table.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/piccolo/table.py b/piccolo/table.py index d6735ac3..67d4eff2 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -5,6 +5,8 @@ import types import typing as t import warnings +from datetime import datetime +from zoneinfo import ZoneInfo from dataclasses import dataclass, field from piccolo.columns import Column @@ -17,6 +19,7 @@ ReferencedTable, Secret, Serial, + Timestamptz, ) from piccolo.columns.defaults.base import Default from piccolo.columns.indexes import IndexMethod @@ -436,6 +439,9 @@ def __init__( ): raise ValueError(f"{column._meta.name} wasn't provided") + if isinstance(column, Timestamptz) and isinstance(value, datetime): + value = value.astimezone(column.tz) + self[column._meta.name] = value unrecognized = kwargs.keys() From 95d8e4dd2e6b3de13236a61b3b5fee378260b965 Mon Sep 17 00:00:00 2001 From: Mattias Aabmets <6948036+aabmets@users.noreply.github.com> Date: Mon, 18 Mar 2024 02:04:55 +0200 Subject: [PATCH 02/32] Added tz ability to Timestamptz --- piccolo/columns/column_types.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 7012355d..706db2cb 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -33,6 +33,7 @@ class Band(Table): import uuid from dataclasses import dataclass from datetime import date, datetime, time, timedelta +from zoneinfo import ZoneInfo from enum import Enum from piccolo.columns.base import ( @@ -1007,6 +1008,7 @@ class Concert(Table): """ value_type = datetime + tz_type = ZoneInfo # Currently just used by ModelBuilder, to know that we want a timezone # aware datetime. @@ -1015,7 +1017,10 @@ class Concert(Table): timedelta_delegate = TimedeltaDelegate() def __init__( - self, default: TimestamptzArg = TimestamptzNow(), **kwargs + self, + tz: ZoneInfo = ZoneInfo('UTC'), + default: TimestamptzArg = TimestamptzNow(), + **kwargs ) -> None: self._validate_default( default, TimestamptzArg.__args__ # type: ignore @@ -1025,10 +1030,11 @@ def __init__( default = TimestamptzCustom.from_datetime(default) if default == datetime.now: - default = TimestamptzNow() + default = TimestamptzNow(tz) + self.tz = tz self.default = default - kwargs.update({"default": default}) + kwargs.update({"tz": tz, "default": default}) super().__init__(**kwargs) ########################################################################### From 7dd76d8c959adeb31c718e6e85a32fef2a560925 Mon Sep 17 00:00:00 2001 From: Mattias Aabmets <6948036+aabmets@users.noreply.github.com> Date: Mon, 18 Mar 2024 02:10:01 +0200 Subject: [PATCH 03/32] Added tz ability to TimestamptzNow --- piccolo/columns/defaults/timestamptz.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index 5db6ebd5..1c40b9c8 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -3,6 +3,7 @@ import datetime import typing as t from enum import Enum +from zoneinfo import ZoneInfo from .timestamp import TimestampCustom, TimestampNow, TimestampOffset @@ -27,12 +28,15 @@ def python(self): class TimestamptzNow(TimestampNow): + def __init__(self, tz: ZoneInfo = ZoneInfo('UTC')): + self._tz = tz + @property def cockroach(self): return "current_timestamp" def python(self): - return datetime.datetime.now(tz=datetime.timezone.utc) + return datetime.datetime.now(tz=self._tz) class TimestamptzCustom(TimestampCustom): From 06e062555e2fe36674e4d3b02a79bed8e7872f4d Mon Sep 17 00:00:00 2001 From: Mattias Aabmets <6948036+aabmets@users.noreply.github.com> Date: Wed, 20 Mar 2024 01:27:16 +0200 Subject: [PATCH 04/32] Added timezone awareness to TimestamptzOffset and TimestamptzCustom --- piccolo/columns/defaults/timestamptz.py | 36 ++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index 1c40b9c8..277f9aad 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -9,6 +9,20 @@ class TimestamptzOffset(TimestampOffset): + def __init__( + self, + days: int = 0, + hours: int = 0, + minutes: int = 0, + seconds: int = 0, + tz: ZoneInfo = ZoneInfo('UTC') + ): + self._tz = tz + super().__init__(**{ + k: v for k, v in locals().items() + if k not in ['self', 'tz'] + }) + @property def cockroach(self): interval_string = self.get_postgres_interval_string( @@ -18,7 +32,7 @@ def cockroach(self): def python(self): return datetime.datetime.now( - tz=datetime.timezone.utc + tz=self._tz ) + datetime.timedelta( days=self.days, hours=self.hours, @@ -40,6 +54,22 @@ def python(self): class TimestamptzCustom(TimestampCustom): + def __init__( + self, + year: int = 2000, + month: int = 1, + day: int = 1, + hour: int = 0, + second: int = 0, + microsecond: int = 0, + tz: ZoneInfo = ZoneInfo('UTC') + ): + self._tz = tz + super().__init__(**{ + k: v for k, v in locals().items() + if k not in ['self', 'tz'] + }) + @property def cockroach(self): return "'{}'".format(self.datetime.isoformat().replace("T", " ")) @@ -53,13 +83,13 @@ def datetime(self): hour=self.hour, second=self.second, microsecond=self.microsecond, - tzinfo=datetime.timezone.utc, + tzinfo=self._tz, ) @classmethod def from_datetime(cls, instance: datetime.datetime): # type: ignore if instance.tzinfo is not None: - instance = instance.astimezone(datetime.timezone.utc) + instance = instance.astimezone(self._tz) return cls( year=instance.year, month=instance.month, From c589d44e5842627e30c32e260280b4d49954e431 Mon Sep 17 00:00:00 2001 From: Mattias Aabmets <6948036+aabmets@users.noreply.github.com> Date: Wed, 20 Mar 2024 01:31:00 +0200 Subject: [PATCH 05/32] Added zoneinfo fallback dependency for missing zone data --- requirements/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 0a5ee624..e09e7669 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -5,3 +5,4 @@ targ>=0.3.7 inflection>=0.5.1 typing-extensions>=4.3.0 pydantic[email]==2.* +tzdata>=2024.1 From fc708fcfc96b1458fedcbc3b338e85333614d7ea Mon Sep 17 00:00:00 2001 From: Mattias Aabmets <6948036+aabmets@users.noreply.github.com> Date: Wed, 20 Mar 2024 01:35:28 +0200 Subject: [PATCH 06/32] Added backports.zoneinfo --- requirements/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index e09e7669..73c8f4e6 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -6,3 +6,4 @@ inflection>=0.5.1 typing-extensions>=4.3.0 pydantic[email]==2.* tzdata>=2024.1 +backports.zoneinfo>=0.2.1 From 705629a4a2f5958b6f8f06a14c6e702aa31ec423 Mon Sep 17 00:00:00 2001 From: Mattias Aabmets <6948036+aabmets@users.noreply.github.com> Date: Wed, 20 Mar 2024 01:39:05 +0200 Subject: [PATCH 07/32] Un-privatized tz attribute from default timestamptz classes --- piccolo/columns/defaults/timestamptz.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index 277f9aad..e52d4445 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -17,7 +17,7 @@ def __init__( seconds: int = 0, tz: ZoneInfo = ZoneInfo('UTC') ): - self._tz = tz + self.tz = tz super().__init__(**{ k: v for k, v in locals().items() if k not in ['self', 'tz'] @@ -32,7 +32,7 @@ def cockroach(self): def python(self): return datetime.datetime.now( - tz=self._tz + tz=self.tz ) + datetime.timedelta( days=self.days, hours=self.hours, @@ -43,14 +43,14 @@ def python(self): class TimestamptzNow(TimestampNow): def __init__(self, tz: ZoneInfo = ZoneInfo('UTC')): - self._tz = tz + self.tz = tz @property def cockroach(self): return "current_timestamp" def python(self): - return datetime.datetime.now(tz=self._tz) + return datetime.datetime.now(tz=self.tz) class TimestamptzCustom(TimestampCustom): @@ -64,7 +64,7 @@ def __init__( microsecond: int = 0, tz: ZoneInfo = ZoneInfo('UTC') ): - self._tz = tz + self.tz = tz super().__init__(**{ k: v for k, v in locals().items() if k not in ['self', 'tz'] @@ -83,13 +83,13 @@ def datetime(self): hour=self.hour, second=self.second, microsecond=self.microsecond, - tzinfo=self._tz, + tzinfo=self.tz, ) @classmethod def from_datetime(cls, instance: datetime.datetime): # type: ignore if instance.tzinfo is not None: - instance = instance.astimezone(self._tz) + instance = instance.astimezone(self.tz) return cls( year=instance.year, month=instance.month, From a25e814c8338dc650eb5cc8cefce35ea4908c76f Mon Sep 17 00:00:00 2001 From: Mattias Aabmets <6948036+aabmets@users.noreply.github.com> Date: Wed, 20 Mar 2024 01:52:53 +0200 Subject: [PATCH 08/32] Added python version constraint to backports.zoneinfo --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 73c8f4e6..9199ab8e 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -6,4 +6,4 @@ inflection>=0.5.1 typing-extensions>=4.3.0 pydantic[email]==2.* tzdata>=2024.1 -backports.zoneinfo>=0.2.1 +backports.zoneinfo>=0.2.1; python_version <= '3.8' From 6611a975f55fd0676c13776bd3d5b22f0e41d909 Mon Sep 17 00:00:00 2001 From: Mattias Aabmets <6948036+aabmets@users.noreply.github.com> Date: Wed, 20 Mar 2024 08:38:37 +0200 Subject: [PATCH 09/32] Updated Timestamptz docstring and fixed a bug in TimestamptzCustom.from_datetime --- piccolo/columns/column_types.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 6292988f..ead6237f 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -978,30 +978,33 @@ def __set__(self, obj, value: t.Union[datetime, None]): class Timestamptz(Column): """ Used for storing timezone aware datetimes. Uses the ``datetime`` type for - values. The values are converted to UTC in the database, and are also - returned as UTC. + values. The values are converted to UTC when saved into the database and + are converted back into the timezone of the column on select queries. **Example** .. code-block:: python import datetime + from zoneinfo import ZoneInfo - class Concert(Table): - starts = Timestamptz() + class TallinnConcerts(Table): + event_start = Timestamptz(tz=ZoneInfo("Europe/Tallinn")) # Create - >>> await Concert( - ... starts=datetime.datetime( - ... year=2050, month=1, day=1, tzinfo=datetime.timezone.tz + >>> await TallinnConcerts( + ... event_start=datetime.datetime( + ... year=2050, month=1, day=1, hour=20 ... ) ... ).save() # Query - >>> await Concert.select(Concert.starts) + >>> await TallinnConcerts.select(TallinnConcerts.event_start) { - 'starts': datetime.datetime( - 2050, 1, 1, 0, 0, tzinfo=datetime.timezone.utc + 'event_start': datetime.datetime( + 2050, 1, 1, 20, 0, tzinfo=zoneinfo.ZoneInfo( + key='Europe/Tallinn' + ) ) } @@ -1027,7 +1030,7 @@ def __init__( ) if isinstance(default, datetime): - default = TimestamptzCustom.from_datetime(default) + default = TimestamptzCustom.from_datetime(default, tz) if default == datetime.now: default = TimestamptzNow(tz) From aa9f077b31138fb09b3e86ade6070f7dabc2682d Mon Sep 17 00:00:00 2001 From: Mattias Aabmets <6948036+aabmets@users.noreply.github.com> Date: Wed, 20 Mar 2024 08:41:46 +0200 Subject: [PATCH 10/32] Fixed bug in TimestamptzCustom.from_datetime --- piccolo/columns/defaults/timestamptz.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index e52d4445..0edbf3e8 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -87,9 +87,9 @@ def datetime(self): ) @classmethod - def from_datetime(cls, instance: datetime.datetime): # type: ignore + def from_datetime(cls, instance: datetime.datetime, tz: ZoneInfo = ZoneInfo('UTC')): # type: ignore if instance.tzinfo is not None: - instance = instance.astimezone(self.tz) + instance = instance.astimezone(tz) return cls( year=instance.year, month=instance.month, From 6fcb93fbfe171942bec5c0d26ebee8de461a2f6f Mon Sep 17 00:00:00 2001 From: Mattias Aabmets Date: Tue, 26 Mar 2024 19:28:34 +0200 Subject: [PATCH 11/32] Fixed linter and test issues (hopefully) --- piccolo/columns/column_types.py | 2 +- piccolo/columns/defaults/timestamptz.py | 22 ++++++++++++++-------- piccolo/table.py | 2 +- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index ead6237f..ee74c110 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -33,7 +33,6 @@ class Band(Table): import uuid from dataclasses import dataclass from datetime import date, datetime, time, timedelta -from zoneinfo import ZoneInfo from enum import Enum from piccolo.columns.base import ( @@ -64,6 +63,7 @@ class Band(Table): from piccolo.querystring import QueryString, Unquoted from piccolo.utils.encoding import dump_json from piccolo.utils.warnings import colored_warning +from zoneinfo import ZoneInfo if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns.base import ColumnMeta diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index 0edbf3e8..0dd29d33 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -18,10 +18,12 @@ def __init__( tz: ZoneInfo = ZoneInfo('UTC') ): self.tz = tz - super().__init__(**{ - k: v for k, v in locals().items() - if k not in ['self', 'tz'] - }) + super().__init__( + days=days, + hours=hours, + minutes=minutes, + seconds=seconds + ) @property def cockroach(self): @@ -65,10 +67,14 @@ def __init__( tz: ZoneInfo = ZoneInfo('UTC') ): self.tz = tz - super().__init__(**{ - k: v for k, v in locals().items() - if k not in ['self', 'tz'] - }) + super().__init__( + year=year, + month=month, + day=day, + hour=hour, + second=second, + microsecond=microsecond + ) @property def cockroach(self): diff --git a/piccolo/table.py b/piccolo/table.py index 67d4eff2..5b66b934 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -6,7 +6,6 @@ import typing as t import warnings from datetime import datetime -from zoneinfo import ZoneInfo from dataclasses import dataclass, field from piccolo.columns import Column @@ -57,6 +56,7 @@ from piccolo.utils.sql_values import convert_to_sql_value from piccolo.utils.sync import run_sync from piccolo.utils.warnings import colored_warning +from zoneinfo import ZoneInfo if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns import Selectable From 59a531e22349ba9742dad2df3a2fd1276e57a696 Mon Sep 17 00:00:00 2001 From: Mattias Aabmets Date: Thu, 28 Mar 2024 12:26:49 +0200 Subject: [PATCH 12/32] Added backport.zoneinfo import with try-except clause, fixed imports ordering with isort --- piccolo/columns/column_types.py | 6 +++++- piccolo/columns/defaults/timestamptz.py | 6 +++++- piccolo/table.py | 8 ++++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index ee74c110..05353eab 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -63,7 +63,11 @@ class Band(Table): from piccolo.querystring import QueryString, Unquoted from piccolo.utils.encoding import dump_json from piccolo.utils.warnings import colored_warning -from zoneinfo import ZoneInfo + +try: + from zoneinfo import ZoneInfo +except ImportError: + from backports.zoneinfo import ZoneInfo if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns.base import ColumnMeta diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index 0dd29d33..c498594d 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -3,7 +3,11 @@ import datetime import typing as t from enum import Enum -from zoneinfo import ZoneInfo + +try: + from zoneinfo import ZoneInfo +except ImportError: + from backports.zoneinfo import ZoneInfo from .timestamp import TimestampCustom, TimestampNow, TimestampOffset diff --git a/piccolo/table.py b/piccolo/table.py index 5b66b934..ec76ee24 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -5,8 +5,8 @@ import types import typing as t import warnings -from datetime import datetime from dataclasses import dataclass, field +from datetime import datetime from piccolo.columns import Column from piccolo.columns.column_types import ( @@ -56,7 +56,11 @@ from piccolo.utils.sql_values import convert_to_sql_value from piccolo.utils.sync import run_sync from piccolo.utils.warnings import colored_warning -from zoneinfo import ZoneInfo + +try: + from zoneinfo import ZoneInfo +except ImportError: + from backports.zoneinfo import ZoneInfo if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns import Selectable From 7a3c58456b0f450c9e9467c38b508f15ea118958 Mon Sep 17 00:00:00 2001 From: Mattias Aabmets Date: Thu, 28 Mar 2024 14:42:42 +0200 Subject: [PATCH 13/32] Fixed linting errors across codebase and fixed timestamptz test errors, migration tests are not fixed --- .../apps/migrations/auto/migration_manager.py | 18 +- piccolo/apps/migrations/auto/serialisation.py | 6 +- piccolo/apps/migrations/commands/backwards.py | 6 +- piccolo/apps/migrations/commands/base.py | 6 +- piccolo/apps/migrations/commands/forwards.py | 6 +- piccolo/apps/playground/commands/run.py | 1 + piccolo/apps/user/tables.py | 1 + piccolo/columns/base.py | 8 +- piccolo/columns/column_types.py | 186 ++++++------------ piccolo/columns/defaults/timestamptz.py | 45 ++--- piccolo/columns/reference.py | 1 + piccolo/conf/apps.py | 3 +- piccolo/query/methods/objects.py | 16 +- piccolo/query/methods/select.py | 12 +- piccolo/query/mixins.py | 6 +- piccolo/query/proxy.py | 5 +- piccolo/table.py | 16 +- tests/columns/test_reference.py | 1 + tests/columns/test_timestamptz.py | 106 +++++----- tests/conf/example.py | 1 + 20 files changed, 204 insertions(+), 246 deletions(-) diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index 772ec3ed..fca36e8e 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -737,9 +737,9 @@ async def _run_rename_columns(self, backwards: bool = False): async def _run_add_tables(self, backwards: bool = False): table_classes: t.List[t.Type[Table]] = [] for add_table in self.add_tables: - add_columns: t.List[ - AddColumnClass - ] = self.add_columns.for_table_class_name(add_table.class_name) + add_columns: t.List[AddColumnClass] = ( + self.add_columns.for_table_class_name(add_table.class_name) + ) _Table: t.Type[Table] = create_table_class( class_name=add_table.class_name, class_kwargs={ @@ -792,9 +792,9 @@ async def _run_add_columns(self, backwards: bool = False): if table_class_name in [i.class_name for i in self.add_tables]: continue # No need to add columns to new tables - add_columns: t.List[ - AddColumnClass - ] = self.add_columns.for_table_class_name(table_class_name) + add_columns: t.List[AddColumnClass] = ( + self.add_columns.for_table_class_name(table_class_name) + ) ############################################################### # Define the table, with the columns, so the metaclass @@ -838,9 +838,9 @@ async def _run_add_columns(self, backwards: bool = False): else: primary_key = existing_table._meta.primary_key - table_class_members[ - primary_key._meta.name - ] = primary_key + table_class_members[primary_key._meta.name] = ( + primary_key + ) break diff --git a/piccolo/apps/migrations/auto/serialisation.py b/piccolo/apps/migrations/auto/serialisation.py index d1fd5ee4..b3644b85 100644 --- a/piccolo/apps/migrations/auto/serialisation.py +++ b/piccolo/apps/migrations/auto/serialisation.py @@ -25,8 +25,7 @@ class CanConflictWithGlobalNames(abc.ABC): @abc.abstractmethod - def warn_if_is_conflicting_with_global_name(self): - ... + def warn_if_is_conflicting_with_global_name(self): ... class UniqueGlobalNamesMeta(type): @@ -237,8 +236,7 @@ def warn_if_is_conflicting_with_global_name(self): class Definition(CanConflictWithGlobalNames, abc.ABC): @abc.abstractmethod - def __repr__(self): - ... + def __repr__(self): ... ########################################################################### # To allow sorting: diff --git a/piccolo/apps/migrations/commands/backwards.py b/piccolo/apps/migrations/commands/backwards.py index a0a454d9..c84e5455 100644 --- a/piccolo/apps/migrations/commands/backwards.py +++ b/piccolo/apps/migrations/commands/backwards.py @@ -31,9 +31,9 @@ def __init__( super().__init__() async def run_migrations_backwards(self, app_config: AppConfig): - migration_modules: t.Dict[ - str, MigrationModule - ] = self.get_migration_modules(app_config.migrations_folder_path) + migration_modules: t.Dict[str, MigrationModule] = ( + self.get_migration_modules(app_config.migrations_folder_path) + ) ran_migration_ids = await Migration.get_migrations_which_ran( app_name=self.app_name diff --git a/piccolo/apps/migrations/commands/base.py b/piccolo/apps/migrations/commands/base.py index 3b4cee10..a3966f7c 100644 --- a/piccolo/apps/migrations/commands/base.py +++ b/piccolo/apps/migrations/commands/base.py @@ -88,9 +88,9 @@ async def get_migration_managers( migrations_folder = app_config.migrations_folder_path - migration_modules: t.Dict[ - str, MigrationModule - ] = self.get_migration_modules(migrations_folder) + migration_modules: t.Dict[str, MigrationModule] = ( + self.get_migration_modules(migrations_folder) + ) migration_ids = sorted(migration_modules.keys()) diff --git a/piccolo/apps/migrations/commands/forwards.py b/piccolo/apps/migrations/commands/forwards.py index f060b493..6d967dd5 100644 --- a/piccolo/apps/migrations/commands/forwards.py +++ b/piccolo/apps/migrations/commands/forwards.py @@ -32,9 +32,9 @@ async def run_migrations(self, app_config: AppConfig) -> MigrationResult: app_name=app_config.app_name ) - migration_modules: t.Dict[ - str, MigrationModule - ] = self.get_migration_modules(app_config.migrations_folder_path) + migration_modules: t.Dict[str, MigrationModule] = ( + self.get_migration_modules(app_config.migrations_folder_path) + ) ids = self.get_migration_ids(migration_modules) n = len(ids) diff --git a/piccolo/apps/playground/commands/run.py b/piccolo/apps/playground/commands/run.py index 32f840d5..b4bc23f7 100644 --- a/piccolo/apps/playground/commands/run.py +++ b/piccolo/apps/playground/commands/run.py @@ -2,6 +2,7 @@ Populates a database with an example schema and data, and launches a shell for interacting with the data using Piccolo. """ + import datetime import sys import typing as t diff --git a/piccolo/apps/user/tables.py b/piccolo/apps/user/tables.py index 878f0670..a9a38910 100644 --- a/piccolo/apps/user/tables.py +++ b/piccolo/apps/user/tables.py @@ -1,6 +1,7 @@ """ A User model, used for authentication. """ + from __future__ import annotations import datetime diff --git a/piccolo/columns/base.py b/piccolo/columns/base.py index d477dc99..886a0ee4 100644 --- a/piccolo/columns/base.py +++ b/piccolo/columns/base.py @@ -887,9 +887,11 @@ def get_sql_value(self, value: t.Any) -> t.Any: return ( "'{" + ", ".join( - f'"{i}"' - if isinstance(i, str) - else str(self.get_sql_value(i)) + ( + f'"{i}"' + if isinstance(i, str) + else str(self.get_sql_value(i)) + ) for i in value ) ) + "}'" diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 05353eab..68e6e0ce 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -67,7 +67,7 @@ class Band(Table): try: from zoneinfo import ZoneInfo except ImportError: - from backports.zoneinfo import ZoneInfo + from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns.base import ColumnMeta @@ -355,12 +355,10 @@ def __radd__(self, value: t.Union[str, Varchar, Text]) -> QueryString: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> str: - ... + def __get__(self, obj: Table, objtype=None) -> str: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Varchar: - ... + def __get__(self, obj: None, objtype=None) -> Varchar: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -394,12 +392,10 @@ def __init__(self, *args, **kwargs): # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> str: - ... + def __get__(self, obj: Table, objtype=None) -> str: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Secret: - ... + def __get__(self, obj: None, objtype=None) -> Secret: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -461,12 +457,10 @@ def __radd__(self, value: t.Union[str, Varchar, Text]) -> QueryString: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> str: - ... + def __get__(self, obj: Table, objtype=None) -> str: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Text: - ... + def __get__(self, obj: None, objtype=None) -> Text: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -526,12 +520,10 @@ def __init__(self, default: UUIDArg = UUID4(), **kwargs) -> None: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> uuid.UUID: - ... + def __get__(self, obj: Table, objtype=None) -> uuid.UUID: ... @t.overload - def __get__(self, obj: None, objtype=None) -> UUID: - ... + def __get__(self, obj: None, objtype=None) -> UUID: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -646,12 +638,10 @@ def __rfloordiv__( # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> int: - ... + def __get__(self, obj: Table, objtype=None) -> int: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Integer: - ... + def __get__(self, obj: None, objtype=None) -> Integer: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -704,12 +694,10 @@ def column_type(self): # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> int: - ... + def __get__(self, obj: Table, objtype=None) -> int: ... @t.overload - def __get__(self, obj: None, objtype=None) -> BigInt: - ... + def __get__(self, obj: None, objtype=None) -> BigInt: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -754,12 +742,10 @@ def column_type(self): # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> int: - ... + def __get__(self, obj: Table, objtype=None) -> int: ... @t.overload - def __get__(self, obj: None, objtype=None) -> SmallInt: - ... + def __get__(self, obj: None, objtype=None) -> SmallInt: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -806,12 +792,10 @@ def default(self): # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> int: - ... + def __get__(self, obj: Table, objtype=None) -> int: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Serial: - ... + def __get__(self, obj: None, objtype=None) -> Serial: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -840,12 +824,10 @@ def column_type(self): # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> int: - ... + def __get__(self, obj: Table, objtype=None) -> int: ... @t.overload - def __get__(self, obj: None, objtype=None) -> BigSerial: - ... + def __get__(self, obj: None, objtype=None) -> BigSerial: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -875,12 +857,10 @@ def __init__(self, **kwargs) -> None: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> int: - ... + def __get__(self, obj: Table, objtype=None) -> int: ... @t.overload - def __get__(self, obj: None, objtype=None) -> PrimaryKey: - ... + def __get__(self, obj: None, objtype=None) -> PrimaryKey: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -965,12 +945,10 @@ def __sub__(self, value: timedelta) -> QueryString: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> datetime: - ... + def __get__(self, obj: Table, objtype=None) -> datetime: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Timestamp: - ... + def __get__(self, obj: None, objtype=None) -> Timestamp: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -982,7 +960,7 @@ def __set__(self, obj, value: t.Union[datetime, None]): class Timestamptz(Column): """ Used for storing timezone aware datetimes. Uses the ``datetime`` type for - values. The values are converted to UTC when saved into the database and + values. The values are converted to UTC when saved into the database and are converted back into the timezone of the column on select queries. **Example** @@ -1025,9 +1003,9 @@ class TallinnConcerts(Table): def __init__( self, - tz: ZoneInfo = ZoneInfo('UTC'), - default: TimestamptzArg = TimestamptzNow(), - **kwargs + tz: ZoneInfo = ZoneInfo("UTC"), + default: TimestamptzArg = TimestamptzNow(), + **kwargs, ) -> None: self._validate_default( default, TimestamptzArg.__args__ # type: ignore @@ -1070,12 +1048,10 @@ def __sub__(self, value: timedelta) -> QueryString: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> datetime: - ... + def __get__(self, obj: Table, objtype=None) -> datetime: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Timestamptz: - ... + def __get__(self, obj: None, objtype=None) -> Timestamptz: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -1150,12 +1126,10 @@ def __sub__(self, value: timedelta) -> QueryString: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> date: - ... + def __get__(self, obj: Table, objtype=None) -> date: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Date: - ... + def __get__(self, obj: None, objtype=None) -> Date: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -1227,12 +1201,10 @@ def __sub__(self, value: timedelta) -> QueryString: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> time: - ... + def __get__(self, obj: Table, objtype=None) -> time: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Time: - ... + def __get__(self, obj: None, objtype=None) -> Time: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -1318,12 +1290,10 @@ def __sub__(self, value: timedelta) -> QueryString: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> timedelta: - ... + def __get__(self, obj: Table, objtype=None) -> timedelta: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Interval: - ... + def __get__(self, obj: None, objtype=None) -> Interval: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -1412,12 +1382,10 @@ def ne(self, value) -> Where: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> bool: - ... + def __get__(self, obj: Table, objtype=None) -> bool: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Boolean: - ... + def __get__(self, obj: None, objtype=None) -> Boolean: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -1515,12 +1483,10 @@ def __init__( # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> decimal.Decimal: - ... + def __get__(self, obj: Table, objtype=None) -> decimal.Decimal: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Numeric: - ... + def __get__(self, obj: None, objtype=None) -> Numeric: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -1538,12 +1504,10 @@ class Decimal(Numeric): # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> decimal.Decimal: - ... + def __get__(self, obj: Table, objtype=None) -> decimal.Decimal: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Decimal: - ... + def __get__(self, obj: None, objtype=None) -> Decimal: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -1589,12 +1553,10 @@ def __init__( # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> float: - ... + def __get__(self, obj: Table, objtype=None) -> float: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Real: - ... + def __get__(self, obj: None, objtype=None) -> Real: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -1612,12 +1574,10 @@ class Float(Real): # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> float: - ... + def __get__(self, obj: Table, objtype=None) -> float: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Float: - ... + def __get__(self, obj: None, objtype=None) -> Float: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -1639,12 +1599,10 @@ def column_type(self): # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> float: - ... + def __get__(self, obj: Table, objtype=None) -> float: ... @t.overload - def __get__(self, obj: None, objtype=None) -> DoublePrecision: - ... + def __get__(self, obj: None, objtype=None) -> DoublePrecision: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -1884,8 +1842,7 @@ def __init__( on_update: OnUpdate = OnUpdate.cascade, target_column: t.Union[str, Column, None] = None, **kwargs, - ) -> None: - ... + ) -> None: ... @t.overload def __init__( @@ -1897,8 +1854,7 @@ def __init__( on_update: OnUpdate = OnUpdate.cascade, target_column: t.Union[str, Column, None] = None, **kwargs, - ) -> None: - ... + ) -> None: ... @t.overload def __init__( @@ -1910,8 +1866,7 @@ def __init__( on_update: OnUpdate = OnUpdate.cascade, target_column: t.Union[str, Column, None] = None, **kwargs, - ) -> None: - ... + ) -> None: ... def __init__( self, @@ -2261,16 +2216,15 @@ def __getattribute__(self, name: str) -> t.Union[Column, t.Any]: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> t.Any: - ... + def __get__(self, obj: Table, objtype=None) -> t.Any: ... @t.overload - def __get__(self, obj: None, objtype=None) -> ForeignKey[ReferencedTable]: - ... + def __get__( + self, obj: None, objtype=None + ) -> ForeignKey[ReferencedTable]: ... @t.overload - def __get__(self, obj: t.Any, objtype=None) -> t.Any: - ... + def __get__(self, obj: t.Any, objtype=None) -> t.Any: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -2330,12 +2284,10 @@ def column_type(self): # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> str: - ... + def __get__(self, obj: Table, objtype=None) -> str: ... @t.overload - def __get__(self, obj: None, objtype=None) -> JSON: - ... + def __get__(self, obj: None, objtype=None) -> JSON: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -2400,12 +2352,10 @@ def ne(self, value) -> Where: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> str: - ... + def __get__(self, obj: Table, objtype=None) -> str: ... @t.overload - def __get__(self, obj: None, objtype=None) -> JSONB: - ... + def __get__(self, obj: None, objtype=None) -> JSONB: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -2473,12 +2423,10 @@ def __init__( # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> bytes: - ... + def __get__(self, obj: Table, objtype=None) -> bytes: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Bytea: - ... + def __get__(self, obj: None, objtype=None) -> Bytea: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -2496,12 +2444,10 @@ class Blob(Bytea): # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> bytes: - ... + def __get__(self, obj: Table, objtype=None) -> bytes: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Blob: - ... + def __get__(self, obj: None, objtype=None) -> Blob: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self @@ -2752,12 +2698,10 @@ def __add__(self, value: t.List[t.Any]) -> QueryString: # Descriptors @t.overload - def __get__(self, obj: Table, objtype=None) -> t.List[t.Any]: - ... + def __get__(self, obj: Table, objtype=None) -> t.List[t.Any]: ... @t.overload - def __get__(self, obj: None, objtype=None) -> Array: - ... + def __get__(self, obj: None, objtype=None) -> Array: ... def __get__(self, obj, objtype=None): return obj.__dict__[self._meta.name] if obj else self diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index c498594d..a055c032 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -1,34 +1,31 @@ from __future__ import annotations -import datetime +import datetime as pydatetime import typing as t from enum import Enum try: from zoneinfo import ZoneInfo except ImportError: - from backports.zoneinfo import ZoneInfo + from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 from .timestamp import TimestampCustom, TimestampNow, TimestampOffset class TimestamptzOffset(TimestampOffset): def __init__( - self, - days: int = 0, - hours: int = 0, - minutes: int = 0, + self, + days: int = 0, + hours: int = 0, + minutes: int = 0, seconds: int = 0, - tz: ZoneInfo = ZoneInfo('UTC') + tz: ZoneInfo = ZoneInfo("UTC"), ): self.tz = tz super().__init__( - days=days, - hours=hours, - minutes=minutes, - seconds=seconds + days=days, hours=hours, minutes=minutes, seconds=seconds ) - + @property def cockroach(self): interval_string = self.get_postgres_interval_string( @@ -37,9 +34,7 @@ def cockroach(self): return f"CURRENT_TIMESTAMP + INTERVAL '{interval_string}'" def python(self): - return datetime.datetime.now( - tz=self.tz - ) + datetime.timedelta( + return pydatetime.datetime.now(tz=self.tz) + pydatetime.timedelta( days=self.days, hours=self.hours, minutes=self.minutes, @@ -48,15 +43,15 @@ def python(self): class TimestamptzNow(TimestampNow): - def __init__(self, tz: ZoneInfo = ZoneInfo('UTC')): + def __init__(self, tz: ZoneInfo = ZoneInfo("UTC")): self.tz = tz - + @property def cockroach(self): return "current_timestamp" def python(self): - return datetime.datetime.now(tz=self.tz) + return pydatetime.datetime.now(tz=self.tz) class TimestamptzCustom(TimestampCustom): @@ -68,7 +63,7 @@ def __init__( hour: int = 0, second: int = 0, microsecond: int = 0, - tz: ZoneInfo = ZoneInfo('UTC') + tz: ZoneInfo = ZoneInfo("UTC"), ): self.tz = tz super().__init__( @@ -77,16 +72,16 @@ def __init__( day=day, hour=hour, second=second, - microsecond=microsecond + microsecond=microsecond, ) - + @property def cockroach(self): return "'{}'".format(self.datetime.isoformat().replace("T", " ")) @property def datetime(self): - return datetime.datetime( + return pydatetime.datetime( year=self.year, month=self.month, day=self.day, @@ -97,7 +92,9 @@ def datetime(self): ) @classmethod - def from_datetime(cls, instance: datetime.datetime, tz: ZoneInfo = ZoneInfo('UTC')): # type: ignore + def from_datetime( + cls, instance: pydatetime.datetime, tz: ZoneInfo = ZoneInfo("UTC") + ): # type: ignore if instance.tzinfo is not None: instance = instance.astimezone(tz) return cls( @@ -116,7 +113,7 @@ def from_datetime(cls, instance: datetime.datetime, tz: ZoneInfo = ZoneInfo('UTC TimestamptzOffset, Enum, None, - datetime.datetime, + pydatetime.datetime, ] diff --git a/piccolo/columns/reference.py b/piccolo/columns/reference.py index 841545ee..f6edcdd5 100644 --- a/piccolo/columns/reference.py +++ b/piccolo/columns/reference.py @@ -1,6 +1,7 @@ """ Dataclasses for storing lazy references between ForeignKey columns and tables. """ + from __future__ import annotations import importlib diff --git a/piccolo/conf/apps.py b/piccolo/conf/apps.py index c311e156..47631c47 100644 --- a/piccolo/conf/apps.py +++ b/piccolo/conf/apps.py @@ -25,8 +25,7 @@ class MigrationModule(ModuleType): @staticmethod @abstractmethod - async def forwards() -> MigrationManager: - ... + async def forwards() -> MigrationManager: ... class PiccoloAppModule(ModuleType): diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index 5b1c9600..7b8c3ad4 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -124,10 +124,10 @@ async def run( results = objects[0] if objects else None - modified_response: t.Optional[ - TableInstance - ] = await self.query.callback_delegate.invoke( - results=results, kind=CallbackType.success + modified_response: t.Optional[TableInstance] = ( + await self.query.callback_delegate.invoke( + results=results, kind=CallbackType.success + ) ) return modified_response @@ -355,10 +355,10 @@ async def run( # With callbacks, the user can return any data that they want. # Assume that most of the time they will still return a list of # Table instances. - modified: t.List[ - TableInstance - ] = await self.callback_delegate.invoke( - results, kind=CallbackType.success + modified: t.List[TableInstance] = ( + await self.callback_delegate.invoke( + results, kind=CallbackType.success + ) ) return modified else: diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index a00745e4..a2a77b15 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -604,20 +604,16 @@ def order_by( return self @t.overload - def output(self: Self, *, as_list: bool) -> SelectList: - ... + def output(self: Self, *, as_list: bool) -> SelectList: ... @t.overload - def output(self: Self, *, as_json: bool) -> SelectJSON: - ... + def output(self: Self, *, as_json: bool) -> SelectJSON: ... @t.overload - def output(self: Self, *, load_json: bool) -> Self: - ... + def output(self: Self, *, load_json: bool) -> Self: ... @t.overload - def output(self: Self, *, nested: bool) -> Self: - ... + def output(self: Self, *, nested: bool) -> Self: ... def output( self: Self, diff --git a/piccolo/query/mixins.py b/piccolo/query/mixins.py index 56be35a8..8d7c6a4a 100644 --- a/piccolo/query/mixins.py +++ b/piccolo/query/mixins.py @@ -639,9 +639,9 @@ class OnConflictAction(str, Enum): class OnConflictItem: target: t.Optional[t.Union[str, Column, t.Tuple[Column, ...]]] = None action: t.Optional[OnConflictAction] = None - values: t.Optional[ - t.Sequence[t.Union[Column, t.Tuple[Column, t.Any]]] - ] = None + values: t.Optional[t.Sequence[t.Union[Column, t.Tuple[Column, t.Any]]]] = ( + None + ) where: t.Optional[Combinable] = None @property diff --git a/piccolo/query/proxy.py b/piccolo/query/proxy.py index 7ded47b8..30ce0608 100644 --- a/piccolo/query/proxy.py +++ b/piccolo/query/proxy.py @@ -8,8 +8,9 @@ class Runnable(Protocol): - async def run(self, node: t.Optional[str] = None, in_pool: bool = True): - ... + async def run( + self, node: t.Optional[str] = None, in_pool: bool = True + ): ... QueryType = t.TypeVar("QueryType", bound=Runnable) diff --git a/piccolo/table.py b/piccolo/table.py index ec76ee24..3772b9ab 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -60,7 +60,7 @@ try: from zoneinfo import ZoneInfo except ImportError: - from backports.zoneinfo import ZoneInfo + from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns import Selectable @@ -445,7 +445,7 @@ def __init__( if isinstance(column, Timestamptz) and isinstance(value, datetime): value = value.astimezone(column.tz) - + self[column._meta.name] = value unrecognized = kwargs.keys() @@ -576,12 +576,10 @@ def refresh( @t.overload def get_related( self, foreign_key: ForeignKey[ReferencedTable] - ) -> First[ReferencedTable]: - ... + ) -> First[ReferencedTable]: ... @t.overload - def get_related(self, foreign_key: str) -> First[Table]: - ... + def get_related(self, foreign_key: str) -> First[Table]: ... def get_related( self, foreign_key: t.Union[str, ForeignKey[ReferencedTable]] @@ -755,9 +753,9 @@ def to_dict(self, *columns: Column) -> t.Dict[str, t.Any]: if isinstance(value, Table): value = value.to_dict(*columns) - output[ - alias_names.get(column._meta.name) or column._meta.name - ] = value + output[alias_names.get(column._meta.name) or column._meta.name] = ( + value + ) return output def __setitem__(self, key: str, value: t.Any): diff --git a/tests/columns/test_reference.py b/tests/columns/test_reference.py index ea9887e0..21daa2f5 100644 --- a/tests/columns/test_reference.py +++ b/tests/columns/test_reference.py @@ -2,6 +2,7 @@ Most of the tests for piccolo/columns/reference.py are covered in piccolo/columns/test_foreignkey.py """ + from unittest import TestCase from piccolo.columns.reference import LazyTableReference diff --git a/tests/columns/test_timestamptz.py b/tests/columns/test_timestamptz.py index 8e239900..ee690938 100644 --- a/tests/columns/test_timestamptz.py +++ b/tests/columns/test_timestamptz.py @@ -1,8 +1,7 @@ import datetime +from operator import eq from unittest import TestCase -from dateutil import tz - from piccolo.columns.column_types import Timestamptz from piccolo.columns.defaults.timestamptz import ( TimestamptzCustom, @@ -11,9 +10,19 @@ ) from piccolo.table import Table +try: + from zoneinfo import ZoneInfo +except ImportError: + from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 + + +UTC_TZ = ZoneInfo("UTC") +LOCAL_TZ = ZoneInfo("Europe/Tallinn") + class MyTable(Table): - created_on = Timestamptz() + created_on_utc = Timestamptz(tz=UTC_TZ) + created_on_local = Timestamptz(tz=LOCAL_TZ) class MyTableDefault(Table): @@ -22,18 +31,19 @@ class MyTableDefault(Table): `Timestamptz`. """ - created_on = Timestamptz(default=TimestamptzNow()) - created_on_offset = Timestamptz(default=TimestamptzOffset(days=1)) - created_on_custom = Timestamptz(default=TimestamptzCustom(year=2021)) + created_on = Timestamptz(default=TimestamptzNow(tz=LOCAL_TZ), tz=LOCAL_TZ) + created_on_offset = Timestamptz( + default=TimestamptzOffset(days=1, tz=LOCAL_TZ), tz=LOCAL_TZ + ) + created_on_custom = Timestamptz( + default=TimestamptzCustom(year=2021, tz=LOCAL_TZ), tz=LOCAL_TZ + ) created_on_datetime = Timestamptz( - default=datetime.datetime(year=2020, month=1, day=1) + default=datetime.datetime(year=2020, month=1, day=1, tzinfo=LOCAL_TZ), + tz=LOCAL_TZ, ) -class CustomTimezone(datetime.tzinfo): - pass - - class TestTimestamptz(TestCase): def setUp(self): MyTable.create_table().run_sync() @@ -45,37 +55,32 @@ def test_timestamptz_timezone_aware(self): """ Test storing a timezone aware timestamp. """ - for tzinfo in ( - datetime.timezone.utc, - tz.gettz("America/New_York"), - ): - created_on = datetime.datetime( - year=2020, - month=1, - day=1, - hour=12, - minute=0, - second=0, - tzinfo=tzinfo, - ) - row = MyTable(created_on=created_on) - row.save().run_sync() - - # Fetch it back from the database - result = ( - MyTable.objects() - .where( - MyTable._meta.primary_key - == getattr(row, MyTable._meta.primary_key._meta.name) - ) - .first() - .run_sync() - ) - assert result is not None - self.assertEqual(result.created_on, created_on) - - # The database converts it to UTC - self.assertEqual(result.created_on.tzinfo, datetime.timezone.utc) + dt_args = dict(year=2020, month=1, day=1, hour=12, minute=0, second=0) + created_on_utc = datetime.datetime(**dt_args, tzinfo=ZoneInfo("UTC")) + created_on_local = datetime.datetime( + **dt_args, tzinfo=ZoneInfo("Europe/Tallinn") + ) + row = MyTable( + created_on_utc=created_on_utc, created_on_local=created_on_local + ) + row.save().run_sync() + + # Fetch it back from the database + p_key = MyTable._meta.primary_key + p_key_name = getattr(row, p_key._meta.name) + result = ( + MyTable.objects().where(eq(p_key, p_key_name)).first().run_sync() + ) + assert result is not None + self.assertEqual(result.created_on_utc, created_on_utc) + self.assertEqual(result.created_on_local, created_on_local) + + # The database stores the datetime of the column in UTC timezone, but + # the column converts it back to the timezone that is defined for it. + self.assertEqual(result.created_on_utc.tzinfo, created_on_utc.tzinfo) + self.assertEqual( + result.created_on_local.tzinfo, created_on_local.tzinfo + ) class TestTimestamptzDefault(TestCase): @@ -89,12 +94,25 @@ def test_timestamptz_default(self): """ Make sure the default value gets created, and can be retrieved. """ - created_on = datetime.datetime.now(tz=datetime.timezone.utc) + created_on = datetime.datetime.now(tz=LOCAL_TZ) row = MyTableDefault() row.save().run_sync() result = MyTableDefault.objects().first().run_sync() assert result is not None + delta = result.created_on - created_on self.assertLess(delta, datetime.timedelta(seconds=1)) - self.assertEqual(result.created_on.tzinfo, datetime.timezone.utc) + self.assertEqual(result.created_on.tzinfo, created_on.tzinfo) + + delta = result.created_on_offset - created_on + self.assertLessEqual(delta, datetime.timedelta(days=1)) + self.assertEqual(result.created_on_offset.tzinfo, created_on.tzinfo) + + delta = created_on - result.created_on_custom + self.assertGreaterEqual(delta, datetime.timedelta(days=delta.days)) + self.assertEqual(result.created_on_custom.tzinfo, created_on.tzinfo) + + delta = created_on - result.created_on_datetime + self.assertGreaterEqual(delta, datetime.timedelta(days=delta.days)) + self.assertEqual(result.created_on_datetime.tzinfo, created_on.tzinfo) diff --git a/tests/conf/example.py b/tests/conf/example.py index c7dce05a..ef454162 100644 --- a/tests/conf/example.py +++ b/tests/conf/example.py @@ -2,6 +2,7 @@ This file is used by test_apps.py to make sure we can exclude imported ``Table`` subclasses when using ``table_finder``. """ + from piccolo.apps.user.tables import BaseUser from piccolo.columns.column_types import ForeignKey, Varchar from piccolo.table import Table From 41a6d0c59ea3a4936c0e6ae14e81acc3ad03abd8 Mon Sep 17 00:00:00 2001 From: Mattias Aabmets Date: Thu, 28 Mar 2024 17:47:01 +0200 Subject: [PATCH 14/32] Fixed zoneinfo module import for autogenerated migrations files --- piccolo/apps/migrations/auto/serialisation.py | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/piccolo/apps/migrations/auto/serialisation.py b/piccolo/apps/migrations/auto/serialisation.py index b3644b85..9791d2f9 100644 --- a/piccolo/apps/migrations/auto/serialisation.py +++ b/piccolo/apps/migrations/auto/serialisation.py @@ -12,12 +12,22 @@ from dataclasses import dataclass, field from enum import Enum -from piccolo.columns import Column +from piccolo.columns import Column, Timestamptz from piccolo.columns.defaults.base import Default +from piccolo.columns.defaults.timestamptz import ( + TimestamptzCustom, + TimestamptzNow, + TimestamptzOffset, +) from piccolo.columns.reference import LazyTableReference from piccolo.table import Table from piccolo.utils.repr import repr_class_instance +try: + from zoneinfo import ZoneInfo +except ImportError: + from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 + from .serialisation_legacy import deserialise_legacy_params ############################################################################### @@ -546,8 +556,29 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: expect_conflict_with_global_name=UniqueGlobalNames.DEFAULT, ) ) + # ZoneInfo for Timestamptz* instances + in_group = ( + Timestamptz, TimestamptzNow, + TimestamptzCustom, TimestamptzOffset + ) + if isinstance(value, in_group): + extra_imports.append( + Import( + module=ZoneInfo.__module__, + target=None, + ) + ) continue + # ZoneInfo instances + if isinstance(value, ZoneInfo): + extra_imports.append( + Import( + module=value.__class__.__module__, + target=None, + ) + ) + # Dates and times if isinstance( value, (datetime.time, datetime.datetime, datetime.date) From ad0d1d15d9030e8b034639fa156f3373c1c3e74c Mon Sep 17 00:00:00 2001 From: Mattias Aabmets Date: Thu, 28 Mar 2024 20:53:01 +0200 Subject: [PATCH 15/32] I swear to god if this commit doesn't fix the linting and test errors --- piccolo/apps/migrations/auto/serialisation.py | 6 ++++-- tests/columns/test_timestamptz.py | 5 ++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/piccolo/apps/migrations/auto/serialisation.py b/piccolo/apps/migrations/auto/serialisation.py index 9791d2f9..a06cf6bf 100644 --- a/piccolo/apps/migrations/auto/serialisation.py +++ b/piccolo/apps/migrations/auto/serialisation.py @@ -558,8 +558,10 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: ) # ZoneInfo for Timestamptz* instances in_group = ( - Timestamptz, TimestamptzNow, - TimestamptzCustom, TimestamptzOffset + Timestamptz, + TimestamptzNow, + TimestamptzCustom, + TimestamptzOffset, ) if isinstance(value, in_group): extra_imports.append( diff --git a/tests/columns/test_timestamptz.py b/tests/columns/test_timestamptz.py index ee690938..3e36b0cf 100644 --- a/tests/columns/test_timestamptz.py +++ b/tests/columns/test_timestamptz.py @@ -1,4 +1,5 @@ import datetime +import time from operator import eq from unittest import TestCase @@ -95,6 +96,8 @@ def test_timestamptz_default(self): Make sure the default value gets created, and can be retrieved. """ created_on = datetime.datetime.now(tz=LOCAL_TZ) + time.sleep(1e-5) + row = MyTableDefault() row.save().run_sync() @@ -106,7 +109,7 @@ def test_timestamptz_default(self): self.assertEqual(result.created_on.tzinfo, created_on.tzinfo) delta = result.created_on_offset - created_on - self.assertLessEqual(delta, datetime.timedelta(days=1)) + self.assertGreaterEqual(delta, datetime.timedelta(days=1)) self.assertEqual(result.created_on_offset.tzinfo, created_on.tzinfo) delta = created_on - result.created_on_custom From c816d8ec62b63ac388d8b5490b1843288d5dc6df Mon Sep 17 00:00:00 2001 From: Mattias Aabmets Date: Thu, 28 Mar 2024 21:56:23 +0200 Subject: [PATCH 16/32] Added more ZoneInfo import ignore rules for MyPy --- piccolo/apps/migrations/auto/serialisation.py | 2 +- piccolo/columns/column_types.py | 2 +- piccolo/columns/defaults/timestamptz.py | 2 +- piccolo/table.py | 2 +- tests/columns/test_timestamptz.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/piccolo/apps/migrations/auto/serialisation.py b/piccolo/apps/migrations/auto/serialisation.py index a06cf6bf..81ee2380 100644 --- a/piccolo/apps/migrations/auto/serialisation.py +++ b/piccolo/apps/migrations/auto/serialisation.py @@ -24,7 +24,7 @@ from piccolo.utils.repr import repr_class_instance try: - from zoneinfo import ZoneInfo + from zoneinfo import ZoneInfo # type: ignore except ImportError: from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 68e6e0ce..428cb3c4 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -65,7 +65,7 @@ class Band(Table): from piccolo.utils.warnings import colored_warning try: - from zoneinfo import ZoneInfo + from zoneinfo import ZoneInfo # type: ignore except ImportError: from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index a055c032..ffb04ec5 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -5,7 +5,7 @@ from enum import Enum try: - from zoneinfo import ZoneInfo + from zoneinfo import ZoneInfo # type: ignore except ImportError: from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 diff --git a/piccolo/table.py b/piccolo/table.py index 3772b9ab..00d5c950 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -58,7 +58,7 @@ from piccolo.utils.warnings import colored_warning try: - from zoneinfo import ZoneInfo + from zoneinfo import ZoneInfo # type: ignore except ImportError: from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 diff --git a/tests/columns/test_timestamptz.py b/tests/columns/test_timestamptz.py index 3e36b0cf..163919e7 100644 --- a/tests/columns/test_timestamptz.py +++ b/tests/columns/test_timestamptz.py @@ -12,7 +12,7 @@ from piccolo.table import Table try: - from zoneinfo import ZoneInfo + from zoneinfo import ZoneInfo # type: ignore except ImportError: from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 From cbf6d41b64a36b3c6c7e60430c2130ef88baac93 Mon Sep 17 00:00:00 2001 From: Mattias Aabmets Date: Fri, 29 Mar 2024 00:04:26 +0200 Subject: [PATCH 17/32] Added pragma no cover to ZoneInfo import except clauses --- piccolo/apps/migrations/auto/serialisation.py | 2 +- piccolo/columns/column_types.py | 2 +- piccolo/columns/defaults/timestamptz.py | 2 +- piccolo/table.py | 2 +- tests/columns/test_timestamptz.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/piccolo/apps/migrations/auto/serialisation.py b/piccolo/apps/migrations/auto/serialisation.py index 81ee2380..83ed4d22 100644 --- a/piccolo/apps/migrations/auto/serialisation.py +++ b/piccolo/apps/migrations/auto/serialisation.py @@ -25,7 +25,7 @@ try: from zoneinfo import ZoneInfo # type: ignore -except ImportError: +except ImportError: # pragma: no cover from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 from .serialisation_legacy import deserialise_legacy_params diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 428cb3c4..34ebb3f6 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -66,7 +66,7 @@ class Band(Table): try: from zoneinfo import ZoneInfo # type: ignore -except ImportError: +except ImportError: # pragma: no cover from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 if t.TYPE_CHECKING: # pragma: no cover diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index ffb04ec5..90ba0fa9 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -6,7 +6,7 @@ try: from zoneinfo import ZoneInfo # type: ignore -except ImportError: +except ImportError: # pragma: no cover from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 from .timestamp import TimestampCustom, TimestampNow, TimestampOffset diff --git a/piccolo/table.py b/piccolo/table.py index 00d5c950..c8e5fd47 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -59,7 +59,7 @@ try: from zoneinfo import ZoneInfo # type: ignore -except ImportError: +except ImportError: # pragma: no cover from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 if t.TYPE_CHECKING: # pragma: no cover diff --git a/tests/columns/test_timestamptz.py b/tests/columns/test_timestamptz.py index 163919e7..7d8ee3da 100644 --- a/tests/columns/test_timestamptz.py +++ b/tests/columns/test_timestamptz.py @@ -13,7 +13,7 @@ try: from zoneinfo import ZoneInfo # type: ignore -except ImportError: +except ImportError: # pragma: no cover from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 From db539394018bfdb1e03b2c67e295cabe61a03cf6 Mon Sep 17 00:00:00 2001 From: Mattias Aabmets Date: Fri, 29 Mar 2024 10:14:05 +0200 Subject: [PATCH 18/32] Removed ZoneInfo import from table.py --- piccolo/table.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/piccolo/table.py b/piccolo/table.py index c8e5fd47..2dc3e1ca 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -57,11 +57,6 @@ from piccolo.utils.sync import run_sync from piccolo.utils.warnings import colored_warning -try: - from zoneinfo import ZoneInfo # type: ignore -except ImportError: # pragma: no cover - from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 - if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns import Selectable From 36b0da6f5f6fb7147a159d9dbbcec69aab8de21a Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 5 Apr 2024 14:06:37 +0100 Subject: [PATCH 19/32] move `tz` after `default` for better backwards compatibility --- piccolo/columns/column_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 770b88c5..990fc4e7 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -1003,8 +1003,8 @@ class TallinnConcerts(Table): def __init__( self, - tz: ZoneInfo = ZoneInfo("UTC"), default: TimestamptzArg = TimestamptzNow(), + tz: ZoneInfo = ZoneInfo("UTC"), **kwargs, ) -> None: self._validate_default( From 9fda29d2c6e8032c7133937d35dc5429d80d1c95 Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 5 Apr 2024 14:07:44 +0100 Subject: [PATCH 20/32] drop `tz_type` - I don't think we need it for now --- piccolo/columns/column_types.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 990fc4e7..0eaf9d8a 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -993,7 +993,6 @@ class TallinnConcerts(Table): """ value_type = datetime - tz_type = ZoneInfo # Currently just used by ModelBuilder, to know that we want a timezone # aware datetime. @@ -1019,7 +1018,7 @@ def __init__( self.tz = tz self.default = default - kwargs.update({"tz": tz, "default": default}) + kwargs.update({"default": default, "tz": tz}) super().__init__(**kwargs) ########################################################################### From 3733218898ab57e99c5fad225e67dcf1ee8c8df1 Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 5 Apr 2024 14:15:13 +0100 Subject: [PATCH 21/32] add missing continues --- piccolo/apps/migrations/auto/serialisation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/piccolo/apps/migrations/auto/serialisation.py b/piccolo/apps/migrations/auto/serialisation.py index 83ed4d22..87286408 100644 --- a/piccolo/apps/migrations/auto/serialisation.py +++ b/piccolo/apps/migrations/auto/serialisation.py @@ -580,6 +580,7 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: target=None, ) ) + continue # Dates and times if isinstance( @@ -666,6 +667,7 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams: extra_imports.append( Import(module=module_name, target=type_.__name__) ) + continue # Functions if inspect.isfunction(value): From 7c6ce55a57e165075cd6e3da9656e900ff78a5e7 Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 5 Apr 2024 15:14:05 +0100 Subject: [PATCH 22/32] add `at_timezone` clause --- piccolo/columns/base.py | 7 +++++-- piccolo/columns/column_types.py | 30 +++++++++++++++++++++++++++++- piccolo/table.py | 3 --- 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/piccolo/columns/base.py b/piccolo/columns/base.py index 886a0ee4..016b0490 100644 --- a/piccolo/columns/base.py +++ b/piccolo/columns/base.py @@ -945,8 +945,8 @@ def ddl(self) -> str: return query - def copy(self) -> Column: - column: Column = copy.copy(self) + def copy(self: Self) -> Self: + column = copy.copy(self) column._meta = self._meta.copy() return column @@ -971,3 +971,6 @@ def __repr__(self): f"{table_class_name}.{self._meta.name} - " f"{self.__class__.__name__}" ) + + +Self = t.TypeVar("Self", bound=Column) diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 0eaf9d8a..66b2a6af 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -1021,6 +1021,34 @@ def __init__( kwargs.update({"default": default, "tz": tz}) super().__init__(**kwargs) + ########################################################################### + + def at_time_zone(self, tz: t.Union[ZoneInfo, str]) -> Timestamptz: + """ + By default, the database returns the value in UTC. This lets us get + the value converted to the specified timezone. + """ + tz = ZoneInfo(tz) if isinstance(tz, str) else tz + instance = self.copy() + instance.tz = tz + return instance + + ########################################################################### + + def get_select_string( + self, engine_type: str, with_alias: bool = True + ) -> str: + select_string = self._meta.get_full_name(with_alias=False) + + if self.tz is not None: + select_string += f" AT TIME ZONE '{self.tz.key}'" + + if with_alias: + alias = self._alias or self._meta.get_default_alias() + select_string += f' AS "{alias}"' + + return select_string + ########################################################################### # For update queries @@ -2317,7 +2345,7 @@ def arrow(self, key: str) -> JSONB: Allows part of the JSON structure to be returned - for example, for {"a": 1}, and a key value of "a", then 1 will be returned. """ - instance = t.cast(JSONB, self.copy()) + instance = self.copy() instance.json_operator = f"-> '{key}'" return instance diff --git a/piccolo/table.py b/piccolo/table.py index 2dc3e1ca..b44d6469 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -438,9 +438,6 @@ def __init__( ): raise ValueError(f"{column._meta.name} wasn't provided") - if isinstance(column, Timestamptz) and isinstance(value, datetime): - value = value.astimezone(column.tz) - self[column._meta.name] = value unrecognized = kwargs.keys() From 755b059ab81484b892e6435fa176b3a9cf59b65d Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 5 Apr 2024 15:19:26 +0100 Subject: [PATCH 23/32] remove unused imports --- piccolo/table.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/piccolo/table.py b/piccolo/table.py index b44d6469..b4fcbf94 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -6,7 +6,6 @@ import typing as t import warnings from dataclasses import dataclass, field -from datetime import datetime from piccolo.columns import Column from piccolo.columns.column_types import ( @@ -18,7 +17,6 @@ ReferencedTable, Secret, Serial, - Timestamptz, ) from piccolo.columns.defaults.base import Default from piccolo.columns.indexes import IndexMethod From 026b182444033ae0e8bb70840c1a6cf9f5d63811 Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 5 Apr 2024 15:21:40 +0100 Subject: [PATCH 24/32] check `!= ZoneInfo("UTC")` --- piccolo/columns/column_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 66b2a6af..6ee7fdbe 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -1040,7 +1040,7 @@ def get_select_string( ) -> str: select_string = self._meta.get_full_name(with_alias=False) - if self.tz is not None: + if self.tz != ZoneInfo("UTC"): select_string += f" AT TIME ZONE '{self.tz.key}'" if with_alias: From e9b02be37dd4e1f373b1d2e7b812b331316ab17b Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 5 Apr 2024 18:03:56 +0100 Subject: [PATCH 25/32] centralised imports for ZoneInfo --- piccolo/apps/migrations/auto/serialisation.py | 6 +----- piccolo/columns/column_types.py | 6 +----- piccolo/columns/defaults/timestamptz.py | 5 +---- piccolo/utils/zoneinfo.py | 7 +++++++ tests/columns/test_timestamptz.py | 7 +------ 5 files changed, 11 insertions(+), 20 deletions(-) create mode 100644 piccolo/utils/zoneinfo.py diff --git a/piccolo/apps/migrations/auto/serialisation.py b/piccolo/apps/migrations/auto/serialisation.py index 87286408..ac4944e6 100644 --- a/piccolo/apps/migrations/auto/serialisation.py +++ b/piccolo/apps/migrations/auto/serialisation.py @@ -22,11 +22,7 @@ from piccolo.columns.reference import LazyTableReference from piccolo.table import Table from piccolo.utils.repr import repr_class_instance - -try: - from zoneinfo import ZoneInfo # type: ignore -except ImportError: # pragma: no cover - from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 +from piccolo.utils.zoneinfo import ZoneInfo from .serialisation_legacy import deserialise_legacy_params diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 6ee7fdbe..130b31a6 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -63,11 +63,7 @@ class Band(Table): from piccolo.querystring import QueryString, Unquoted from piccolo.utils.encoding import dump_json from piccolo.utils.warnings import colored_warning - -try: - from zoneinfo import ZoneInfo # type: ignore -except ImportError: # pragma: no cover - from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 +from piccolo.utils.zoneinfo import ZoneInfo if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns.base import ColumnMeta diff --git a/piccolo/columns/defaults/timestamptz.py b/piccolo/columns/defaults/timestamptz.py index 90ba0fa9..6e3b5187 100644 --- a/piccolo/columns/defaults/timestamptz.py +++ b/piccolo/columns/defaults/timestamptz.py @@ -4,10 +4,7 @@ import typing as t from enum import Enum -try: - from zoneinfo import ZoneInfo # type: ignore -except ImportError: # pragma: no cover - from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 +from piccolo.utils.zoneinfo import ZoneInfo from .timestamp import TimestampCustom, TimestampNow, TimestampOffset diff --git a/piccolo/utils/zoneinfo.py b/piccolo/utils/zoneinfo.py new file mode 100644 index 00000000..a2981aae --- /dev/null +++ b/piccolo/utils/zoneinfo.py @@ -0,0 +1,7 @@ +try: + from zoneinfo import ZoneInfo # type: ignore +except ImportError: # pragma: no cover + from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 + + +__all__ = ("ZoneInfo",) diff --git a/tests/columns/test_timestamptz.py b/tests/columns/test_timestamptz.py index 7d8ee3da..681a9466 100644 --- a/tests/columns/test_timestamptz.py +++ b/tests/columns/test_timestamptz.py @@ -10,12 +10,7 @@ TimestamptzOffset, ) from piccolo.table import Table - -try: - from zoneinfo import ZoneInfo # type: ignore -except ImportError: # pragma: no cover - from backports.zoneinfo import ZoneInfo # type: ignore # noqa: F401 - +from piccolo.utils.zoneinfo import ZoneInfo UTC_TZ = ZoneInfo("UTC") LOCAL_TZ = ZoneInfo("Europe/Tallinn") From 4327178420b3eca82920aa6ca81d9dc2bb42046f Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 5 Apr 2024 18:04:19 +0100 Subject: [PATCH 26/32] added method for getting column alias --- piccolo/columns/base.py | 3 +++ piccolo/columns/column_types.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/piccolo/columns/base.py b/piccolo/columns/base.py index 016b0490..e48c0a97 100644 --- a/piccolo/columns/base.py +++ b/piccolo/columns/base.py @@ -763,6 +763,9 @@ def as_alias(self, name: str) -> Column: column._alias = name return column + def _get_alias(self) -> str: + return self._alias or self._meta.get_default_alias() + def join_on(self, column: Column) -> ForeignKey: """ Joins are typically performed via foreign key columns. For example, diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 130b31a6..655440a0 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -1040,7 +1040,7 @@ def get_select_string( select_string += f" AT TIME ZONE '{self.tz.key}'" if with_alias: - alias = self._alias or self._meta.get_default_alias() + alias = self._get_alias() select_string += f' AS "{alias}"' return select_string @@ -2354,7 +2354,7 @@ def get_select_string( select_string += f" {self.json_operator}" if with_alias: - alias = self._alias or self._meta.get_default_alias() + alias = self._get_alias() select_string += f' AS "{alias}"' return select_string @@ -2659,7 +2659,7 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str: select_string += f"[{self.index}]" if with_alias: - alias = self._alias or self._meta.get_default_alias() + alias = self._get_alias() select_string += f' AS "{alias}"' return select_string From e9fae63f9928590c5c61049042391172a084db24 Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 5 Apr 2024 18:04:42 +0100 Subject: [PATCH 27/32] make sure all Timestamptz values have a tz set --- piccolo/query/methods/select.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index a2a77b15..62456e26 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -1,12 +1,13 @@ from __future__ import annotations +import datetime import decimal import itertools import typing as t from collections import OrderedDict from piccolo.columns import Column, Selectable -from piccolo.columns.column_types import JSON, JSONB, PrimaryKey +from piccolo.columns.column_types import JSON, JSONB, PrimaryKey, Timestamptz from piccolo.columns.m2m import M2MSelect from piccolo.columns.readable import Readable from piccolo.custom_types import TableInstance @@ -31,6 +32,7 @@ from piccolo.utils.dictionary import make_nested from piccolo.utils.encoding import dump_json, load_json from piccolo.utils.warnings import colored_warning +from piccolo.utils.zoneinfo import ZoneInfo if t.TYPE_CHECKING: # pragma: no cover from piccolo.custom_types import Combinable @@ -574,6 +576,33 @@ async def response_handler(self, response): ####################################################################### + # Make sure any Timestamptz values are timezone aware. + # This happens when we use `AS TIME ZONE` which returns a naive + # datetime. + + timestamptz_columns = [ + i + for i in self.columns_delegate.selected_columns + if isinstance(i, Timestamptz) + ] + + if timestamptz_columns: + for column in timestamptz_columns: + if column.tz != ZoneInfo("UTC"): + continue + + alias = column._get_alias() + + for row in response: + timestamp_value = row.get(alias) + if ( + isinstance(timestamp_value, datetime.datetime) + and timestamp_value.tzinfo is None + ): + timestamp_value.replace(tzinfo=column.tz) + + ####################################################################### + # If no columns were specified, it's a select *, so we know that # no columns were selected from related tables. was_select_star = len(self.columns_delegate.selected_columns) == 0 From 4416adc7316c5dd0d6b802ef60a791e867e7e0b4 Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 5 Apr 2024 18:11:40 +0100 Subject: [PATCH 28/32] fix typos --- piccolo/query/methods/select.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index 62456e26..cedacd53 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -588,7 +588,7 @@ async def response_handler(self, response): if timestamptz_columns: for column in timestamptz_columns: - if column.tz != ZoneInfo("UTC"): + if column.tz == ZoneInfo("UTC"): continue alias = column._get_alias() @@ -599,7 +599,7 @@ async def response_handler(self, response): isinstance(timestamp_value, datetime.datetime) and timestamp_value.tzinfo is None ): - timestamp_value.replace(tzinfo=column.tz) + row[alias] = timestamp_value.replace(tzinfo=column.tz) ####################################################################### From 0ab2d2493d827822a2d4eb44ee72d4cec9d2176f Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 5 Apr 2024 19:18:43 +0100 Subject: [PATCH 29/32] fix tests --- piccolo/query/methods/objects.py | 18 ++++++++++-------- piccolo/query/methods/select.py | 8 +++++--- tests/columns/test_timestamptz.py | 9 +++++++-- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index 7b8c3ad4..7a9ab295 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -22,7 +22,6 @@ ) from piccolo.query.proxy import Proxy from piccolo.querystring import QueryString -from piccolo.utils.dictionary import make_nested from piccolo.utils.sync import run_sync if t.TYPE_CHECKING: # pragma: no cover @@ -306,13 +305,12 @@ async def batch( return await self.table._meta.db.batch(self, **kwargs) async def response_handler(self, response): - if self.output_delegate._output.nested: - return [make_nested(i) for i in response] - else: - return response + return await self._get_select_query().response_handler( + response=response + ) - @property - def default_querystrings(self) -> t.Sequence[QueryString]: + # TODO - would be good to cache this somehow + def _get_select_query(self) -> Select: select = Select(table=self.table) for attr in ( @@ -339,7 +337,11 @@ def default_querystrings(self) -> t.Sequence[QueryString]: select.output_delegate.output(nested=True) - return select.querystrings + return select + + @property + def default_querystrings(self) -> t.Sequence[QueryString]: + return self._get_select_query().querystrings ########################################################################### diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index cedacd53..850bb192 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -580,10 +580,12 @@ async def response_handler(self, response): # This happens when we use `AS TIME ZONE` which returns a naive # datetime. + selected_columns = ( + self.columns_delegate.selected_columns or self.table.all_columns() + ) + timestamptz_columns = [ - i - for i in self.columns_delegate.selected_columns - if isinstance(i, Timestamptz) + i for i in selected_columns if isinstance(i, Timestamptz) ] if timestamptz_columns: diff --git a/tests/columns/test_timestamptz.py b/tests/columns/test_timestamptz.py index 681a9466..4ed3d5c9 100644 --- a/tests/columns/test_timestamptz.py +++ b/tests/columns/test_timestamptz.py @@ -52,9 +52,13 @@ def test_timestamptz_timezone_aware(self): Test storing a timezone aware timestamp. """ dt_args = dict(year=2020, month=1, day=1, hour=12, minute=0, second=0) - created_on_utc = datetime.datetime(**dt_args, tzinfo=ZoneInfo("UTC")) + created_on_utc = datetime.datetime( + **dt_args, + tzinfo=datetime.timezone.utc, + ) created_on_local = datetime.datetime( - **dt_args, tzinfo=ZoneInfo("Europe/Tallinn") + **dt_args, + tzinfo=ZoneInfo("Europe/Tallinn"), ) row = MyTable( created_on_utc=created_on_utc, created_on_local=created_on_local @@ -68,6 +72,7 @@ def test_timestamptz_timezone_aware(self): MyTable.objects().where(eq(p_key, p_key_name)).first().run_sync() ) assert result is not None + self.assertEqual(result.created_on_utc, created_on_utc) self.assertEqual(result.created_on_local, created_on_local) From c64efa93555f0d5943ef668372be359ffec6e2e1 Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Fri, 5 Apr 2024 21:27:48 +0100 Subject: [PATCH 30/32] fix sqlite --- piccolo/columns/column_types.py | 5 ++++- piccolo/query/methods/objects.py | 2 ++ piccolo/query/methods/select.py | 24 ++++++++++++++++++------ 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 655440a0..9828aa98 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -1037,7 +1037,10 @@ def get_select_string( select_string = self._meta.get_full_name(with_alias=False) if self.tz != ZoneInfo("UTC"): - select_string += f" AT TIME ZONE '{self.tz.key}'" + # SQLite doesn't support `AT TIME ZONE`, so we have to do it in + # Python instead (see ``Select.response_handler``). + if self._meta.engine_type in ("postgres", "cockroach"): + select_string += f" AT TIME ZONE '{self.tz.key}'" if with_alias: alias = self._get_alias() diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index 7a9ab295..e1a38031 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -310,6 +310,8 @@ async def response_handler(self, response): ) # TODO - would be good to cache this somehow + # I only really need the bit that sets the columns, so could move that + # into a separate method. Or pass in args for which delegates to copy? def _get_select_query(self) -> Select: select = Select(table=self.table) diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index 850bb192..bdb0c9c5 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -575,7 +575,6 @@ async def response_handler(self, response): ) ####################################################################### - # Make sure any Timestamptz values are timezone aware. # This happens when we use `AS TIME ZONE` which returns a naive # datetime. @@ -589,19 +588,32 @@ async def response_handler(self, response): ] if timestamptz_columns: + + is_sqlite = self.table._meta.db.engine_type == "sqlite" + for column in timestamptz_columns: if column.tz == ZoneInfo("UTC"): + # The values already come back as UTC, so nothing to do. continue alias = column._get_alias() for row in response: timestamp_value = row.get(alias) - if ( - isinstance(timestamp_value, datetime.datetime) - and timestamp_value.tzinfo is None - ): - row[alias] = timestamp_value.replace(tzinfo=column.tz) + if isinstance(timestamp_value, datetime.datetime): + if is_sqlite: + # SQLite doesn't support the `AT TIME ZONE` clause + # so we're just getting the values back as UTC, + # so we need to convert them here. + row[alias] = timestamp_value.astimezone(column.tz) + else: + # Postgres and Cockroach support the + # `AT TIME ZONE` clause, so the values are already + # correct, but the datetime object doesn't contain + # a tz value, so set it here. + row[alias] = timestamp_value.replace( + tzinfo=column.tz + ) ####################################################################### From 1d98c3662c2dd44087946021f37983d79d148e26 Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Sat, 6 Apr 2024 11:35:01 +0100 Subject: [PATCH 31/32] refactor `objects` queries to proxy to `select` --- piccolo/query/methods/objects.py | 123 +++++++++---------------------- 1 file changed, 36 insertions(+), 87 deletions(-) diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index e1a38031..752af60d 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -8,18 +8,7 @@ from piccolo.engine.base import Batch from piccolo.query.base import Query from piccolo.query.methods.select import Select -from piccolo.query.mixins import ( - AsOfDelegate, - CallbackDelegate, - CallbackType, - LimitDelegate, - OffsetDelegate, - OrderByDelegate, - OrderByRaw, - OutputDelegate, - PrefetchDelegate, - WhereDelegate, -) +from piccolo.query.mixins import CallbackType, OrderByRaw, PrefetchDelegate from piccolo.query.proxy import Proxy from piccolo.querystring import QueryString from piccolo.utils.sync import run_sync @@ -124,7 +113,7 @@ async def run( results = objects[0] if objects else None modified_response: t.Optional[TableInstance] = ( - await self.query.callback_delegate.invoke( + await self.query._select_query.callback_delegate.invoke( results=results, kind=CallbackType.success ) ) @@ -184,15 +173,8 @@ class Objects( """ __slots__ = ( - "nested", - "as_of_delegate", - "limit_delegate", - "offset_delegate", - "order_by_delegate", - "output_delegate", - "callback_delegate", + "_select_query", "prefetch_delegate", - "where_delegate", ) def __init__( @@ -202,19 +184,16 @@ def __init__( **kwargs, ): super().__init__(table, **kwargs) - self.as_of_delegate = AsOfDelegate() - self.limit_delegate = LimitDelegate() - self.offset_delegate = OffsetDelegate() - self.order_by_delegate = OrderByDelegate() - self.output_delegate = OutputDelegate() - self.output_delegate._output.as_objects = True - self.callback_delegate = CallbackDelegate() + self._select_query = Select(table=self.table) + self._select_query.output_delegate._output.as_objects = True self.prefetch_delegate = PrefetchDelegate() self.prefetch(*prefetch) - self.where_delegate = WhereDelegate() + + ########################################################################### + # Proxying to select query def output(self: Self, load_json: bool = False) -> Self: - self.output_delegate.output( + self._select_query.output_delegate.output( as_list=False, as_json=False, load_json=load_json ) return self @@ -225,55 +204,48 @@ def callback( *, on: CallbackType = CallbackType.success, ) -> Self: - self.callback_delegate.callback(callbacks, on=on) + self._select_query.callback(callbacks, on=on) return self def as_of(self, interval: str = "-1s") -> Objects: - if self.engine_type != "cockroach": - raise NotImplementedError("Only CockroachDB supports AS OF") - self.as_of_delegate.as_of(interval) + self._select_query.as_of(interval=interval) return self def limit(self: Self, number: int) -> Self: - self.limit_delegate.limit(number) - return self - - def prefetch( - self: Self, *fk_columns: t.Union[ForeignKey, t.List[ForeignKey]] - ) -> Self: - self.prefetch_delegate.prefetch(*fk_columns) + self._select_query.limit(number=number) return self def offset(self: Self, number: int) -> Self: - self.offset_delegate.offset(number) + self._select_query.offset(number=number) return self def order_by( self: Self, *columns: t.Union[Column, str, OrderByRaw], ascending=True ) -> Self: - _columns: t.List[t.Union[Column, OrderByRaw]] = [] - for column in columns: - if isinstance(column, str): - _columns.append(self.table._meta.get_column_by_name(column)) - else: - _columns.append(column) - - self.order_by_delegate.order_by(*_columns, ascending=ascending) + self._select_query.order_by(*columns, ascending=ascending) return self def where(self: Self, *where: Combinable) -> Self: - self.where_delegate.where(*where) + self._select_query.where(*where) + return self + + ########################################################################### + + def prefetch( + self: Self, *fk_columns: t.Union[ForeignKey, t.List[ForeignKey]] + ) -> Self: + self.prefetch_delegate.prefetch(*fk_columns) return self ########################################################################### def first(self: Self) -> First[TableInstance]: - self.limit_delegate.limit(1) + self._select_query.limit(1) return First[TableInstance](query=self) def get(self: Self, where: Combinable) -> Get[TableInstance]: - self.where_delegate.where(where) - self.limit_delegate.limit(1) + self._select_query.where(where) + self._select_query.limit(1) return Get[TableInstance](query=First[TableInstance](query=self)) def get_or_create( @@ -298,32 +270,17 @@ async def batch( node: t.Optional[str] = None, **kwargs, ) -> Batch: - if batch_size: - kwargs.update(batch_size=batch_size) - if node: - kwargs.update(node=node) - return await self.table._meta.db.batch(self, **kwargs) + return await self._get_select_query().batch( + batch_size=batch_size, node=node, **kwargs + ) async def response_handler(self, response): return await self._get_select_query().response_handler( response=response ) - # TODO - would be good to cache this somehow - # I only really need the bit that sets the columns, so could move that - # into a separate method. Or pass in args for which delegates to copy? def _get_select_query(self) -> Select: - select = Select(table=self.table) - - for attr in ( - "as_of_delegate", - "limit_delegate", - "where_delegate", - "offset_delegate", - "output_delegate", - "order_by_delegate", - ): - setattr(select, attr, getattr(self, attr)) + select = self._select_query if self.prefetch_delegate.fk_columns: select.columns(*self.table.all_columns()) @@ -353,20 +310,12 @@ async def run( in_pool: bool = True, use_callbacks: bool = True, ) -> t.List[TableInstance]: - results = await super().run(node=node, in_pool=in_pool) - - if use_callbacks: - # With callbacks, the user can return any data that they want. - # Assume that most of the time they will still return a list of - # Table instances. - modified: t.List[TableInstance] = ( - await self.callback_delegate.invoke( - results, kind=CallbackType.success - ) - ) - return modified - else: - return results + results = await self._get_select_query().run( + node=node, + in_pool=in_pool, + use_callbacks=use_callbacks, + ) + return t.cast(t.List[TableInstance], results) def __await__( self, From 2c75e7081905b9bdf4603944cb67f9747f5d42f8 Mon Sep 17 00:00:00 2001 From: Daniel Townsend Date: Sat, 6 Apr 2024 11:35:36 +0100 Subject: [PATCH 32/32] rename `tz` to `at_time_zone` --- piccolo/columns/column_types.py | 24 +++++++++++++----------- piccolo/query/methods/select.py | 9 +++++---- tests/columns/test_timestamptz.py | 17 +++++++++++------ 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 9828aa98..e429e061 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -967,7 +967,7 @@ class Timestamptz(Column): from zoneinfo import ZoneInfo class TallinnConcerts(Table): - event_start = Timestamptz(tz=ZoneInfo("Europe/Tallinn")) + event_start = Timestamptz(at_time_zone=ZoneInfo("Europe/Tallinn")) # Create >>> await TallinnConcerts( @@ -999,7 +999,7 @@ class TallinnConcerts(Table): def __init__( self, default: TimestamptzArg = TimestamptzNow(), - tz: ZoneInfo = ZoneInfo("UTC"), + at_time_zone: ZoneInfo = ZoneInfo("UTC"), **kwargs, ) -> None: self._validate_default( @@ -1007,26 +1007,28 @@ def __init__( ) if isinstance(default, datetime): - default = TimestamptzCustom.from_datetime(default, tz) + default = TimestamptzCustom.from_datetime(default, at_time_zone) if default == datetime.now: - default = TimestamptzNow(tz) + default = TimestamptzNow(tz=at_time_zone) - self.tz = tz + self._at_time_zone = at_time_zone self.default = default - kwargs.update({"default": default, "tz": tz}) + kwargs.update({"default": default, "at_time_zone": at_time_zone}) super().__init__(**kwargs) ########################################################################### - def at_time_zone(self, tz: t.Union[ZoneInfo, str]) -> Timestamptz: + def at_time_zone(self, time_zone: t.Union[ZoneInfo, str]) -> Timestamptz: """ By default, the database returns the value in UTC. This lets us get the value converted to the specified timezone. """ - tz = ZoneInfo(tz) if isinstance(tz, str) else tz + time_zone = ( + ZoneInfo(time_zone) if isinstance(time_zone, str) else time_zone + ) instance = self.copy() - instance.tz = tz + instance._at_time_zone = time_zone return instance ########################################################################### @@ -1036,11 +1038,11 @@ def get_select_string( ) -> str: select_string = self._meta.get_full_name(with_alias=False) - if self.tz != ZoneInfo("UTC"): + if self._at_time_zone != ZoneInfo("UTC"): # SQLite doesn't support `AT TIME ZONE`, so we have to do it in # Python instead (see ``Select.response_handler``). if self._meta.engine_type in ("postgres", "cockroach"): - select_string += f" AT TIME ZONE '{self.tz.key}'" + select_string += f" AT TIME ZONE '{self._at_time_zone.key}'" if with_alias: alias = self._get_alias() diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index bdb0c9c5..644c23eb 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -588,11 +588,10 @@ async def response_handler(self, response): ] if timestamptz_columns: - is_sqlite = self.table._meta.db.engine_type == "sqlite" for column in timestamptz_columns: - if column.tz == ZoneInfo("UTC"): + if column._at_time_zone == ZoneInfo("UTC"): # The values already come back as UTC, so nothing to do. continue @@ -605,14 +604,16 @@ async def response_handler(self, response): # SQLite doesn't support the `AT TIME ZONE` clause # so we're just getting the values back as UTC, # so we need to convert them here. - row[alias] = timestamp_value.astimezone(column.tz) + row[alias] = timestamp_value.astimezone( + column._at_time_zone + ) else: # Postgres and Cockroach support the # `AT TIME ZONE` clause, so the values are already # correct, but the datetime object doesn't contain # a tz value, so set it here. row[alias] = timestamp_value.replace( - tzinfo=column.tz + tzinfo=column._at_time_zone ) ####################################################################### diff --git a/tests/columns/test_timestamptz.py b/tests/columns/test_timestamptz.py index 4ed3d5c9..d0748a1b 100644 --- a/tests/columns/test_timestamptz.py +++ b/tests/columns/test_timestamptz.py @@ -17,8 +17,8 @@ class MyTable(Table): - created_on_utc = Timestamptz(tz=UTC_TZ) - created_on_local = Timestamptz(tz=LOCAL_TZ) + created_on_utc = Timestamptz(at_time_zone=UTC_TZ) + created_on_local = Timestamptz(at_time_zone=LOCAL_TZ) class MyTableDefault(Table): @@ -27,16 +27,21 @@ class MyTableDefault(Table): `Timestamptz`. """ - created_on = Timestamptz(default=TimestamptzNow(tz=LOCAL_TZ), tz=LOCAL_TZ) + created_on = Timestamptz( + default=TimestamptzNow(tz=LOCAL_TZ), + at_time_zone=LOCAL_TZ, + ) created_on_offset = Timestamptz( - default=TimestamptzOffset(days=1, tz=LOCAL_TZ), tz=LOCAL_TZ + default=TimestamptzOffset(days=1, tz=LOCAL_TZ), + at_time_zone=LOCAL_TZ, ) created_on_custom = Timestamptz( - default=TimestamptzCustom(year=2021, tz=LOCAL_TZ), tz=LOCAL_TZ + default=TimestamptzCustom(year=2021, tz=LOCAL_TZ), + at_time_zone=LOCAL_TZ, ) created_on_datetime = Timestamptz( default=datetime.datetime(year=2020, month=1, day=1, tzinfo=LOCAL_TZ), - tz=LOCAL_TZ, + at_time_zone=LOCAL_TZ, )