From 4a8700a999fbc65a3f0cd096f729a356c5dd9d09 Mon Sep 17 00:00:00 2001 From: michalk8 Date: Wed, 16 Sep 2020 16:15:17 +0200 Subject: [PATCH 01/12] Add weights to input data --- batchglm/models/base/estimator.py | 4 +++ batchglm/models/base/input.py | 42 ++++++++++++++++++---- batchglm/models/base/model.py | 4 +++ batchglm/train/numpy/base_glm/estimator.py | 1 - batchglm/train/numpy/base_glm/vars.py | 3 +- 5 files changed, 45 insertions(+), 9 deletions(-) diff --git a/batchglm/models/base/estimator.py b/batchglm/models/base/estimator.py index dc1e3bf1..af1cae3b 100644 --- a/batchglm/models/base/estimator.py +++ b/batchglm/models/base/estimator.py @@ -76,6 +76,10 @@ def fisher_inv(self): def x(self) -> np.ndarray: return self.input_data.x + @property + def w(self) -> np.ndarray: + return self.input_data.w + @property def a_var(self): if isinstance(self.model.a_var, dask.array.core.Array): diff --git a/batchglm/models/base/input.py b/batchglm/models/base/input.py index 20fddd65..a4272a47 100644 --- a/batchglm/models/base/input.py +++ b/batchglm/models/base/input.py @@ -3,7 +3,7 @@ import numpy as np import scipy.sparse import sparse -from typing import List +from typing import List, Union, Optional, TypeAlias try: import anndata @@ -17,6 +17,9 @@ logger = logging.getLogger(__name__) +ArrayLike = TypeAlias(Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array]) +InputType = TypeAlias(Union[ArrayLike, anndata.AnnData, "InputDataBase"]) + class InputDataBase: """ @@ -29,13 +32,14 @@ class InputDataBase: def __init__( self, - data, - observation_names=None, - feature_names=None, + data: InputType, + weights: Optional[Union[ArrayLike, str]] = None, + observation_names: Optional[List[str]] = None, + feature_names: Optional[List[str]] = None, chunk_size_cells: int = 100000, chunk_size_genes: int = 100, as_dask: bool = True, - cast_dtype=None + cast_dtype: Optional[np.dtype] = None ): """ Create a new InputData object. @@ -53,22 +57,38 @@ def __init__( """ self.observations = observation_names self.features = feature_names + self.w = weights + if isinstance(data, np.ndarray) or \ isinstance(data, scipy.sparse.csr_matrix) or \ isinstance(data, dask.array.core.Array): self.x = data elif isinstance(data, anndata.AnnData) or isinstance(data, Raw): self.x = data.X + if isinstance(weights, str): + self.w = data.obs[weights].values elif isinstance(data, InputDataBase): self.x = data.x + self.w = data.w else: raise ValueError("type of data %s not recognized" % type(data)) + if self.w is None: + self.w = np.ones(self.x.shape[0], dtype=self.x.dtype) + + if scipy.sparse.issparse(self.w): + self.w = self.w.toarray() + if self.w.ndim == 2: + self.w = self.w.squeeze(1) + + assert self.w.shape == (self.x.shape[0],), "invalid weight shape %s" % self.w.shape + assert issubclass(self.w.dtype.type, np.floating) + if as_dask: if isinstance(self.x, dask.array.core.Array): self.x = self.x.compute() # Need to wrap dask around the COO matrix version of the sparse package if matrix is sparse. - if isinstance(self.x, scipy.sparse.spmatrix): + if scipy.sparse.issparse(self.x): self.x = dask.array.from_array( sparse.COO.from_scipy_sparse( self.x.astype(cast_dtype if cast_dtype is not None else self.x.dtype) @@ -81,11 +101,21 @@ def __init__( self.x.astype(cast_dtype if cast_dtype is not None else self.x.dtype), chunks=(chunk_size_cells, chunk_size_genes), ) + + if isinstance(self.w, dask.array.core.Array): + self.w = self.w.compute() + self.w = dask.array.from_array( + self.w.astype(cast_dtype if cast_dtype is not None else self.w.dtype), + chunks=(chunk_size_cells,), + ) else: if isinstance(self.x, dask.array.core.Array): self.x = self.x.compute() + if isinstance(self.w, dask.array.core.Array): + self.w = self.w.compute() if cast_dtype is not None: self.x = self.x.astype(cast_dtype) + self.w = self.w.astype(cast_dtype) self._feature_allzero = np.sum(self.x, axis=0) == 0 self.chunk_size_cells = chunk_size_cells diff --git a/batchglm/models/base/model.py b/batchglm/models/base/model.py index 92ab97bc..37d60801 100644 --- a/batchglm/models/base/model.py +++ b/batchglm/models/base/model.py @@ -26,6 +26,10 @@ def __init__( def x(self): return self.input_data.x + @property + def w(self): + return self.input_data.w + def get(self, key: Union[str, Iterable]) -> Union[Any, Dict[str, Any]]: """ Returns the values specified by key. diff --git a/batchglm/train/numpy/base_glm/estimator.py b/batchglm/train/numpy/base_glm/estimator.py index 1ad26326..d0476542 100644 --- a/batchglm/train/numpy/base_glm/estimator.py +++ b/batchglm/train/numpy/base_glm/estimator.py @@ -9,7 +9,6 @@ import sparse import sys import time -from typing import Tuple from .external import _EstimatorGLM, pkg_constants from .training_strategies import TrainingStrategies diff --git a/batchglm/train/numpy/base_glm/vars.py b/batchglm/train/numpy/base_glm/vars.py index afaca423..3645b938 100644 --- a/batchglm/train/numpy/base_glm/vars.py +++ b/batchglm/train/numpy/base_glm/vars.py @@ -1,12 +1,11 @@ import dask.array import numpy as np -import scipy.sparse import abc class ModelVarsGlm: """ - Build variables to be optimzed and their constraints. + Build variables to be optimized and their constraints. """ From 2ead7ccf52928c44e115bae4153ca3f36c87aeea Mon Sep 17 00:00:00 2001 From: michalk8 Date: Wed, 16 Sep 2020 19:35:19 +0200 Subject: [PATCH 02/12] Propagate weights --- batchglm/models/base/input.py | 12 +++++++---- batchglm/models/base_glm/external.py | 3 ++- batchglm/models/base_glm/input.py | 4 ++-- batchglm/train/numpy/base_glm/model.py | 14 ++++++------- batchglm/train/numpy/glm_nb/model.py | 28 ++++++++++++++------------ batchglm/types.py | 12 +++++++++++ 6 files changed, 46 insertions(+), 27 deletions(-) create mode 100644 batchglm/types.py diff --git a/batchglm/models/base/input.py b/batchglm/models/base/input.py index a4272a47..aafaf2b5 100644 --- a/batchglm/models/base/input.py +++ b/batchglm/models/base/input.py @@ -3,7 +3,8 @@ import numpy as np import scipy.sparse import sparse -from typing import List, Union, Optional, TypeAlias +from typing import List, Union, Optional +from .external.types import ArrayLike, InputType try: import anndata @@ -17,9 +18,6 @@ logger = logging.getLogger(__name__) -ArrayLike = TypeAlias(Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array]) -InputType = TypeAlias(Union[ArrayLike, anndata.AnnData, "InputDataBase"]) - class InputDataBase: """ @@ -81,9 +79,15 @@ def __init__( if self.w.ndim == 2: self.w = self.w.squeeze(1) + # sanity checks assert self.w.shape == (self.x.shape[0],), "invalid weight shape %s" % self.w.shape assert issubclass(self.w.dtype.type, np.floating) + if self.observations is not None: + assert len(self.observations) == self.x.shape[0] + if self.features is not None: + assert len(self.features) == self.x.shape[1] + if as_dask: if isinstance(self.x, dask.array.core.Array): self.x = self.x.compute() diff --git a/batchglm/models/base_glm/external.py b/batchglm/models/base_glm/external.py index 163181f5..ce6d3233 100644 --- a/batchglm/models/base_glm/external.py +++ b/batchglm/models/base_glm/external.py @@ -4,4 +4,5 @@ from batchglm.models.base import _SimulatorBase import batchglm.data as data_utils -from batchglm.utils.linalg import groupwise_solve_lm \ No newline at end of file +from batchglm.utils.linalg import groupwise_solve_lm +import batchglm.types as types diff --git a/batchglm/models/base_glm/input.py b/batchglm/models/base_glm/input.py index dd27be71..6532a069 100644 --- a/batchglm/models/base_glm/input.py +++ b/batchglm/models/base_glm/input.py @@ -7,11 +7,11 @@ import numpy as np import pandas as pd import patsy -import scipy.sparse from typing import Union from .utils import parse_constraints, parse_design from .external import InputDataBase +from .external.types import InputType class InputDataGLM(InputDataBase): @@ -25,7 +25,7 @@ class InputDataGLM(InputDataBase): def __init__( self, - data: Union[np.ndarray, anndata.AnnData, scipy.sparse.csr_matrix], + data: InputType, design_loc: Union[np.ndarray, pd.DataFrame, patsy.design_info.DesignMatrix] = None, design_loc_names: Union[list, np.ndarray] = None, design_scale: Union[np.ndarray, pd.DataFrame, patsy.design_info.DesignMatrix] = None, diff --git a/batchglm/train/numpy/base_glm/model.py b/batchglm/train/numpy/base_glm/model.py index 2f1f13ca..6f785425 100644 --- a/batchglm/train/numpy/base_glm/model.py +++ b/batchglm/train/numpy/base_glm/model.py @@ -263,10 +263,10 @@ def jac_b_j(self, j) -> np.ndarray: # Make sure that dimensionality of sliced array is kept: if isinstance(j, int) or isinstance(j, np.int32) or isinstance(j, np.int64): j = [j] - w = self.jac_weight_b_j(j=j) # (observations x features) - xh = np.matmul(self.design_scale, self.constraints_scale) # (observations x inferred param) - return np.einsum( - 'fob,of->fb', - np.einsum('ob,of->fob', xh, w), - xh - ) + w = self.jac_weight_b_j(j=j) # (observations x features) + xh = np.matmul(self.design_scale, self.constraints_scale) # (observations x inferred param) + return np.einsum( + 'fob,of->fb', + np.einsum('ob,of->fob', xh, w), + xh + ) diff --git a/batchglm/train/numpy/glm_nb/model.py b/batchglm/train/numpy/glm_nb/model.py index 3fd93fe1..a9a24133 100644 --- a/batchglm/train/numpy/glm_nb/model.py +++ b/batchglm/train/numpy/glm_nb/model.py @@ -3,7 +3,6 @@ import numpy as np import scipy.sparse import scipy.special -import sparse from .external import Model, ModelIwls, InputDataGLM from .processModel import ProcessModel @@ -41,7 +40,7 @@ def fim_weight_aa(self): :return: observations x features """ - return - self.location * self.scale / (self.scale + self.location) + return - self.w * self.location * self.scale / (self.scale + self.location) @property def ybar(self) -> np.ndarray: @@ -56,7 +55,7 @@ def fim_weight_aa_j(self, j): :return: observations x features """ - return - self.location_j(j=j) * self.scale_j(j=j) / (self.scale_j(j=j) + self.location_j(j=j)) + return - self.w * self.location_j(j=j) * self.scale_j(j=j) / (self.scale_j(j=j) + self.location_j(j=j)) def ybar_j(self, j) -> np.ndarray: """ @@ -89,7 +88,7 @@ def jac_weight_b(self): const1 = scipy.special.digamma(scale_plus_x) - scipy.special.digamma(scale) const2 = - scale_plus_x / r_plus_mu const3 = np.log(scale) + np.ones_like(scale) - np.log(r_plus_mu) - return scale * (const1 + const2 + const3) + return self.w * scale * (const1 + const2 + const3) def jac_weight_b_j(self, j): """ @@ -111,7 +110,7 @@ def jac_weight_b_j(self, j): const1 = scipy.special.digamma(scale_plus_x) - scipy.special.digamma(scale) const2 = - scale_plus_x / r_plus_mu const3 = np.log(scale) + np.ones_like(scale) - np.log(r_plus_mu) - return scale * (const1 + const2 + const3) + return self.w * scale * (const1 + const2 + const3) @property def fim_ab(self) -> np.ndarray: @@ -150,6 +149,7 @@ def hessian_weight_ab(self): scale = self.scale loc = self.location return np.multiply( + self.w, loc * scale, np.asarray(self.x - loc) / np.square(loc + scale) ) @@ -163,7 +163,7 @@ def hessian_weight_aa(self): else: x_by_scale_plus_one = np.asarray(self.x.divide(scale) + np.ones_like(scale)) - return - loc * x_by_scale_plus_one / np.square((loc / scale) + np.ones_like(loc)) + return - self.w * loc * x_by_scale_plus_one / np.square((loc / scale) + np.ones_like(loc)) @property def hessian_weight_bb(self): @@ -176,7 +176,7 @@ def hessian_weight_bb(self): const2 = - scipy.special.digamma(scale) + scale * scipy.special.polygamma(n=1, x=scale) const3 = - loc * scale_plus_x + np.ones_like(scale) * 2. * scale * scale_plus_loc / np.square(scale_plus_loc) const4 = np.log(scale) + np.ones_like(scale) * 2. - np.log(scale_plus_loc) - return scale * (const1 + const2 + const3 + const4) + return self.w * scale * (const1 + const2 + const3 + const4) @property def ll(self): @@ -198,7 +198,7 @@ def ll(self): np.asarray(self.x.multiply(self.eta_loc - log_r_plus_mu) + np.multiply(scale, self.eta_scale - log_r_plus_mu)) ll = np.asarray(ll) - return self.np_clip_param(ll, "ll") + return self.np_clip_param(self.w * ll, "ll") def ll_j(self, j): # Make sure that dimensionality of sliced array is kept: @@ -222,10 +222,11 @@ def ll_j(self, j): np.asarray(self.x[:, j].multiply(self.eta_loc_j(j=j) - log_r_plus_mu) + np.multiply(scale, self.eta_scale_j(j=j) - log_r_plus_mu)) ll = np.asarray(ll) - return self.np_clip_param(ll, "ll") + return self.np_clip_param(self.w * ll, "ll") + # TODO: not used def ll_handle(self): - def fun(x, eta_loc, b_var, xh_scale): + def fun(x, w, eta_loc, b_var, xh_scale): eta_scale = np.matmul(xh_scale, b_var) scale = np.exp(eta_scale) loc = np.exp(eta_loc) @@ -239,11 +240,12 @@ def fun(x, eta_loc, b_var, xh_scale): np.multiply(scale, eta_scale - log_r_plus_mu) else: raise ValueError("type x %s not supported" % type(x)) - return self.np_clip_param(ll, "ll") + return self.np_clip_param(w * ll, "ll") return fun + # TODO: not used def jac_b_handle(self): - def fun(x, eta_loc, b_var, xh_scale): + def fun(x, w, eta_loc, b_var, xh_scale): scale = np.exp(b_var) loc = np.exp(eta_loc) scale_plus_x = scale + x @@ -253,6 +255,6 @@ def fun(x, eta_loc, b_var, xh_scale): const1 = scipy.special.digamma(scale_plus_x) - scipy.special.digamma(scale) const2 = - scale_plus_x / r_plus_mu const3 = np.log(scale) + np.ones_like(scale) - np.log(r_plus_mu) - return scale * (const1 + const2 + const3) + return w * scale * (const1 + const2 + const3) return fun diff --git a/batchglm/types.py b/batchglm/types.py new file mode 100644 index 00000000..b3d76e29 --- /dev/null +++ b/batchglm/types.py @@ -0,0 +1,12 @@ +from typing import TypeAlias, TypeVar, Union + +import scipy +import dask + +try: + from anndata import AnnData +except ImportError: + AnnData = TypeVar("AnnData") + +ArrayLike = TypeAlias(Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array]) +InputType = TypeAlias(Union[ArrayLike, AnnData, "InputDataBase"]) From 05a8c548fd032cb84643585b566323b9bcc49353 Mon Sep 17 00:00:00 2001 From: michalk8 Date: Wed, 16 Sep 2020 19:39:38 +0200 Subject: [PATCH 03/12] Add forgotten import --- batchglm/models/base/external.py | 1 + 1 file changed, 1 insertion(+) diff --git a/batchglm/models/base/external.py b/batchglm/models/base/external.py index 75015489..996552c8 100644 --- a/batchglm/models/base/external.py +++ b/batchglm/models/base/external.py @@ -1,2 +1,3 @@ import batchglm.pkg_constants as pkg_constants import batchglm.data as data_utils +import batchglm.types as types From 6227b77c75bf739b1a65e48ef927ebc973050f7d Mon Sep 17 00:00:00 2001 From: michalk8 Date: Wed, 16 Sep 2020 20:22:43 +0200 Subject: [PATCH 04/12] Fix missing import, remove TypeAlias --- batchglm/types.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/batchglm/types.py b/batchglm/types.py index b3d76e29..e3c8dd10 100644 --- a/batchglm/types.py +++ b/batchglm/types.py @@ -1,12 +1,13 @@ -from typing import TypeAlias, TypeVar, Union +from typing import TypeVar, Union import scipy import dask +import numpy as np try: from anndata import AnnData except ImportError: AnnData = TypeVar("AnnData") -ArrayLike = TypeAlias(Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array]) -InputType = TypeAlias(Union[ArrayLike, AnnData, "InputDataBase"]) +ArrayLike = Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array] +InputType = Union[ArrayLike, AnnData, "InputDataBase"] From 4edc75cebf92f4ec3cbdad9039995511518d7ef4 Mon Sep 17 00:00:00 2001 From: michalk8 Date: Wed, 16 Sep 2020 20:27:33 +0200 Subject: [PATCH 05/12] Fix types --- batchglm/models/base/input.py | 6 +++--- batchglm/models/base_glm/input.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/batchglm/models/base/input.py b/batchglm/models/base/input.py index aafaf2b5..a230d6a7 100644 --- a/batchglm/models/base/input.py +++ b/batchglm/models/base/input.py @@ -4,7 +4,7 @@ import scipy.sparse import sparse from typing import List, Union, Optional -from .external.types import ArrayLike, InputType +from .external import types as T try: import anndata @@ -30,8 +30,8 @@ class InputDataBase: def __init__( self, - data: InputType, - weights: Optional[Union[ArrayLike, str]] = None, + data: T.InputType, + weights: Optional[Union[T.ArrayLike, str]] = None, observation_names: Optional[List[str]] = None, feature_names: Optional[List[str]] = None, chunk_size_cells: int = 100000, diff --git a/batchglm/models/base_glm/input.py b/batchglm/models/base_glm/input.py index 6532a069..4f7134fb 100644 --- a/batchglm/models/base_glm/input.py +++ b/batchglm/models/base_glm/input.py @@ -11,7 +11,7 @@ from .utils import parse_constraints, parse_design from .external import InputDataBase -from .external.types import InputType +from .external import types as T class InputDataGLM(InputDataBase): @@ -25,7 +25,7 @@ class InputDataGLM(InputDataBase): def __init__( self, - data: InputType, + data: T.InputType, design_loc: Union[np.ndarray, pd.DataFrame, patsy.design_info.DesignMatrix] = None, design_loc_names: Union[list, np.ndarray] = None, design_scale: Union[np.ndarray, pd.DataFrame, patsy.design_info.DesignMatrix] = None, From d977f3ddc987976d0b85a75a884f75a8763ada8e Mon Sep 17 00:00:00 2001 From: michalk8 Date: Wed, 23 Sep 2020 21:07:34 +0200 Subject: [PATCH 06/12] Fix shape, dtype check --- batchglm/models/base/input.py | 11 ++++++----- batchglm/models/base_glm/input.py | 5 ++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/batchglm/models/base/input.py b/batchglm/models/base/input.py index a230d6a7..ebbb65c4 100644 --- a/batchglm/models/base/input.py +++ b/batchglm/models/base/input.py @@ -48,6 +48,7 @@ def __init__( - np.ndarray: NumPy array containing the raw data - anndata.AnnData: AnnData object containing the count data and optional the design models stored as data.obsm[design_loc] and data.obsm[design_scale] + :param weights: (optional) observation weights :param observation_names: (optional) names of the observations. :param feature_names: (optional) names of the features. :param cast_dtype: data type of all data; should be either float32 or float64 @@ -72,16 +73,16 @@ def __init__( raise ValueError("type of data %s not recognized" % type(data)) if self.w is None: - self.w = np.ones(self.x.shape[0], dtype=self.x.dtype) + self.w = np.ones(self.x.shape[0], dtype=np.float32 if not issubclass(self.x.dtype, np.floating) else self.x.dtype) if scipy.sparse.issparse(self.w): self.w = self.w.toarray() - if self.w.ndim == 2: - self.w = self.w.squeeze(1) + if self.w.ndim == 1: + self.w = self.w.reshape((-1, 1)) # sanity checks - assert self.w.shape == (self.x.shape[0],), "invalid weight shape %s" % self.w.shape - assert issubclass(self.w.dtype.type, np.floating) + assert self.w.shape == (self.x.shape[0], 1), "invalid weight shape %s" % self.w.shape + assert issubclass(self.w.dtype.type, np.floating), "invalid weight type %s" % self.w.dtype if self.observations is not None: assert len(self.observations) == self.x.shape[0] diff --git a/batchglm/models/base_glm/input.py b/batchglm/models/base_glm/input.py index 4f7134fb..bf497433 100644 --- a/batchglm/models/base_glm/input.py +++ b/batchglm/models/base_glm/input.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd import patsy -from typing import Union +from typing import Union, Optional from .utils import parse_constraints, parse_design from .external import InputDataBase @@ -26,6 +26,7 @@ class InputDataGLM(InputDataBase): def __init__( self, data: T.InputType, + weights: Optional[Union[T.ArrayLike, str]] = None, design_loc: Union[np.ndarray, pd.DataFrame, patsy.design_info.DesignMatrix] = None, design_loc_names: Union[list, np.ndarray] = None, design_scale: Union[np.ndarray, pd.DataFrame, patsy.design_info.DesignMatrix] = None, @@ -48,6 +49,7 @@ def __init__( - np.ndarray: NumPy array containing the raw data - anndata.AnnData: AnnData object containing the count data and optional the design models stored as data.obsm[design_loc] and data.obsm[design_scale] + :param weights: (optional) observation weights :param design_loc: Some matrix format (observations x mean model parameters) The location design model. Optional if already specified in `data` :param design_loc_names: (optional) @@ -83,6 +85,7 @@ def __init__( InputDataBase.__init__( self=self, data=data, + weights=weights, observation_names=observation_names, feature_names=feature_names, chunk_size_cells=chunk_size_cells, From dd98812c302fa535d77f945dc9183d0b1f04a009 Mon Sep 17 00:00:00 2001 From: michalk8 Date: Wed, 23 Sep 2020 22:04:02 +0200 Subject: [PATCH 07/12] Fix dtype, start fixing sparse implementation --- batchglm/models/base/input.py | 5 +++-- batchglm/models/base/model.py | 3 +++ batchglm/train/numpy/base_glm/vars.py | 19 ++++++++++--------- batchglm/train/numpy/glm_nb/model.py | 8 ++++++-- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/batchglm/models/base/input.py b/batchglm/models/base/input.py index ebbb65c4..39f3ef1a 100644 --- a/batchglm/models/base/input.py +++ b/batchglm/models/base/input.py @@ -73,7 +73,8 @@ def __init__( raise ValueError("type of data %s not recognized" % type(data)) if self.w is None: - self.w = np.ones(self.x.shape[0], dtype=np.float32 if not issubclass(self.x.dtype, np.floating) else self.x.dtype) + self.w = np.ones(self.x.shape[0], + dtype=np.float32 if not issubclass(type(self.x.dtype.type), np.floating) else self.x.dtype) if scipy.sparse.issparse(self.w): self.w = self.w.toarray() @@ -111,7 +112,7 @@ def __init__( self.w = self.w.compute() self.w = dask.array.from_array( self.w.astype(cast_dtype if cast_dtype is not None else self.w.dtype), - chunks=(chunk_size_cells,), + chunks=(chunk_size_cells, 1), ) else: if isinstance(self.x, dask.array.core.Array): diff --git a/batchglm/models/base/model.py b/batchglm/models/base/model.py index 37d60801..a154e99c 100644 --- a/batchglm/models/base/model.py +++ b/batchglm/models/base/model.py @@ -45,5 +45,8 @@ def get(self, key: Union[str, Iterable]) -> Union[Any, Dict[str, Any]]: def __getitem__(self, item): return self.get(item) + def __str__(self): + return self.__class__.__name__ + def __repr__(self): return self.__str__() diff --git a/batchglm/train/numpy/base_glm/vars.py b/batchglm/train/numpy/base_glm/vars.py index 3645b938..d4244ab2 100644 --- a/batchglm/train/numpy/base_glm/vars.py +++ b/batchglm/train/numpy/base_glm/vars.py @@ -1,3 +1,5 @@ +from typing import Union + import dask.array import numpy as np import abc @@ -23,8 +25,8 @@ def __init__( self, init_a: np.ndarray, init_b: np.ndarray, - constraints_loc: np.ndarray, - constraints_scale: np.ndarray, + constraints_loc: Union[np.ndarray, dask.array.core.Array], + constraints_scale: Union[np.ndarray, dask.array.core.Array], chunk_size_genes: int, dtype: str ): @@ -41,13 +43,12 @@ def __init__( init_a_clipped = self.np_clip_param(np.asarray(init_a, dtype=dtype), "a_var") init_b_clipped = self.np_clip_param(np.asarray(init_b, dtype=dtype), "b_var") - self.params = dask.array.from_array(np.concatenate( - [ - init_a_clipped, - init_b_clipped, - ], - axis=0 - ), chunks=(1000, chunk_size_genes)) + + self.params = np.concatenate([init_a_clipped, init_b_clipped], axis=0) + # TODO: being not a dask array breaks a lot of things + if True or isinstance(constraints_loc, dask.array.core.Array): + self.params = dask.array.from_array(self.params, chunks=(1000, chunk_size_genes)) + self.npar_a = init_a_clipped.shape[0] # Properties to follow gene-wise convergence. diff --git a/batchglm/train/numpy/glm_nb/model.py b/batchglm/train/numpy/glm_nb/model.py index a9a24133..9e5963c5 100644 --- a/batchglm/train/numpy/glm_nb/model.py +++ b/batchglm/train/numpy/glm_nb/model.py @@ -192,11 +192,15 @@ def ll(self): np.multiply(scale, self.eta_scale - log_r_plus_mu) else: # sparse scipy + # The inner np.asarray (self.x.multiply(...)) is there because dask + # does not yet support fancy nd indexing + # tocsr because TypeError: 'coo_matrix' object is not subscriptable + ll = scipy.special.gammaln(np.asarray(scale + self.x)) - \ scipy.special.gammaln(self.x + np.ones_like(scale)) - \ scipy.special.gammaln(scale) + \ - np.asarray(self.x.multiply(self.eta_loc - log_r_plus_mu) + - np.multiply(scale, self.eta_scale - log_r_plus_mu)) + self.x.multiply(np.asarray(self.eta_loc - log_r_plus_mu + + np.multiply(scale, self.eta_scale - log_r_plus_mu))).tocsr() ll = np.asarray(ll) return self.np_clip_param(self.w * ll, "ll") From 72fbd20f82c9bd40fe3e78e3958038dfe5237197 Mon Sep 17 00:00:00 2001 From: michalk8 Date: Wed, 23 Sep 2020 22:19:45 +0200 Subject: [PATCH 08/12] Revert previous solution, raise if not uisng dask --- batchglm/models/base/input.py | 3 +++ batchglm/train/numpy/glm_nb/model.py | 16 +++++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/batchglm/models/base/input.py b/batchglm/models/base/input.py index 39f3ef1a..9597b732 100644 --- a/batchglm/models/base/input.py +++ b/batchglm/models/base/input.py @@ -115,6 +115,9 @@ def __init__( chunks=(chunk_size_cells, 1), ) else: + if scipy.sparse.issparse(self.x): + raise TypeError(f"For sparse matrices, only dask is supported.") + if isinstance(self.x, dask.array.core.Array): self.x = self.x.compute() if isinstance(self.w, dask.array.core.Array): diff --git a/batchglm/train/numpy/glm_nb/model.py b/batchglm/train/numpy/glm_nb/model.py index 9e5963c5..02d04620 100644 --- a/batchglm/train/numpy/glm_nb/model.py +++ b/batchglm/train/numpy/glm_nb/model.py @@ -191,16 +191,22 @@ def ll(self): self.x * (self.eta_loc - log_r_plus_mu) + \ np.multiply(scale, self.eta_scale - log_r_plus_mu) else: - # sparse scipy - # The inner np.asarray (self.x.multiply(...)) is there because dask - # does not yet support fancy nd indexing + # sparse scipy, assuming self.x is also a dask array + + # if not, it can be fixed as follows: + # the inner np.asarray: ... + (self.x.multiply(np.asarray(...))).tocsr() + # is there because dask does not yet support fancy nd indexing # tocsr because TypeError: 'coo_matrix' object is not subscriptable + # when computing the values in `ll = np.asarray(ll)` + # + # however, the estimator code will be broken, since it calls: self.model.ll_byfeature.compute() + # which requires this property, which will not be a dask array ll = scipy.special.gammaln(np.asarray(scale + self.x)) - \ scipy.special.gammaln(self.x + np.ones_like(scale)) - \ scipy.special.gammaln(scale) + \ - self.x.multiply(np.asarray(self.eta_loc - log_r_plus_mu + - np.multiply(scale, self.eta_scale - log_r_plus_mu))).tocsr() + np.asarray(self.x.multiply(self.eta_loc - log_r_plus_mu + + np.multiply(scale, self.eta_scale - log_r_plus_mu))) ll = np.asarray(ll) return self.np_clip_param(self.w * ll, "ll") From e8274273d5236e11e933ea1258773f82f3a0f70f Mon Sep 17 00:00:00 2001 From: michalk8 Date: Wed, 23 Sep 2020 22:25:13 +0200 Subject: [PATCH 09/12] Fix parantheses --- batchglm/train/numpy/glm_nb/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/batchglm/train/numpy/glm_nb/model.py b/batchglm/train/numpy/glm_nb/model.py index 02d04620..170a3262 100644 --- a/batchglm/train/numpy/glm_nb/model.py +++ b/batchglm/train/numpy/glm_nb/model.py @@ -205,8 +205,8 @@ def ll(self): ll = scipy.special.gammaln(np.asarray(scale + self.x)) - \ scipy.special.gammaln(self.x + np.ones_like(scale)) - \ scipy.special.gammaln(scale) + \ - np.asarray(self.x.multiply(self.eta_loc - log_r_plus_mu + - np.multiply(scale, self.eta_scale - log_r_plus_mu))) + np.asarray(self.x.multiply(self.eta_loc - log_r_plus_mu) + + np.multiply(scale, self.eta_scale - log_r_plus_mu)) ll = np.asarray(ll) return self.np_clip_param(self.w * ll, "ll") From 1f8295b83907fd06efdefea438282f87700fd311 Mon Sep 17 00:00:00 2001 From: michalk8 Date: Thu, 24 Sep 2020 22:47:38 +0200 Subject: [PATCH 10/12] Add utils, start sparse-no-dask --- batchglm/models/base/input.py | 26 ++++++++++----------- batchglm/models/base_glm/input.py | 28 ++++++++++++----------- batchglm/train/numpy/base_glm/external.py | 1 + batchglm/train/numpy/base_glm/vars.py | 5 ++-- batchglm/train/numpy/utils/__init__.py | 1 + batchglm/train/numpy/utils/external.py | 1 + batchglm/train/numpy/utils/utils.py | 19 +++++++++++++++ 7 files changed, 52 insertions(+), 29 deletions(-) create mode 100644 batchglm/train/numpy/utils/__init__.py create mode 100644 batchglm/train/numpy/utils/external.py create mode 100644 batchglm/train/numpy/utils/utils.py diff --git a/batchglm/models/base/input.py b/batchglm/models/base/input.py index 9597b732..1fe0711e 100644 --- a/batchglm/models/base/input.py +++ b/batchglm/models/base/input.py @@ -56,7 +56,7 @@ def __init__( """ self.observations = observation_names self.features = feature_names - self.w = weights + self.w = None if isinstance(data, np.ndarray) or \ isinstance(data, scipy.sparse.csr_matrix) or \ @@ -72,9 +72,11 @@ def __init__( else: raise ValueError("type of data %s not recognized" % type(data)) + self._cast_type = cast_dtype if cast_dtype is not None else self.x.dtype + self._cast_type_float = np.float32 if not issubclass(type(self._cast_type), np.floating) else self._cast_type + if self.w is None: - self.w = np.ones(self.x.shape[0], - dtype=np.float32 if not issubclass(type(self.x.dtype.type), np.floating) else self.x.dtype) + self.w = np.ones(self.x.shape[0], dtype=self._cast_type_float) if scipy.sparse.issparse(self.w): self.w = self.w.toarray() @@ -96,25 +98,21 @@ def __init__( # Need to wrap dask around the COO matrix version of the sparse package if matrix is sparse. if scipy.sparse.issparse(self.x): self.x = dask.array.from_array( - sparse.COO.from_scipy_sparse( - self.x.astype(cast_dtype if cast_dtype is not None else self.x.dtype) - ), + sparse.COO.from_scipy_sparse(self.x.astype(self._cast_type)), chunks=(chunk_size_cells, chunk_size_genes), asarray=False ) else: self.x = dask.array.from_array( - self.x.astype(cast_dtype if cast_dtype is not None else self.x.dtype), + self.x.astype(self._cast_type), chunks=(chunk_size_cells, chunk_size_genes), ) if isinstance(self.w, dask.array.core.Array): self.w = self.w.compute() - self.w = dask.array.from_array( - self.w.astype(cast_dtype if cast_dtype is not None else self.w.dtype), - chunks=(chunk_size_cells, 1), - ) + self.w = dask.array.from_array(self.w.astype(self._cast_type_float), chunks=(chunk_size_cells, 1)) else: + # TODO: fix this if scipy.sparse.issparse(self.x): raise TypeError(f"For sparse matrices, only dask is supported.") @@ -122,9 +120,9 @@ def __init__( self.x = self.x.compute() if isinstance(self.w, dask.array.core.Array): self.w = self.w.compute() - if cast_dtype is not None: - self.x = self.x.astype(cast_dtype) - self.w = self.w.astype(cast_dtype) + + self.x = self.x.astype(self._cast_type) + self.w = self.w.astype(self._cast_type) self._feature_allzero = np.sum(self.x, axis=0) == 0 self.chunk_size_cells = chunk_size_cells diff --git a/batchglm/models/base_glm/input.py b/batchglm/models/base_glm/input.py index bf497433..0fcf14c4 100644 --- a/batchglm/models/base_glm/input.py +++ b/batchglm/models/base_glm/input.py @@ -145,21 +145,23 @@ def __init__( self._loc_names = loc_names self._scale_names = scale_names + self.size_factors = None + if size_factors is not None: - if len(size_factors.shape) == 1: - size_factors = np.expand_dims(np.asarray(size_factors), axis=-1) - elif len(size_factors.shape) == 2: - pass + size_factors = np.asarray(size_factors) + + if size_factors.ndim == 1: + size_factors = size_factors.reshape((-1, 1)) + if size_factors.ndim != 2: + raise ValueError("received size factors with dimension=%i" % size_factors.ndim) + + if as_dask: + self.size_factors = dask.array.from_array( + size_factors.astype(cast_dtype if cast_dtype is not None else self.x.dtype), + chunks=(chunk_size_cells, 1), + ) else: - raise ValueError("received size factors with dimension=%i" % len(size_factors.shape)) - if as_dask: - self.size_factors = dask.array.from_array( - size_factors.astype(cast_dtype if cast_dtype is not None else self.x.dtype), - chunks=(chunk_size_cells, 1), - ) if size_factors is not None else None - else: - self.size_factors = size_factors.astype(cast_dtype if cast_dtype is not None else self.x.dtype) \ - if size_factors is not None else None + self.size_factors = size_factors.astype(cast_dtype if cast_dtype is not None else self.x.dtype) @property def design_loc_names(self): diff --git a/batchglm/train/numpy/base_glm/external.py b/batchglm/train/numpy/base_glm/external.py index 0b57b5a4..182dee30 100644 --- a/batchglm/train/numpy/base_glm/external.py +++ b/batchglm/train/numpy/base_glm/external.py @@ -1,4 +1,5 @@ from batchglm.models.base_glm import InputDataGLM, _ModelGLM, _EstimatorGLM from batchglm.utils.linalg import groupwise_solve_lm +from batchglm.train.numpy.utils import maybe_compute, isdask from batchglm import pkg_constants \ No newline at end of file diff --git a/batchglm/train/numpy/base_glm/vars.py b/batchglm/train/numpy/base_glm/vars.py index d4244ab2..ad8f19a3 100644 --- a/batchglm/train/numpy/base_glm/vars.py +++ b/batchglm/train/numpy/base_glm/vars.py @@ -4,6 +4,8 @@ import numpy as np import abc +from .external import isdask + class ModelVarsGlm: """ @@ -45,8 +47,7 @@ def __init__( init_b_clipped = self.np_clip_param(np.asarray(init_b, dtype=dtype), "b_var") self.params = np.concatenate([init_a_clipped, init_b_clipped], axis=0) - # TODO: being not a dask array breaks a lot of things - if True or isinstance(constraints_loc, dask.array.core.Array): + if isdask(constraints_loc): self.params = dask.array.from_array(self.params, chunks=(1000, chunk_size_genes)) self.npar_a = init_a_clipped.shape[0] diff --git a/batchglm/train/numpy/utils/__init__.py b/batchglm/train/numpy/utils/__init__.py new file mode 100644 index 00000000..879e3422 --- /dev/null +++ b/batchglm/train/numpy/utils/__init__.py @@ -0,0 +1 @@ +from .utils import maybe_compute, isdask \ No newline at end of file diff --git a/batchglm/train/numpy/utils/external.py b/batchglm/train/numpy/utils/external.py new file mode 100644 index 00000000..0ab57039 --- /dev/null +++ b/batchglm/train/numpy/utils/external.py @@ -0,0 +1 @@ +from batchglm.types import ArrayLike \ No newline at end of file diff --git a/batchglm/train/numpy/utils/utils.py b/batchglm/train/numpy/utils/utils.py new file mode 100644 index 00000000..f28a21ef --- /dev/null +++ b/batchglm/train/numpy/utils/utils.py @@ -0,0 +1,19 @@ +from typing import Union, Optional + +import numpy as np +from scipy.sparse import spmatrix +import dask.array + +from .external import ArrayLike + + +def maybe_compute(array: Optional[Union[ArrayLike]]) -> Optional[Union[np.ndarray, spmatrix]]: + if array is None: + return None + if isdask(array): + return array.compute() + return array + + +def isdask(array: Optional[ArrayLike]) -> bool: + return isinstance(array, dask.array.core.Array) From 570244b422f8738cb8b816b41e2aad62f7a84ffe Mon Sep 17 00:00:00 2001 From: michalk8 Date: Thu, 24 Sep 2020 23:12:35 +0200 Subject: [PATCH 11/12] Start typing --- batchglm/models/base/estimator.py | 29 +++++++++------------ batchglm/models/base/external.py | 1 + batchglm/models/base/input.py | 30 ++++++++++------------ batchglm/models/base/simulator.py | 8 +++--- batchglm/train/numpy/base_glm/estimator.py | 2 +- batchglm/train/numpy/utils/utils.py | 4 +-- batchglm/types.py | 1 + 7 files changed, 33 insertions(+), 42 deletions(-) diff --git a/batchglm/models/base/estimator.py b/batchglm/models/base/estimator.py index af1cae3b..5107b042 100644 --- a/batchglm/models/base/estimator.py +++ b/batchglm/models/base/estimator.py @@ -1,4 +1,6 @@ import abc +from abc import ABC + import dask from enum import Enum import logging @@ -7,25 +9,25 @@ import pprint import sys -try: - import anndata -except ImportError: - anndata = None - from .input import InputDataBase from .model import _ModelBase +from .external import maybe_compute logger = logging.getLogger(__name__) +# TODO: why abc.Meta instead of abc.ABC class _EstimatorBase(metaclass=abc.ABCMeta): r""" Estimator base class """ + # TODO: why are these here? model: _ModelBase _loss: np.ndarray _jacobian: np.ndarray + # TODO: better enums (use the pretty-printing like in CellRank) + # TODO: don't use nested classes class TrainingStrategy(Enum): AUTO = None @@ -81,18 +83,12 @@ def w(self) -> np.ndarray: return self.input_data.w @property - def a_var(self): - if isinstance(self.model.a_var, dask.array.core.Array): - return self.model.a_var.compute() - else: - return self.model.a_var + def a_var(self) -> np.ndarray: + return maybe_compute(self.model.a_var) @property def b_var(self) -> np.ndarray: - if isinstance(self.model.b_var, dask.array.core.Array): - return self.model.b_var.compute() - else: - return self.model.b_var + return maybe_compute(self.model.b_var) @abc.abstractmethod def initialize(self, **kwargs): @@ -106,6 +102,7 @@ def train_sequence( training_strategy, **kwargs ): + # TODO: better enums (use the pretty-printing like in CellRank) if isinstance(training_strategy, Enum): training_strategy = training_strategy.value elif isinstance(training_strategy, str): @@ -250,8 +247,6 @@ def _plot_coef_vs_ref( if return_axs: return axs - else: - return def _plot_deviation( self, @@ -317,7 +312,7 @@ def _plot_deviation( return -class EstimatorBaseTyping(_EstimatorBase): +class EstimatorBaseTyping(_EstimatorBase, ABC): r""" Estimator base class used for typing in other packages. """ diff --git a/batchglm/models/base/external.py b/batchglm/models/base/external.py index 996552c8..bfa5bb4e 100644 --- a/batchglm/models/base/external.py +++ b/batchglm/models/base/external.py @@ -1,3 +1,4 @@ import batchglm.pkg_constants as pkg_constants import batchglm.data as data_utils import batchglm.types as types +from batchglm.train.numpy.utils import maybe_compute diff --git a/batchglm/models/base/input.py b/batchglm/models/base/input.py index 1fe0711e..f7ea05d2 100644 --- a/batchglm/models/base/input.py +++ b/batchglm/models/base/input.py @@ -3,9 +3,10 @@ import numpy as np import scipy.sparse import sparse -from typing import List, Union, Optional -from .external import types as T +from typing import List, Union, Optional, Tuple +from .external import types as T, maybe_compute +# TODO: make this nicer try: import anndata try: @@ -84,6 +85,7 @@ def __init__( self.w = self.w.reshape((-1, 1)) # sanity checks + # TODO: don't use asserts assert self.w.shape == (self.x.shape[0], 1), "invalid weight shape %s" % self.w.shape assert issubclass(self.w.dtype.type, np.floating), "invalid weight type %s" % self.w.dtype @@ -92,9 +94,10 @@ def __init__( if self.features is not None: assert len(self.features) == self.x.shape[1] + self.x = maybe_compute(self.x) + self.w = maybe_compute(self.w) + if as_dask: - if isinstance(self.x, dask.array.core.Array): - self.x = self.x.compute() # Need to wrap dask around the COO matrix version of the sparse package if matrix is sparse. if scipy.sparse.issparse(self.x): self.x = dask.array.from_array( @@ -108,19 +111,12 @@ def __init__( chunks=(chunk_size_cells, chunk_size_genes), ) - if isinstance(self.w, dask.array.core.Array): - self.w = self.w.compute() self.w = dask.array.from_array(self.w.astype(self._cast_type_float), chunks=(chunk_size_cells, 1)) else: # TODO: fix this if scipy.sparse.issparse(self.x): raise TypeError(f"For sparse matrices, only dask is supported.") - if isinstance(self.x, dask.array.core.Array): - self.x = self.x.compute() - if isinstance(self.w, dask.array.core.Array): - self.w = self.w.compute() - self.x = self.x.astype(self._cast_type) self.w = self.w.astype(self._cast_type) @@ -129,27 +125,27 @@ def __init__( self.chunk_size_genes = chunk_size_genes @property - def num_observations(self): + def num_observations(self) -> int: return self.x.shape[0] @property - def num_features(self): + def num_features(self) -> int: return self.x.shape[1] @property - def feature_isnonzero(self): + def feature_isnonzero(self) -> np.ndarray: return ~self._feature_allzero @property - def feature_isallzero(self): + def feature_isallzero(self) -> np.ndarray: return self._feature_allzero - def fetch_x_dense(self, idx): + def fetch_x_dense(self, idx: T.IndexLike) -> np.ndarray: assert isinstance(self.x, np.ndarray), "tried to fetch dense from non ndarray" return self.x[idx, :] - def fetch_x_sparse(self, idx): + def fetch_x_sparse(self, idx: T.IndexLike) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: assert isinstance(self.x, scipy.sparse.csr_matrix), "tried to fetch sparse from non csr_matrix" data = self.x[idx, :] diff --git a/batchglm/models/base/simulator.py b/batchglm/models/base/simulator.py index 095aced8..2ad7ed6e 100644 --- a/batchglm/models/base/simulator.py +++ b/batchglm/models/base/simulator.py @@ -1,6 +1,5 @@ import abc import dask.array -import os import logging import numpy as np @@ -11,6 +10,7 @@ from .input import InputDataBase from .model import _ModelBase +from .external import maybe_compute logger = logging.getLogger(__name__) @@ -63,9 +63,7 @@ def generate_params(self, *args, **kwargs): """ pass + # TODO: computed property (self.input_data.x should not change)? @property def x(self) -> np.ndarray: - if isinstance(self.input_data.x, dask.array.core.Array): - return self.input_data.x.compute() - else: - return self.input_data.x + return maybe_compute(self.input_data.x) diff --git a/batchglm/train/numpy/base_glm/estimator.py b/batchglm/train/numpy/base_glm/estimator.py index d0476542..2d581c9b 100644 --- a/batchglm/train/numpy/base_glm/estimator.py +++ b/batchglm/train/numpy/base_glm/estimator.py @@ -64,7 +64,7 @@ def train( of the location model is tracked with self.model.converged. This is re-set after a scale model update, as this convergence only holds conditioned on a particular scale model value. Full convergence of a feature wise model is evaluated after each scale model update: If the loss function based - convergence criterium holds across the cumulative updates of the sequence of location updates and last scale + convergence criterion holds across the cumulative updates of the sequence of location updates and last scale model update, the feature is considered converged. For this, the loss value at the last scale model update is save in ll_last_b_update. Full convergence is saved in fully_converged. diff --git a/batchglm/train/numpy/utils/utils.py b/batchglm/train/numpy/utils/utils.py index f28a21ef..c95b41a5 100644 --- a/batchglm/train/numpy/utils/utils.py +++ b/batchglm/train/numpy/utils/utils.py @@ -7,12 +7,12 @@ from .external import ArrayLike -def maybe_compute(array: Optional[Union[ArrayLike]]) -> Optional[Union[np.ndarray, spmatrix]]: +def maybe_compute(array: Optional[Union[ArrayLike]], copy: bool = False) -> Optional[Union[np.ndarray, spmatrix]]: if array is None: return None if isdask(array): return array.compute() - return array + return array.copy() if copy else array def isdask(array: Optional[ArrayLike]) -> bool: diff --git a/batchglm/types.py b/batchglm/types.py index e3c8dd10..31d231f7 100644 --- a/batchglm/types.py +++ b/batchglm/types.py @@ -10,4 +10,5 @@ AnnData = TypeVar("AnnData") ArrayLike = Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array] +IndexLike = Union[np.ndarray, tuple, list, int] # not exhaustive InputType = Union[ArrayLike, AnnData, "InputDataBase"] From ff81b681b6ae7b9463d9bdb9fc35b01ee8466747 Mon Sep 17 00:00:00 2001 From: michalk8 Date: Thu, 24 Sep 2020 23:27:19 +0200 Subject: [PATCH 12/12] Finish base typing (mostly) --- batchglm/models/base/estimator.py | 69 +++++++++++++++++-------------- batchglm/models/base/input.py | 2 +- batchglm/models/base/model.py | 19 ++++----- batchglm/models/base/simulator.py | 15 +++---- batchglm/types.py | 4 +- 5 files changed, 53 insertions(+), 56 deletions(-) diff --git a/batchglm/models/base/estimator.py b/batchglm/models/base/estimator.py index 5107b042..f085d8a2 100644 --- a/batchglm/models/base/estimator.py +++ b/batchglm/models/base/estimator.py @@ -1,8 +1,10 @@ import abc from abc import ABC +from enum import Enum +from pathlib import Path +from typing import Optional, Union import dask -from enum import Enum import logging import numpy as np import pandas as pd @@ -11,7 +13,7 @@ from .input import InputDataBase from .model import _ModelBase -from .external import maybe_compute +from .external import maybe_compute, types as T logger = logging.getLogger(__name__) @@ -23,7 +25,7 @@ class _EstimatorBase(metaclass=abc.ABCMeta): """ # TODO: why are these here? model: _ModelBase - _loss: np.ndarray + _loss: np.ndarray # is this really an array? _jacobian: np.ndarray # TODO: better enums (use the pretty-printing like in CellRank) @@ -46,16 +48,17 @@ def __init__( self._error_codes = None self._niter = None + # TODO: type @property def error_codes(self): return self._error_codes @property - def niter(self): + def niter(self) -> int: return self._niter @property - def loss(self): + def loss(self) -> np.ndarray: return self._loss @property @@ -63,13 +66,14 @@ def log_likelihood(self): return self._log_likelihood @property - def jacobian(self): + def jacobian(self) -> np.ndarray: return self._jacobian @property - def hessian(self): + def hessian(self) -> np.ndarray: return self._hessian + # TODO: type @property def fisher_inv(self): return self._fisher_inv @@ -97,6 +101,8 @@ def initialize(self, **kwargs): """ pass + # TODO: type training strategy + # TODO: docs def train_sequence( self, training_strategy, @@ -118,6 +124,7 @@ def train_sequence( if np.any([x in list(d.keys()) for x in list(kwargs.keys())]): d = dict([(x, y) for x, y in d.items() if x not in list(kwargs.keys())]) for x in [xx for xx in list(d.keys()) if xx in list(kwargs.keys())]: + # TODO: don't use sys sys.stdout.write( "overrding %s from training strategy with value %s with new value %s\n" % (x, str(d[x]), str(kwargs[x])) @@ -140,20 +147,21 @@ def finalize(self, **kwargs): """ pass + # TODO: make static, not use model? def _plot_coef_vs_ref( self, - true_values: np.ndarray, - estim_values: np.ndarray, - size=1, - log=False, - save=None, - show=True, - ncols=5, - row_gap=0.3, - col_gap=0.25, - title=None, - return_axs=False - ): + true_values: T.ArrayLike, + estim_values: T.ArrayLike, + size: float = 1, + log: bool = False, + save: Optional[Union[str, Path]] = None, + show: bool = True, + ncols: int = 5, + row_gap: float = 0.3, + col_gap: float = 0.25, + title: Optional[str] = None, + return_axs: bool = False + ) -> Optional['matplotlib.axes.Axes']: """ Plot estimated coefficients against reference (true) coefficients. @@ -175,10 +183,8 @@ def _plot_coef_vs_ref( from matplotlib import gridspec from matplotlib import rcParams - if isinstance(true_values, dask.array.core.Array): - true_values = true_values.compute() - if isinstance(estim_values, dask.array.core.Array): - estim_values = estim_values.compute() + true_values = maybe_compute(true_values) + estim_values = maybe_compute(estim_values) plt.ioff() @@ -248,15 +254,16 @@ def _plot_coef_vs_ref( if return_axs: return axs + # TODO: make static, not use model? def _plot_deviation( self, - true_values: np.ndarray, - estim_values: np.ndarray, - save=None, - show=True, - title=None, - return_axs=False - ): + true_values: T.ArrayLike, + estim_values: T.ArrayLike, + save: Optional[Union[str, Path]] = None, + show: bool = True, + title: Optional[str] = None, + return_axs: bool = False + ) -> Optional['matplotlib.axes.Axes']: """ Plot estimated coefficients against reference (true) coefficients. @@ -308,8 +315,6 @@ def _plot_deviation( if return_axs: return ax - else: - return class EstimatorBaseTyping(_EstimatorBase, ABC): diff --git a/batchglm/models/base/input.py b/batchglm/models/base/input.py index f7ea05d2..f300f0af 100644 --- a/batchglm/models/base/input.py +++ b/batchglm/models/base/input.py @@ -6,7 +6,7 @@ from typing import List, Union, Optional, Tuple from .external import types as T, maybe_compute -# TODO: make this nicer +# TODO: make this nicer (i.e. top-level import + function + warning if None) try: import anndata try: diff --git a/batchglm/models/base/model.py b/batchglm/models/base/model.py index a154e99c..243da054 100644 --- a/batchglm/models/base/model.py +++ b/batchglm/models/base/model.py @@ -2,11 +2,8 @@ from typing import Union, Any, Dict, Iterable import logging -try: - import anndata -except ImportError: - anndata = None - +from .external import types as T +from .input import InputDataBase logger = logging.getLogger(__name__) @@ -18,16 +15,16 @@ class _ModelBase(metaclass=abc.ABCMeta): def __init__( self, - input_data + input_data: InputDataBase ): self.input_data = input_data @property - def x(self): + def x(self) -> T.ArrayLike: return self.input_data.x @property - def w(self): + def w(self) -> T.ArrayLike: return self.input_data.w def get(self, key: Union[str, Iterable]) -> Union[Any, Dict[str, Any]]: @@ -42,11 +39,11 @@ def get(self, key: Union[str, Iterable]) -> Union[Any, Dict[str, Any]]: elif isinstance(key, Iterable): return {s: self.__getattribute__(s) for s in key} - def __getitem__(self, item): + def __getitem__(self, item) -> Union[Any, Dict[str, Any]]: return self.get(item) - def __str__(self): + def __str__(self) -> str: return self.__class__.__name__ - def __repr__(self): + def __repr__(self) -> str: return self.__str__() diff --git a/batchglm/models/base/simulator.py b/batchglm/models/base/simulator.py index 2ad7ed6e..e504bbd6 100644 --- a/batchglm/models/base/simulator.py +++ b/batchglm/models/base/simulator.py @@ -1,13 +1,7 @@ import abc -import dask.array import logging import numpy as np -try: - import anndata -except ImportError: - anndata = None - from .input import InputDataBase from .model import _ModelBase from .external import maybe_compute @@ -24,6 +18,7 @@ class _SimulatorBase(metaclass=abc.ABCMeta): convention: N features with M observations each => (M, N) matrix """ + # TODO: why? nobs: int nfeatures: int @@ -32,9 +27,9 @@ class _SimulatorBase(metaclass=abc.ABCMeta): def __init__( self, - model, - num_observations, - num_features + model: _ModelBase, + num_observations: int, + num_features: int ): self.nobs = num_observations self.nfeatures = num_features @@ -42,7 +37,7 @@ def __init__( self.input_data = None self.model = model - def generate(self): + def generate(self) -> None: """ First generates the parameter set, then observations random data using these parameters """ diff --git a/batchglm/types.py b/batchglm/types.py index 31d231f7..78b273fc 100644 --- a/batchglm/types.py +++ b/batchglm/types.py @@ -1,14 +1,14 @@ from typing import TypeVar, Union -import scipy import dask import numpy as np +from scipy.sparse import spmatrix try: from anndata import AnnData except ImportError: AnnData = TypeVar("AnnData") -ArrayLike = Union[np.ndarray, scipy.sparse.csr_matrix, dask.array.core.Array] +ArrayLike = Union[np.ndarray, spmatrix, dask.array.core.Array] IndexLike = Union[np.ndarray, tuple, list, int] # not exhaustive InputType = Union[ArrayLike, AnnData, "InputDataBase"]