Skip to content

Commit

Permalink
Normalize prop enums
Browse files Browse the repository at this point in the history
FIx #20
  • Loading branch information
Setsugennoao committed May 15, 2024
1 parent 73419b2 commit 491c754
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
24 changes: 17 additions & 7 deletions vskernels/kernels/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from stgpytools import inject_kwargs_params
from vstools import (
CustomIndexError, CustomRuntimeError, CustomValueError, FieldBased, FuncExceptT, GenericVSFunction,
HoldsVideoFormatT, KwargsT, Matrix, MatrixT, T, VideoFormatT, check_correct_subsampling, check_variable_resolution,
core, depth, expect_bits, get_subclasses, get_video_format, inject_self, vs, vs_object
HoldsVideoFormatT, KwargsT, Matrix, MatrixT, PropEnum, T, VideoFormatT, check_correct_subsampling,
check_variable_resolution, core, depth, expect_bits, get_subclasses, get_video_format, inject_self, vs, vs_object
)

from ..exceptions import UnknownDescalerError, UnknownKernelError, UnknownResamplerError, UnknownScalerError
Expand All @@ -32,6 +32,15 @@ def _default_kernel_radius(cls: type[T], self: T) -> int:
return super(cls, self).kernel_radius # type: ignore


def _norm_props_enums(kwargs: KwargsT) -> KwargsT:
return {
key: (
(value.value_zimg if hasattr(value, 'value_zimg') else int(value))
if isinstance(value, PropEnum) else value
) for key, value in kwargs.items()
}


@lru_cache
def _get_keywords(_methods: tuple[Callable[..., Any] | None, ...], self: Any) -> set[str]:
methods_list = list(_methods)
Expand Down Expand Up @@ -211,7 +220,7 @@ def scale( # type: ignore[override]
self, clip: vs.VideoNode, width: int, height: int, shift: tuple[TopShift, LeftShift] = (0, 0), **kwargs: Any
) -> vs.VideoNode:
check_correct_subsampling(clip, width, height)
return self.scale_function(clip, **self.get_scale_args(clip, shift, width, height, **kwargs))
return self.scale_function(clip, **_norm_props_enums(self.get_scale_args(clip, shift, width, height, **kwargs)))

@inject_self.cached
def multi(
Expand Down Expand Up @@ -280,13 +289,14 @@ def descale( # type: ignore[override]
fields = clip.std.SeparateFields(field_based.is_tff)

interleaved = core.std.Interleave([
self.descale_function(fields[offset::2], **(de_kwargs | dict(src_top=top_shift + (field_shift * mult))))
self.descale_function(fields[offset::2], **_norm_props_enums(
de_kwargs | dict(src_top=top_shift + (field_shift * mult))))
for offset, mult in [(0, 1), (1, -1)]
])

descaled = interleaved.std.DoubleWeave(field_based.is_tff)[::2]
else:
descaled = self.descale_function(clip, **de_kwargs)
descaled = self.descale_function(clip, **_norm_props_enums(de_kwargs))

return depth(descaled, bits)

Expand Down Expand Up @@ -326,7 +336,7 @@ def resample(
self, clip: vs.VideoNode, format: int | VideoFormatT | HoldsVideoFormatT,
matrix: MatrixT | None = None, matrix_in: MatrixT | None = None, **kwargs: Any
) -> vs.VideoNode:
return self.resample_function(clip, **self.get_resample_args(clip, format, matrix, matrix_in, **kwargs))
return self.resample_function(clip, **_norm_props_enums(self.get_resample_args(clip, format, matrix, matrix_in, **kwargs)))

def get_resample_args(
self, clip: vs.VideoNode, format: int | VideoFormatT | HoldsVideoFormatT,
Expand Down Expand Up @@ -381,7 +391,7 @@ def shift(
n_planes = clip.format.num_planes

def _shift(src: vs.VideoNode, shift: tuple[TopShift, LeftShift] = (0, 0)) -> vs.VideoNode:
return self.scale_function(src, **self.get_scale_args(src, shift, **kwargs))
return Scaler.scale(self, clip, **kwargs)

if not shifts_or_top and not shift_left:
return _shift(clip)
Expand Down
3 changes: 1 addition & 2 deletions vskernels/kernels/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,8 @@ def scale( # type: ignore[override]
s + ((p - c) // 2) for s, c, p in zip(shift, *((x.width, x.height) for x in (clip, padded)))
), padded

kwargs = self.get_scale_args(clip, shift, width, height, **kwargs)

clip = self.scale_function(clip, **kwargs)
clip = Scaler.scale(self, clip, width, height, shift, **kwargs)

if const_size and out_sar:
clip = out_sar.apply(clip)
Expand Down
3 changes: 1 addition & 2 deletions vskernels/kernels/fmtconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ def shift(

def _shift(shift_top: float | list[float] = 0.0, shift_left: float | list[float] = 0.0) -> vs.VideoNode:
return self.scale_function(
clip, sy=shift_top, sx=shift_left, kernel=self._kernel,
**self.get_clean_kwargs(), **kwargs
clip, sy=shift_top, sx=shift_left, kernel=self._kernel, **self.get_clean_kwargs(), **kwargs
)

if not shifts_or_top and not shift_left:
Expand Down

0 comments on commit 491c754

Please sign in to comment.