Skip to content

Commit

Permalink
Add aliases to typing for clearer params
Browse files Browse the repository at this point in the history
  • Loading branch information
Setsugennoao committed Dec 4, 2023
1 parent 8944f67 commit 63e85b2
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 33 deletions.
19 changes: 10 additions & 9 deletions vskernels/kernels/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)

from ..exceptions import UnknownDescalerError, UnknownKernelError, UnknownResamplerError, UnknownScalerError
from ..types import LeftShift, TopShift

__all__ = [
'Scaler', 'ScalerT',
Expand Down Expand Up @@ -173,14 +174,14 @@ class Scaler(BaseScaler):
@inject_self.cached
@inject_kwargs_params
def scale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float] = (0, 0), **kwargs: Any
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0), **kwargs: Any
) -> vs.VideoNode:
check_correct_subsampling(clip, width, height)
return self.scale_function(clip, **self.get_scale_args(clip, shift, width, height, **kwargs))

@inject_self.cached
def multi(
self, clip: vs.VideoNode, multi: float = 2, shift: tuple[float, float] = (0, 0), **kwargs: Any
self, clip: vs.VideoNode, multi: float = 2, shift: tuple[TopShift, LeftShift] = (0, 0), **kwargs: Any
) -> vs.VideoNode:
assert check_variable_resolution(clip, self.multi)

Expand All @@ -194,7 +195,7 @@ def multi(
return self.scale(clip, dst_width, dst_height, shift, **kwargs)

def get_scale_args(
self, clip: vs.VideoNode, shift: tuple[float, float] = (0, 0),
self, clip: vs.VideoNode, shift: tuple[TopShift, LeftShift] = (0, 0),
width: int | None = None, height: int | None = None,
*funcs: Callable[..., Any], **kwargs: Any
) -> KwargsT:
Expand Down Expand Up @@ -225,7 +226,7 @@ class Descaler(BaseScaler):
@inject_self.cached
@inject_kwargs_params
def descale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float] = (0, 0), **kwargs: Any
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0), **kwargs: Any
) -> vs.VideoNode:
check_correct_subsampling(clip, width, height)

