Skip to content

Commit

Permalink
Merge pull request #36 from flaport/pydantic-v2
Browse files Browse the repository at this point in the history
Bump to Pydantic v2. Add caching, validation and serialization improvements.
  • Loading branch information
flaport committed Jun 20, 2024
2 parents fcbe1b3 + 799c887 commit 010e51b
Show file tree
Hide file tree
Showing 38 changed files with 28,369 additions and 1,232 deletions.
4 changes: 2 additions & 2 deletions examples/00_introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@
"\n",
"cells = mw.create_cells(\n",
" structures=structures,\n",
" mesh=mw.Mesh2d(\n",
" mesh=mw.Mesh2D(\n",
" x=np.linspace(-1, 1, 101),\n",
" y=np.linspace(-1, 1, 101),\n",
" # specify possible conformal mesh specifications here:\n",
Expand Down Expand Up @@ -454,7 +454,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.11.9"
},
"papermill": {
"default_parameters": {},
Expand Down
4 changes: 2 additions & 2 deletions examples/01_gds_taper.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@
"\n",
"cells = mw.create_cells(\n",
" structures=structs,\n",
" mesh=mw.Mesh2d(\n",
" mesh=mw.Mesh2D(\n",
" x=np.linspace(-0.75, 0.75, mesh + 1),\n",
" y=np.linspace(-0.3, 0.5, mesh + 1),\n",
" ),\n",
Expand Down Expand Up @@ -532,7 +532,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.11.9"
},
"papermill": {
"default_parameters": {},
Expand Down
4 changes: 2 additions & 2 deletions examples/02_taper_length_sweep.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
" structures = create_structures(length=length)\n",
" cells = mw.create_cells(\n",
" structures=structures,\n",
" mesh=mw.Mesh2d(\n",
" mesh=mw.Mesh2D(\n",
" x=np.linspace(-2, 2, 101),\n",
" y=np.linspace(-2, 2, 101),\n",
" # specify possible conformal mesh specifications here:\n",
Expand Down Expand Up @@ -450,7 +450,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.11.9"
},
"papermill": {
"default_parameters": {},
Expand Down
4 changes: 2 additions & 2 deletions examples/03_unequal_number_of_modes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
" structures = create_structures(length=length)\n",
" cells = mw.create_cells(\n",
" structures=structures,\n",
" mesh=mw.Mesh2d(\n",
" mesh=mw.Mesh2D(\n",
" x=np.linspace(-2, 2, 101),\n",
" y=np.linspace(-2, 2, 101),\n",
" # specify possible conformal mesh specifications here:\n",
Expand Down Expand Up @@ -366,7 +366,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.11.9"
},
"papermill": {
"default_parameters": {},
Expand Down
71 changes: 11 additions & 60 deletions meow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,57 +3,10 @@
__author__ = "Floris Laporte"
__version__ = "0.10.0"

import warnings

# Silence Excessive Logging...

try:
from loguru import logger

logger.disable("gdsfactory")
except ImportError:
pass

try:
from numexpr.utils import log

log.setLevel("CRITICAL")
except ImportError:
pass

try:
from rich import pretty

old_install = pretty.install
pretty.install = lambda *_, **__: None
import tidy3d

pretty.install = old_install
except ImportError:
pass


try:
import sax

warnings.filterwarnings(action="ignore", module="sax")
except ImportError:
pass

from . import base_model as base_model
from . import cell as cell
from . import cross_section as cross_section
from . import eme as eme
from . import environment as environment
from . import fde as fde
from . import gds_structures as gds_structures
from . import geometries as geometries
from . import materials as materials
from . import mesh as mesh
from . import mode as mode
from . import structures as structures

# from . import visualize as visualize
from .array import Dim as Dim
from .array import DType as DType
from .array import NDArray as NDArray
from .array import Shape as Shape
from .base_model import BaseModel as BaseModel
from .cell import Cell as Cell
from .cell import create_cells as create_cells
Expand All @@ -63,7 +16,6 @@
from .eme import compute_propagation_s_matrices as compute_propagation_s_matrices
from .eme import compute_propagation_s_matrix as compute_propagation_s_matrix
from .eme import compute_s_matrix as compute_s_matrix
from .eme import compute_s_matrix_sax as compute_s_matrix_sax
from .eme import select_ports as select_ports
from .environment import Environment as Environment
from .fde import compute_modes as compute_modes
Expand All @@ -73,35 +25,34 @@
from .gds_structures import extrude_gds as extrude_gds
from .geometries import Box as Box
from .geometries import Geometry2D as Geometry2D
from .geometries import Geometry2DBase as Geometry2DBase
from .geometries import Geometry3D as Geometry3D
from .geometries import Geometry3DBase as Geometry3DBase
from .geometries import Prism as Prism
from .geometries import Rectangle as Rectangle
from .integrate import integrate_2d as integrate_2d
from .integrate import integrate_interpolate_2d as integrate_interpolate_2d
from .materials import IndexMaterial as IndexMaterial
from .materials import Material as Material
from .materials import MaterialBase as MaterialBase
from .materials import SampledMaterial as SampledMaterial
from .materials import TidyMaterial as TidyMaterial
from .materials import silicon as silicon
from .materials import silicon_nitride as silicon_nitride
from .materials import silicon_oxide as silicon_oxide
from .mesh import Mesh as Mesh
from .mesh import Mesh2D as Mesh2D
from .mesh import Mesh2d as Mesh2d
from .mode import Mode as Mode
from .mode import electric_energy as electric_energy
from .mode import electric_energy_density as electric_energy_density
from .mode import energy as energy
from .mode import energy_density as energy_density
from .mode import inner_product as inner_product
from .mode import inner_product_conj as inner_product_conj
from .mode import invert_mode as invert_mode
from .mode import is_pml_mode as is_pml_mode
from .mode import magnetic_energy as magnetic_energy
from .mode import magnetic_energy_density as magnetic_energy_density
from .mode import normalize_energy as normalize_energy
from .mode import normalize_product as normalize_product
from .mode import te_fraction as te_fraction
from .mode import zero_phase as zero_phase
from .structures import Structure as Structure
from .structures import Structure2D as Structure2D
from .structures import Structure3D as Structure3D
from .visualize import vis as vis
from .visualize import visualize as vis
from .visualize import visualize as visualize
162 changes: 162 additions & 0 deletions meow/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
""" meow array tools for pydantic models """

from __future__ import annotations

from functools import partial
from typing import Annotated, Any

import numpy as np
from pydantic import (
AfterValidator,
BaseModel,
BeforeValidator,
GetPydanticSchema,
InstanceOf,
PlainSerializer,
)


class SerializedArray(BaseModel):
values: list[Any]
shape: tuple[int, ...]
dtype: str

@classmethod
def from_array(cls, x: np.ndarray):
x = np.asarray(x)
shape = x.shape
dtype = str(x.dtype)
if dtype == "complex64":
_x = x.ravel().view("float32")
elif dtype == "complex128":
_x = x.ravel().view("float64")
else:
_x = x.ravel()
return cls(shape=shape, dtype=dtype, values=_x.tolist())

def to_array(self):
if self.dtype == "complex128":
arr = np.asarray(self.values, dtype="float64").view("complex128")
elif self.dtype == "complex64":
arr = np.asarray(self.values, dtype="float32").view("complex64")
else:
arr = np.asarray(self.values, dtype=self.dtype)

if not self.shape:
return arr
else:
return arr.reshape(*self.shape)


def _validate_ndarray(x: Any):
if isinstance(x, dict):
return SerializedArray.model_validate(x).to_array()
elif isinstance(x, SerializedArray):
return x.to_array()
else:
try:
return np.asarray(x)
except Exception:
raise ValueError(f"Could not validate {x} as an array")


def _serialize_ndarray(x: np.ndarray):
return SerializedArray.from_array(x).model_dump()


def _coerce_immutable(x: np.ndarray):
x.setflags(write=False)
return x


def _coerce_shape(arr: np.ndarray, shape: tuple[int, ...]):
shape_to_coerce = []
for i in range(len(shape)):
n = shape[-i - 1]
if n < 0 and i < len(arr.shape):
n = arr.shape[-i - 1]
shape_to_coerce.insert(0, n)
return np.broadcast_to(arr, tuple(shape_to_coerce))


def _assert_shape(arr: np.ndarray, shape: tuple[int, ...]):
shape_to_assert = []
for i in range(len(shape)):
n = shape[-i - 1]
if n < 0 and i < len(arr.shape):
n = arr.shape[-i - 1]
shape_to_assert.insert(0, n)
shape = tuple(shape_to_assert)
if not arr.shape == shape:
raise ValueError(f"Expected an array of shape {shape}. Got {arr.shape}.")
return arr


def _coerce_dim(arr: np.ndarray, ndim: int):
if arr.ndim > ndim:
if arr.shape[0] < 2:
return _coerce_dim(arr[0], ndim)
else:
raise ValueError(
f"Can't coerce arr with shape {arr.shape} into an {ndim}D array."
)
elif arr.ndim < ndim:
return _coerce_dim(arr[None], ndim)
else:
return arr


def _assert_dim(arr: np.ndarray, ndim: int):
if not arr.ndim == ndim:
raise ValueError(f"Expected a {ndim}D array. Got a {arr.ndim}D array.")
return arr


def _coerce_dtype(arr: np.ndarray, dtype: str):
return np.asarray(arr, dtype=dtype)


def _assert_dtype(arr: np.ndarray, dtype: str):
if not str(arr.dtype).startswith(dtype):
raise ValueError(
f"Expected an array with dtype {dtype!r}. Got an array with dtype {str(arr.dtype)!r}."
)
return arr


def Dim(ndim: int, coerce: bool = True):
f = _coerce_dim if coerce else _assert_dim
return AfterValidator(partial(f, ndim=ndim))


def DType(dtype: str, coerce: bool = True):
f = _coerce_dtype if coerce else _assert_dtype
return AfterValidator(partial(f, dtype=dtype))


def Shape(*shape: int, coerce: bool = True):
f = _coerce_shape if coerce else _assert_shape
return AfterValidator(partial(f, shape=shape))


def _get_ndarray_core_schema(_t, h):
return h(InstanceOf[np.ndarray])


def _get_ndarray_json_schema(_t, _h):
return SerializedArray.model_json_schema()


ArraySchema = GetPydanticSchema(_get_ndarray_core_schema, _get_ndarray_json_schema)

NDArray = Annotated[
np.ndarray,
ArraySchema,
PlainSerializer(_serialize_ndarray),
BeforeValidator(_validate_ndarray),
AfterValidator(_coerce_immutable),
]

ComplexArray2D = Annotated[NDArray, Dim(2), DType("complex128")]

Complex = Annotated[NDArray, Dim(0), DType("complex128")]
Loading

0 comments on commit 010e51b

Please sign in to comment.