Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return data from the backend fit method #835

Merged
merged 3 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion alibi_detect/od/_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def fit(
verbose
Verbosity level used to fit the detector. Used for both ``'sklearn'`` and ``'pytorch'`` backends. Defaults to ``0``.
"""
self.backend.fit(
return self.backend.fit(
self.backend._to_backend_dtype(x_ref),
**self.backend.format_fit_kwargs(locals())
)
Expand Down
2 changes: 1 addition & 1 deletion alibi_detect/od/_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def fit(
Verbosity level during training. ``0`` is silent, ``1`` prints fit status. If using `bgd`, fit displays a
progress bar. Otherwise, if using `sgd` then we output the Sklearn `SGDOneClassSVM.fit()` logs.
"""
self.backend.fit(
return self.backend.fit(
self.backend._to_backend_dtype(x_ref),
**self.backend.format_fit_kwargs(locals())
)
Expand Down
10 changes: 9 additions & 1 deletion alibi_detect/od/pytorch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def to_frontend_dtype(self):
return result


def _tensor_to_frontend_dtype(x: Union[torch.Tensor, np.ndarray, float]) -> Union[np.ndarray, float]:
if isinstance(x, torch.Tensor):
x = x.cpu().detach().numpy()
if isinstance(x, np.ndarray) and x.ndim == 0:
x = x.item()
return x


def _raise_type_error(x):
raise TypeError(f'x is type={type(x)} but must be one of TorchOutlierDetectorOutput or a torch Tensor')

Expand All @@ -52,7 +60,7 @@ def to_frontend_dtype(x: Union[torch.Tensor, TorchOutlierDetectorOutput]) -> Uni

return {
'TorchOutlierDetectorOutput': lambda x: x.to_frontend_dtype(),
'Tensor': lambda x: x.cpu().detach().numpy()
'Tensor': _tensor_to_frontend_dtype
}.get(
x.__class__.__name__,
_raise_type_error
Expand Down
2 changes: 1 addition & 1 deletion alibi_detect/od/pytorch/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def fit( # type: ignore[override]
self._set_fitted()
return {
'converged': converged,
'lower_bound': min_loss,
'lower_bound': self._to_frontend_dtype(min_loss),
'n_epochs': epoch
}

Expand Down
2 changes: 1 addition & 1 deletion alibi_detect/od/pytorch/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def fit( # type: ignore[override]
self._set_fitted()
return {
'converged': converged,
'lower_bound': min_loss,
'lower_bound': self._to_frontend_dtype(min_loss),
'n_iter': iter
}

Expand Down
18 changes: 17 additions & 1 deletion alibi_detect/od/tests/test__gmm/test__gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_gmm_integration(backend):
gmm_detector = GMM(n_components=8, backend=backend)
X_ref, _ = make_moons(1001, shuffle=True, noise=0.05, random_state=None)
X_ref, x_inlier = X_ref[0:1000], X_ref[1000][None]
gmm_detector.fit(X_ref)
fit_logs = gmm_detector.fit(X_ref)
gmm_detector.infer_threshold(X_ref, 0.1)
result = gmm_detector.predict(x_inlier)
result = result['data']['is_outlier'][0]
Expand Down Expand Up @@ -117,3 +117,19 @@ def test_gmm_torchscript(tmp_path):
ts_gmm = torch.load(tmp_path / 'gmm.pt')
y = ts_gmm(x)
assert torch.all(y == torch.tensor([False, True]))


@pytest.mark.parametrize('backend', ['pytorch', 'sklearn'])
def test_gmm_fit(backend):
"""Test GMM detector fit method.

Tests detector checks for convergence and stops early if it does.
"""
gmm = GMM(n_components=1, backend=backend)
mean = [8, 8]
cov = [[2., 0.], [0., 1.]]
x_ref = torch.tensor(np.random.multivariate_normal(mean, cov, 1000))
fit_results = gmm.fit(x_ref, tol=0.01, batch_size=32)
assert isinstance(fit_results['lower_bound'], float)
assert fit_results['converged']
assert fit_results['lower_bound'] < 1
27 changes: 26 additions & 1 deletion alibi_detect/od/tests/test__svm/test__svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_fitted_svm_score(optimization):
nu=0.1
)
x_ref = np.random.randn(100, 2)
svm_detector.fit(x_ref)
fit_logs = svm_detector.fit(x_ref)
x = np.array([[0, 10], [0.1, 0]])
scores = svm_detector.score(x)

Expand Down Expand Up @@ -207,3 +207,28 @@ def test_svm_torchscript(tmp_path):
ts_svm = torch.load(tmp_path / 'svm.pt')
y = ts_svm(x)
assert torch.all(y == torch.tensor([False, True]))


@pytest.mark.parametrize('optimization', ['sgd', 'bgd'])
def test_svm_fit(optimization):
"""Test SVM detector fit method.

Tests pytorch detector checks for convergence and stops early if it does.
"""
kernel = GaussianRBF(torch.tensor(1.))
svm = SVM(
n_components=10,
kernel=kernel,
nu=0.01,
optimization=optimization,
)
mean = [8, 8]
cov = [[2., 0.], [0., 1.]]
x_ref = torch.tensor(np.random.multivariate_normal(mean, cov, 1000))
fit_results = svm.fit(x_ref, tol=0.01)
assert fit_results['converged']
assert fit_results['n_iter'] < 100
assert fit_results.get('lower_bound', 0) < 1
# 'sgd' optimization does not return lower bound
if optimization == 'bgd':
assert isinstance(fit_results['lower_bound'], float)
Loading