Expand Down Expand Up @@ -255,7 +256,7 @@ def descale( # type: ignore[override]
return depth(descaled, bits)

def get_descale_args(
self, clip: vs.VideoNode, shift: tuple[float, float] = (0, 0),
self, clip: vs.VideoNode, shift: tuple[TopShift, LeftShift] = (0, 0),
width: int | None = None, height: int | None = None,
*funcs: Callable[..., Any], **kwargs: Any
) -> KwargsT:
Expand Down Expand Up @@ -320,7 +321,7 @@ class Kernel(Scaler, Descaler, Resampler): # type: ignore
@overload # type: ignore
@inject_self.cached
@inject_kwargs_params
def shift(self, clip: vs.VideoNode, shift: tuple[float, float] = (0, 0), **kwargs: Any) -> vs.VideoNode:
def shift(self, clip: vs.VideoNode, shift: tuple[TopShift, LeftShift] = (0, 0), **kwargs: Any) -> vs.VideoNode:
...

@overload # type: ignore
Expand All @@ -343,7 +344,7 @@ def shift(

n_planes = clip.format.num_planes

def _shift(src: vs.VideoNode, shift: tuple[float, float] = (0, 0)) -> vs.VideoNode:
def _shift(src: vs.VideoNode, shift: tuple[TopShift, LeftShift] = (0, 0)) -> vs.VideoNode:
return self.scale_function(src, **self.get_scale_args(src, shift, **kwargs))

if not shifts_or_top and not shift_left:
Expand Down Expand Up @@ -469,7 +470,7 @@ def get_params_args(
return dict(width=width, height=height) | kwargs

def get_scale_args(
self, clip: vs.VideoNode, shift: tuple[float, float] = (0, 0),
self, clip: vs.VideoNode, shift: tuple[TopShift, LeftShift] = (0, 0),
width: int | None = None, height: int | None = None,
*funcs: Callable[..., Any], **kwargs: Any
) -> KwargsT:
Expand All @@ -480,7 +481,7 @@ def get_scale_args(
)

def get_descale_args(
self, clip: vs.VideoNode, shift: tuple[float, float] = (0, 0),
self, clip: vs.VideoNode, shift: tuple[TopShift, LeftShift] = (0, 0),
width: int | None = None, height: int | None = None,
*funcs: Callable[..., Any], **kwargs: Any
) -> KwargsT:
Expand Down
23 changes: 12 additions & 11 deletions vskernels/kernels/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from stgpytools import inject_kwargs_params
from vstools import Dar, KwargsT, Resolution, Sar, VSFunctionAllArgs, check_correct_subsampling, inject_self, vs

from ..types import Center, LeftShift, Slope, TopShift
from .abstract import Descaler, Kernel, Resampler, Scaler

__all__ = [
Expand Down Expand Up @@ -41,8 +42,8 @@ def _linear_op(op_name: str) -> Any:
@inject_kwargs_params
def func(
self: _BaseLinearOperation, clip: vs.VideoNode, width: int, height: int,
shift: tuple[float, float] = (0, 0), *,
linear: bool = False, sigmoid: bool | tuple[float, float] = False, **kwargs: Any
shift: tuple[TopShift, LeftShift] = (0, 0), *,
linear: bool = False, sigmoid: bool | tuple[Slope, Center] = False, **kwargs: Any
) -> vs.VideoNode:
from ..util import LinearLight

Expand Down Expand Up @@ -73,8 +74,8 @@ class LinearScaler(_BaseLinearOperation, Scaler):
@inject_self.cached
@inject_kwargs_params
def scale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float] = (0, 0),
*, linear: bool = False, sigmoid: bool | tuple[float, float] = False, **kwargs: Any
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0),
*, linear: bool = False, sigmoid: bool | tuple[Slope, Center] = False, **kwargs: Any
) -> vs.VideoNode:
...
else:
Expand All @@ -86,8 +87,8 @@ class LinearDescaler(_BaseLinearOperation, Descaler):
@inject_self.cached
@inject_kwargs_params
def descale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float] = (0, 0),
*, linear: bool = False, sigmoid: bool | tuple[float, float] = False, **kwargs: Any
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0),
*, linear: bool = False, sigmoid: bool | tuple[Slope, Center] = False, **kwargs: Any
) -> vs.VideoNode:
...
else:
Expand Down Expand Up @@ -116,9 +117,9 @@ def _get_kwargs_keep_ar(
return kwargs

def _handle_crop_resize_kwargs( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float],
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift],
sar: Sar | bool | float | None, dar: Dar | bool | float | None, **kwargs: Any
) -> tuple[KwargsT, tuple[float, float], Sar | None]:
) -> tuple[KwargsT, tuple[TopShift, LeftShift], Sar | None]:
kwargs.setdefault('src_top', kwargs.pop('sy', shift[0]))
kwargs.setdefault('src_left', kwargs.pop('sx', shift[1]))
kwargs.setdefault('src_width', kwargs.pop('sw', clip.width))
Expand Down Expand Up @@ -162,7 +163,7 @@ def _handle_crop_resize_kwargs( # type: ignore[override]
@inject_self.cached
@inject_kwargs_params
def scale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float] = (0, 0), *,
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0), *,
sar: Sar | float | bool | None = None, dar: Dar | float | bool | None = None, keep_ar: bool = False,
**kwargs: Any
) -> vs.VideoNode:
Expand All @@ -189,10 +190,10 @@ class ComplexScaler(LinearScaler, KeepArScaler):
@inject_self.cached
@inject_kwargs_params
def scale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float] = (0, 0),
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0),
*,
sar: Sar | bool | float | None = None, dar: Dar | bool | float | None = None, keep_ar: bool = False,
linear: bool = False, sigmoid: bool | tuple[float, float] = False,
linear: bool = False, sigmoid: bool | tuple[Slope, Center] = False,
**kwargs: Any
) -> vs.VideoNode:
return super().scale(
Expand Down
5 changes: 3 additions & 2 deletions vskernels/kernels/fmtconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from stgpytools import inject_kwargs_params
from vstools import VideoFormatT, VSFunction, core, inject_self, vs

from ..types import LeftShift, TopShift
from .abstract import Resampler
from .bicubic import Bicubic
from .complex import ComplexScaler
Expand Down Expand Up @@ -102,7 +103,7 @@ def _clean_args(self, **kwargs: Any) -> dict[str, Any]:
return kwargs

def get_scale_args(
self, clip: vs.VideoNode, shift: tuple[float, float] = (0, 0),
self, clip: vs.VideoNode, shift: tuple[TopShift, LeftShift] = (0, 0),
width: int | None = None, height: int | None = None,
*funcs: Callable[..., Any], **kwargs: Any
) -> dict[str, Any]:
Expand All @@ -124,7 +125,7 @@ def get_params_args(
@overload # type: ignore
@inject_self.cached
@inject_kwargs_params
def shift(self, clip: vs.VideoNode, shift: tuple[float, float] = (0, 0), **kwargs: Any) -> vs.VideoNode:
def shift(self, clip: vs.VideoNode, shift: tuple[TopShift, LeftShift] = (0, 0), **kwargs: Any) -> vs.VideoNode:
...

@overload # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion vskernels/kernels/impulse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from stgpytools import inject_kwargs_params
from vstools import inject_self, vs

from ..types import LeftShift, TopShift
from .fmtconv import FmtConv

__all__ = [
Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(self, impulse: Sequence[float], oversample: int = 8, taps: float =
@inject_self.cached
@inject_kwargs_params
def scale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float] = (-0.125, -0.125), **kwargs: Any
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (-0.125, -0.125), **kwargs: Any
) -> vs.VideoNode:
return super().scale(clip, width, height, shift, **kwargs)

Expand Down
7 changes: 4 additions & 3 deletions vskernels/kernels/placebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from stgpytools import inject_kwargs_params
from vstools import Transfer, TransferT, core, fallback, inject_self, vs

from ..types import Center, LeftShift, Slope, TopShift
from .complex import LinearScaler

__all__ = [
Expand Down Expand Up @@ -56,8 +57,8 @@ def __init__(
@inject_self.cached
@inject_kwargs_params
def scale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float] = (0, 0),
*, linear: bool = True, sigmoid: bool | tuple[float, float] = True, curve: TransferT | None = None,
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0),
*, linear: bool = True, sigmoid: bool | tuple[Slope, Center] = True, curve: TransferT | None = None,
**kwargs: Any
) -> vs.VideoNode:
return super().scale(
Expand All @@ -66,7 +67,7 @@ def scale( # type: ignore[override]
)

def get_scale_args(
self, clip: vs.VideoNode, shift: tuple[float, float] = (0, 0),
self, clip: vs.VideoNode, shift: tuple[TopShift, LeftShift] = (0, 0),
width: int | None = None, height: int | None = None,
*funcs: Callable[..., Any], **kwargs: Any
) -> dict[str, Any]:
Expand Down
7 changes: 4 additions & 3 deletions vskernels/kernels/zimg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from stgpytools import inject_kwargs_params
from vstools import CustomIntEnum, inject_self, vs

from ..types import Center, LeftShift, Slope, TopShift
from .abstract import Descaler
from .complex import ComplexKernel

Expand All @@ -26,7 +27,7 @@ class ZimgDescaler(Descaler):
@inject_self.cached
@inject_kwargs_params
def descale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float] = (0, 0),
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0),
*, blur: float = 1.0, border_handling: BorderHandlingT = BorderHandling.MIRROR, **kwargs: Any
) -> vs.VideoNode:
...
Expand All @@ -37,9 +38,9 @@ class ZimgComplexKernel(ComplexKernel, ZimgDescaler): # type: ignore
@inject_self.cached
@inject_kwargs_params
def descale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float] = (0, 0),
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0),
*, blur: float = 1.0, border_handling: BorderHandlingT, ignore_mask: vs.VideoNode | None = None,
linear: bool = False, sigmoid: bool | tuple[float, float] = False, **kwargs: Any
linear: bool = False, sigmoid: bool | tuple[Slope, Center] = False, **kwargs: Any
) -> vs.VideoNode:
...

