Skip to content

Commit

Permalink
Force concrete subclasses of BaseScaler implement kernel_radius, add …
Browse files Browse the repository at this point in the history
…API to add abstract classes from other packages
  • Loading branch information
Setsugennoao committed Dec 14, 2023
1 parent 87f93fd commit 3ddc52b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 9 deletions.
52 changes: 43 additions & 9 deletions vskernels/kernels/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

from stgpytools import inject_kwargs_params
from vstools import (
CustomIndexError, CustomValueError, FieldBased, FuncExceptT, GenericVSFunction, HoldsVideoFormatT, KwargsT, Matrix,
MatrixT, T, VideoFormatT, check_correct_subsampling, check_variable_resolution, core, depth, expect_bits,
get_subclasses, get_video_format, inject_self, vs, vs_object
CustomIndexError, CustomRuntimeError, CustomValueError, FieldBased, FuncExceptT, GenericVSFunction,
HoldsVideoFormatT, KwargsT, Matrix, MatrixT, T, VideoFormatT, check_correct_subsampling, check_variable_resolution,
core, depth, expect_bits, get_subclasses, get_video_format, inject_self, vs, vs_object
)

from ..exceptions import UnknownDescalerError, UnknownKernelError, UnknownResamplerError, UnknownScalerError
Expand All @@ -22,17 +22,14 @@
'Kernel', 'KernelT'
]

_finished_loading_abstract = False


def _default_kernel_radius(cls: type[T], self: T) -> int:
if hasattr(self, '_static_kernel_radius'):
return ceil(self._static_kernel_radius) # type: ignore

try:
return super(cls, self).kernel_radius # type: ignore
except AttributeError:
...

raise NotImplementedError
return super(cls, self).kernel_radius # type: ignore


@lru_cache
Expand Down Expand Up @@ -132,6 +129,43 @@ class BaseScaler(vs_object):
def __init__(self, **kwargs: Any) -> None:
self.kwargs = kwargs

def __init_subclass__(cls) -> None:
if not _finished_loading_abstract:
return

from .zimg import ZimgComplexKernel
from ..util import abstract_kernels

if cls in abstract_kernels:
return

import sys

module = sys.modules[cls.__module__]

if hasattr(module, '__abstract__'):
if cls.__name__ in module.__abstract__:
abstract_kernels.append(cls) # type: ignore
return

if 'kernel_radius' in cls.__dict__.keys():
return

mro = [cls, *({*cls.mro()} - {*ZimgComplexKernel.mro()})]

for sub_cls in mro:
if hasattr(sub_cls, '_static_kernel_radius'):
break

try:
if hasattr(sub_cls, 'kernel_radius'):
break
except Exception:
...
else:
if mro:
raise CustomRuntimeError('You must implement kernel_radius when inheriting BaseScaler!', reason=cls)

@classmethod
def from_param(
cls: type[BaseScalerT], scaler: str | type[BaseScalerT] | BaseScalerT | None = None, /,
Expand Down
5 changes: 5 additions & 0 deletions vskernels/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,8 @@ def resample_to(
return Point.resample(clip, out_fmt, matrix)

return resampler.resample(clip, out_fmt, matrix)


if True:
from .kernels import abstract
abstract._finished_loading_abstract = True

0 comments on commit 3ddc52b

Please sign in to comment.