Expand Down
13 changes: 13 additions & 0 deletions vskernels/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

from typing import TypeAlias

__all__ = [
'TopShift',
'LeftShift'
]

TopShift: TypeAlias = float
LeftShift: TypeAlias = float
Slope: TypeAlias = float
Center: TypeAlias = float
9 changes: 5 additions & 4 deletions vskernels/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Point, Resampler, ResamplerT, Scaler, ZimgComplexKernel, ZimgDescaler
)
from .kernels.bicubic import MemeKernel
from .types import Center, LeftShift, Slope, TopShift

__all__ = [
'abstract_kernels', 'excluded_kernels',
Expand Down Expand Up @@ -85,7 +86,7 @@ class NoScaleBase(Scaler):
@inject_self.cached
@inject_kwargs_params
def scale( # type: ignore
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[float, float] = (0, 0), **kwargs: Any
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0), **kwargs: Any
) -> vs.VideoNode:
try:
return super().scale(clip, clip.width, clip.height, shift, **kwargs) # type: ignore
Expand Down Expand Up @@ -144,7 +145,7 @@ class LinearLight:
clip: vs.VideoNode

linear: bool = True
sigmoid: bool | tuple[float, float] = False
sigmoid: bool | tuple[Slope, Center] = False

resampler: ResamplerT | None = Catrom

Expand Down Expand Up @@ -176,7 +177,7 @@ def linear(self) -> vs.VideoNode:

return wclip

@linear.setter
@linear.setter # type: ignore
def linear(self, processed: vs.VideoNode) -> None:
if self.ll._exited:
raise CustomRuntimeError('You can\'t set .linear after going out of the context manager!')
Expand All @@ -190,7 +191,7 @@ def out(self) -> vs.VideoNode:
if not hasattr(self, '_linear'):
raise CustomValueError('You need to set .linear before getting .out!', self.__class__)

processed = self._linear
processed = self._linear # type: ignore

if self.ll.sigmoid:
processed = processed.std.Expr(
Expand Down

0 comments on commit 63e85b2

Please sign in to comment.