diff --git a/batchglm/models/base/estimator.py b/batchglm/models/base/estimator.py index dc1e3bf1..f085d8a2 100644 --- a/batchglm/models/base/estimator.py +++ b/batchglm/models/base/estimator.py @@ -1,31 +1,35 @@ import abc -import dask +from abc import ABC from enum import Enum +from pathlib import Path +from typing import Optional, Union + +import dask import logging import numpy as np import pandas as pd import pprint import sys -try: - import anndata -except ImportError: - anndata = None - from .input import InputDataBase from .model import _ModelBase +from .external import maybe_compute, types as T logger = logging.getLogger(__name__) +# TODO: why abc.Meta instead of abc.ABC class _EstimatorBase(metaclass=abc.ABCMeta): r""" Estimator base class """ + # TODO: why are these here? model: _ModelBase - _loss: np.ndarray + _loss: np.ndarray # is this really an array? _jacobian: np.ndarray + # TODO: better enums (use the pretty-printing like in CellRank) + # TODO: don't use nested classes class TrainingStrategy(Enum): AUTO = None @@ -44,16 +48,17 @@ def __init__( self._error_codes = None self._niter = None + # TODO: type @property def error_codes(self): return self._error_codes @property - def niter(self): + def niter(self) -> int: return self._niter @property - def loss(self): + def loss(self) -> np.ndarray: return self._loss @property @@ -61,13 +66,14 @@ def log_likelihood(self): return self._log_likelihood @property - def jacobian(self): + def jacobian(self) -> np.ndarray: return self._jacobian @property - def hessian(self): + def hessian(self) -> np.ndarray: return self._hessian + # TODO: type @property def fisher_inv(self): return self._fisher_inv @@ -77,18 +83,16 @@ def x(self) -> np.ndarray: return self.input_data.x @property - def a_var(self): - if isinstance(self.model.a_var, dask.array.core.Array): - return self.model.a_var.compute() - else: - return self.model.a_var + def w(self) -> np.ndarray: + return self.input_data.w + + @property + def a_var(self) -> np.ndarray: + return maybe_compute(self.model.a_var) @property def b_var(self) -> np.ndarray: - if isinstance(self.model.b_var, dask.array.core.Array): - return self.model.b_var.compute() - else: - return self.model.b_var + return maybe_compute(self.model.b_var) @abc.abstractmethod def initialize(self, **kwargs): @@ -97,11 +101,14 @@ def initialize(self, **kwargs): """ pass + # TODO: type training strategy + # TODO: docs def train_sequence( self, training_strategy, **kwargs ): + # TODO: better enums (use the pretty-printing like in CellRank) if isinstance(training_strategy, Enum): training_strategy = training_strategy.value elif isinstance(training_strategy, str): @@ -117,6 +124,7 @@ def train_sequence( if np.any([x in list(d.keys()) for x in list(kwargs.keys())]): d = dict([(x, y) for x, y in d.items() if x not in list(kwargs.keys())]) for x in [xx for xx in list(d.keys()) if xx in list(kwargs.keys())]: + # TODO: don't use sys sys.stdout.write( "overrding %s from training strategy with value %s with new value %s\n" % (x, str(d[x]), str(kwargs[x])) @@ -139,20 +147,21 @@ def finalize(self, **kwargs): """ pass + # TODO: make static, not use model? def _plot_coef_vs_ref( self, - true_values: np.ndarray, - estim_values: np.ndarray, - size=1, - log=False, - save=None, - show=True, - ncols=5, - row_gap=0.3, - col_gap=0.25, - title=None, - return_axs=False - ): + true_values: T.ArrayLike, + estim_values: T.ArrayLike, + size: float = 1, + log: bool = False, + save: Optional[Union[str, Path]] = None, + show: bool = True, + ncols: int = 5, + row_gap: float = 0.3, + col_gap: float = 0.25, + title: Optional[str] = None, + return_axs: bool = False + ) -> Optional['matplotlib.axes.Axes']: """ Plot estimated coefficients against reference (true) coefficients. @@ -174,10 +183,8 @@ def _plot_coef_vs_ref( from matplotlib import gridspec from matplotlib import rcParams - if isinstance(true_values, dask.array.core.Array): - true_values = true_values.compute() - if isinstance(estim_values, dask.array.core.Array): - estim_values = estim_values.compute() + true_values = maybe_compute(true_values) + estim_values = maybe_compute(estim_values) plt.ioff() @@ -246,18 +253,17 @@ def _plot_coef_vs_ref( if return_axs: return axs - else: - return + # TODO: make static, not use model? def _plot_deviation( self, - true_values: np.ndarray, - estim_values: np.ndarray, - save=None, - show=True, - title=None, - return_axs=False - ): + true_values: T.ArrayLike, + estim_values: T.ArrayLike, + save: Optional[Union[str, Path]] = None, + show: bool = True, + title: Optional[str] = None, + return_axs: bool = False + ) -> Optional['matplotlib.axes.Axes']: """ Plot estimated coefficients against reference (true) coefficients. @@ -309,11 +315,9 @@ def _plot_deviation( if return_axs: return ax - else: - return -class EstimatorBaseTyping(_EstimatorBase): +class EstimatorBaseTyping(_EstimatorBase, ABC): r""" Estimator base class used for typing in other packages. """ diff --git a/batchglm/models/base/external.py b/batchglm/models/base/external.py index 75015489..bfa5bb4e 100644 --- a/batchglm/models/base/external.py +++ b/batchglm/models/base/external.py @@ -1,2 +1,4 @@ import batchglm.pkg_constants as pkg_constants import batchglm.data as data_utils +import batchglm.types as types +from batchglm.train.numpy.utils import maybe_compute diff --git a/batchglm/models/base/input.py b/batchglm/models/base/input.py index 20fddd65..f300f0af 100644 --- a/batchglm/models/base/input.py +++ b/batchglm/models/base/input.py @@ -3,8 +3,10 @@ import numpy as np import scipy.sparse import sparse -from typing import List +from typing import List, Union, Optional, Tuple +from .external import types as T, maybe_compute +# TODO: make this nicer (i.e. top-level import + function + warning if None) try: import anndata try: @@ -29,13 +31,14 @@ class InputDataBase: def __init__( self, - data, - observation_names=None, - feature_names=None, + data: T.InputType, + weights: Optional[Union[T.ArrayLike, str]] = None, + observation_names: Optional[List[str]] = None, + feature_names: Optional[List[str]] = None, chunk_size_cells: int = 100000, chunk_size_genes: int = 100, as_dask: bool = True, - cast_dtype=None + cast_dtype: Optional[np.dtype] = None ): """ Create a new InputData object. @@ -46,6 +49,7 @@ def __init__( - np.ndarray: NumPy array containing the raw data - anndata.AnnData: AnnData object containing the count data and optional the design models stored as data.obsm[design_loc] and data.obsm[design_scale] + :param weights: (optional) observation weights :param observation_names: (optional) names of the observations. :param feature_names: (optional) names of the features. :param cast_dtype: data type of all data; should be either float32 or float64 @@ -53,66 +57,95 @@ def __init__( """ self.observations = observation_names self.features = feature_names + self.w = None + if isinstance(data, np.ndarray) or \ isinstance(data, scipy.sparse.csr_matrix) or \ isinstance(data, dask.array.core.Array): self.x = data elif isinstance(data, anndata.AnnData) or isinstance(data, Raw): self.x = data.X + if isinstance(weights, str): + self.w = data.obs[weights].values elif isinstance(data, InputDataBase): self.x = data.x + self.w = data.w else: raise ValueError("type of data %s not recognized" % type(data)) + self._cast_type = cast_dtype if cast_dtype is not None else self.x.dtype + self._cast_type_float = np.float32 if not issubclass(type(self._cast_type), np.floating) else self._cast_type + + if self.w is None: + self.w = np.ones(self.x.shape[0], dtype=self._cast_type_float) + + if scipy.sparse.issparse(self.w): + self.w = self.w.toarray() + if self.w.ndim == 1: + self.w = self.w.reshape((-1, 1)) + + # sanity checks + # TODO: don't use asserts + assert self.w.shape == (self.x.shape[0], 1), "invalid weight shape %s" % self.w.shape + assert issubclass(self.w.dtype.type, np.floating), "invalid weight type %s" % self.w.dtype + + if self.observations is not None: + assert len(self.observations) == self.x.shape[0] + if self.features is not None: + assert len(self.features) == self.x.shape[1] + + self.x = maybe_compute(self.x) + self.w = maybe_compute(self.w) + if as_dask: - if isinstance(self.x, dask.array.core.Array): - self.x = self.x.compute() # Need to wrap dask around the COO matrix version of the sparse package if matrix is sparse. - if isinstance(self.x, scipy.sparse.spmatrix): + if scipy.sparse.issparse(self.x): self.x = dask.array.from_array( - sparse.COO.from_scipy_sparse( - self.x.astype(cast_dtype if cast_dtype is not None else self.x.dtype) - ), + sparse.COO.from_scipy_sparse(self.x.astype(self._cast_type)), chunks=(chunk_size_cells, chunk_size_genes), asarray=False ) else: self.x = dask.array.from_array( - self.x.astype(cast_dtype if cast_dtype is not None else self.x.dtype), + self.x.astype(self._cast_type), chunks=(chunk_size_cells, chunk_size_genes), ) + + self.w = dask.array.from_array(self.w.astype(self._cast_type_float), chunks=(chunk_size_cells, 1)) else: - if isinstance(self.x, dask.array.core.Array): - self.x = self.x.compute() - if cast_dtype is not None: - self.x = self.x.astype(cast_dtype) + # TODO: fix this + if scipy.sparse.issparse(self.x): + raise TypeError(f"For sparse matrices, only dask is supported.") + + self.x = self.x.astype(self._cast_type) + self.w = self.w.astype(self._cast_type) self._feature_allzero = np.sum(self.x, axis=0) == 0 self.chunk_size_cells = chunk_size_cells self.chunk_size_genes = chunk_size_genes @property - def num_observations(self): + def num_observations(self) -> int: return self.x.shape[0] @property - def num_features(self): + def num_features(self) -> int: return self.x.shape[1] @property - def feature_isnonzero(self): + def feature_isnonzero(self) -> np.ndarray: return ~self._feature_allzero @property - def feature_isallzero(self): + def feature_isallzero(self) -> np.ndarray: return self._feature_allzero - def fetch_x_dense(self, idx): + def fetch_x_dense(self, idx: T.IndexLike) -> np.ndarray: assert isinstance(self.x, np.ndarray), "tried to fetch dense from non ndarray" return self.x[idx, :] - def fetch_x_sparse(self, idx): + def fetch_x_sparse(self, idx: T.IndexLike) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: assert isinstance(self.x, scipy.sparse.csr_matrix), "tried to fetch sparse from non csr_matrix" data = self.x[idx, :] diff --git a/batchglm/models/base/model.py b/batchglm/models/base/model.py index 92ab97bc..243da054 100644 --- a/batchglm/models/base/model.py +++ b/batchglm/models/base/model.py @@ -2,11 +2,8 @@ from typing import Union, Any, Dict, Iterable import logging -try: - import anndata -except ImportError: - anndata = None - +from .external import types as T +from .input import InputDataBase logger = logging.getLogger(__name__) @@ -18,14 +15,18 @@ class _ModelBase(metaclass=abc.ABCMeta): def __init__( self, - input_data + input_data: InputDataBase ): self.input_data = input_data @property - def x(self): + def x(self) -> T.ArrayLike: return self.input_data.x + @property + def w(self) -> T.ArrayLike: + return self.input_data.w + def get(self, key: Union[str, Iterable]) -> Union[Any, Dict[str, Any]]: """ Returns the values specified by key. @@ -38,8 +39,11 @@ def get(self, key: Union[str, Iterable]) -> Union[Any, Dict[str, Any]]: elif isinstance(key, Iterable): return {s: self.__getattribute__(s) for s in key} - def __getitem__(self, item): + def __getitem__(self, item) -> Union[Any, Dict[str, Any]]: return self.get(item) - def __repr__(self): + def __str__(self) -> str: + return self.__class__.__name__ + + def __repr__(self) -> str: return self.__str__() diff --git a/batchglm/models/base/simulator.py b/batchglm/models/base/simulator.py index 095aced8..e504bbd6 100644 --- a/batchglm/models/base/simulator.py +++ b/batchglm/models/base/simulator.py @@ -1,16 +1,10 @@ import abc -import dask.array -import os import logging import numpy as np -try: - import anndata -except ImportError: - anndata = None - from .input import InputDataBase from .model import _ModelBase +from .external import maybe_compute logger = logging.getLogger(__name__) @@ -24,6 +18,7 @@ class _SimulatorBase(metaclass=abc.ABCMeta): convention: N features with M observations each => (M, N) matrix """ + # TODO: why? nobs: int nfeatures: int @@ -32,9 +27,9 @@ class _SimulatorBase(metaclass=abc.ABCMeta): def __init__( self, - model, - num_observations, - num_features + model: _ModelBase, + num_observations: int, + num_features: int ): self.nobs = num_observations self.nfeatures = num_features @@ -42,7 +37,7 @@ def __init__( self.input_data = None self.model = model - def generate(self): + def generate(self) -> None: """ First generates the parameter set, then observations random data using these parameters """ @@ -63,9 +58,7 @@ def generate_params(self, *args, **kwargs): """ pass + # TODO: computed property (self.input_data.x should not change)? @property def x(self) -> np.ndarray: - if isinstance(self.input_data.x, dask.array.core.Array): - return self.input_data.x.compute() - else: - return self.input_data.x + return maybe_compute(self.input_data.x) diff --git a/batchglm/models/base_glm/external.py b/batchglm/models/base_glm/external.py index 163181f5..ce6d3233 100644 --- a/batchglm/models/base_glm/external.py +++ b/batchglm/models/base_glm/external.py @@ -4,4 +4,5 @@ from batchglm.models.base import _SimulatorBase import batchglm.data as data_utils -from batchglm.utils.linalg import groupwise_solve_lm \ No newline at end of file +from batchglm.utils.linalg import groupwise_solve_lm +import batchglm.types as types diff --git a/batchglm/models/base_glm/input.py b/batchglm/models/base_glm/input.py index dd27be71..0fcf14c4 100644 --- a/batchglm/models/base_glm/input.py +++ b/batchglm/models/base_glm/input.py @@ -7,11 +7,11 @@ import numpy as np import pandas as pd import patsy -import scipy.sparse -from typing import Union +from typing import Union, Optional from .utils import parse_constraints, parse_design from .external import InputDataBase +from .external import types as T class InputDataGLM(InputDataBase): @@ -25,7 +25,8 @@ class InputDataGLM(InputDataBase): def __init__( self, - data: Union[np.ndarray, anndata.AnnData, scipy.sparse.csr_matrix], + data: T.InputType, + weights: Optional[Union[T.ArrayLike, str]] = None, design_loc: Union[np.ndarray, pd.DataFrame, patsy.design_info.DesignMatrix] = None, design_loc_names: Union[list, np.ndarray] = None, design_scale: Union[np.ndarray, pd.DataFrame, patsy.design_info.DesignMatrix] = None, @@ -48,6 +49,7 @@ def __init__( - np.ndarray: NumPy array containing the raw data - anndata.AnnData: AnnData object containing the count data and optional the design models stored as data.obsm[design_loc] and data.obsm[design_scale] + :param weights: (optional) observation weights :param design_loc: Some matrix format (observations x mean model parameters) The location design model. Optional if already specified in `data` :param design_loc_names: (optional) @@ -83,6 +85,7 @@ def __init__( InputDataBase.__init__( self=self, data=data, + weights=weights, observation_names=observation_names, feature_names=feature_names, chunk_size_cells=chunk_size_cells, @@ -142,21 +145,23 @@ def __init__( self._loc_names = loc_names self._scale_names = scale_names + self.size_factors = None + if size_factors is not None: - if len(size_factors.shape) == 1: - size_factors = np.expand_dims(np.asarray(size_factors), axis=-1) - elif len(size_factors.shape) == 2: - pass + size_factors = np.asarray(size_factors) + + if size_factors.ndim == 1: + size_factors = size_factors.reshape((-1, 1)) + if size_factors.ndim != 2: + raise ValueError("received size factors with dimension=%i" % size_factors.ndim) + + if as_dask: + self.size_factors = dask.array.from_array( + size_factors.astype(cast_dtype if cast_dtype is not None else self.x.dtype), + chunks=(chunk_size_cells, 1), + ) else: - raise ValueError("received size factors with dimension=%i" % len(size_factors.shape)) - if as_dask: - self.size_factors = dask.array.from_array( - size_factors.astype(cast_dtype if cast_dtype is not None else self.x.dtype), - chunks=(chunk_size_cells, 1), - ) if size_factors is not None else None - else: - self.size_factors = size_factors.astype(cast_dtype if cast_dtype is not None else self.x.dtype) \ - if size_factors is not None else None + self.size_factors = size_factors.astype(cast_dtype if cast_dtype is not None else self.x.dtype) @property def design_loc_names(self): diff --git a/batchglm/train/numpy/base_glm/estimator.py b/batchglm/train/numpy/base_glm/estimator.py index 1ad26326..2d581c9b 100644 --- a/batchglm/train/numpy/base_glm/estimator.py +++ b/batchglm/train/numpy/base_glm/estimator.py @@ -9,7 +9,6 @@ import sparse import sys import time -from typing import Tuple from .external import _EstimatorGLM, pkg_constants from .training_strategies import TrainingStrategies @@ -65,7 +64,7 @@ def train( of the location model is tracked with self.model.converged. This is re-set after a scale model update, as this convergence only holds conditioned on a particular scale model value. Full convergence of a feature wise model is evaluated after each scale model update: If the loss function based - convergence criterium holds across the cumulative updates of the sequence of location updates and last scale + convergence criterion holds across the cumulative updates of the sequence of location updates and last scale model update, the feature is considered converged. For this, the loss value at the last scale model update is save in ll_last_b_update. Full convergence is saved in fully_converged. diff --git a/batchglm/train/numpy/base_glm/external.py b/batchglm/train/numpy/base_glm/external.py index 0b57b5a4..182dee30 100644 --- a/batchglm/train/numpy/base_glm/external.py +++ b/batchglm/train/numpy/base_glm/external.py @@ -1,4 +1,5 @@ from batchglm.models.base_glm import InputDataGLM, _ModelGLM, _EstimatorGLM from batchglm.utils.linalg import groupwise_solve_lm +from batchglm.train.numpy.utils import maybe_compute, isdask from batchglm import pkg_constants \ No newline at end of file diff --git a/batchglm/train/numpy/base_glm/model.py b/batchglm/train/numpy/base_glm/model.py index 2f1f13ca..6f785425 100644 --- a/batchglm/train/numpy/base_glm/model.py +++ b/batchglm/train/numpy/base_glm/model.py @@ -263,10 +263,10 @@ def jac_b_j(self, j) -> np.ndarray: # Make sure that dimensionality of sliced array is kept: if isinstance(j, int) or isinstance(j, np.int32) or isinstance(j, np.int64): j = [j] - w = self.jac_weight_b_j(j=j) # (observations x features) - xh = np.matmul(self.design_scale, self.constraints_scale) # (observations x inferred param) - return np.einsum( - 'fob,of->fb', - np.einsum('ob,of->fob', xh, w), - xh - ) + w = self.jac_weight_b_j(j=j) # (observations x features) + xh = np.matmul(self.design_scale, self.constraints_scale) # (observations x inferred param) + return np.einsum( + 'fob,of->fb', + np.einsum('ob,of->fob', xh, w), + xh + ) diff --git a/batchglm/train/numpy/base_glm/vars.py b/batchglm/train/numpy/base_glm/vars.py index afaca423..ad8f19a3 100644 --- a/batchglm/train/numpy/base_glm/vars.py +++ b/batchglm/train/numpy/base_glm/vars.py @@ -1,12 +1,15 @@ +from typing import Union + import dask.array import numpy as np -import scipy.sparse import abc +from .external import isdask + class ModelVarsGlm: """ - Build variables to be optimzed and their constraints. + Build variables to be optimized and their constraints. """ @@ -24,8 +27,8 @@ def __init__( self, init_a: np.ndarray, init_b: np.ndarray, - constraints_loc: np.ndarray, - constraints_scale: np.ndarray, + constraints_loc: Union[np.ndarray, dask.array.core.Array], + constraints_scale: Union[np.ndarray, dask.array.core.Array], chunk_size_genes: int, dtype: str ): @@ -42,13 +45,11 @@ def __init__( init_a_clipped = self.np_clip_param(np.asarray(init_a, dtype=dtype), "a_var") init_b_clipped = self.np_clip_param(np.asarray(init_b, dtype=dtype), "b_var") - self.params = dask.array.from_array(np.concatenate( - [ - init_a_clipped, - init_b_clipped, - ], - axis=0 - ), chunks=(1000, chunk_size_genes)) + + self.params = np.concatenate([init_a_clipped, init_b_clipped], axis=0) + if isdask(constraints_loc): + self.params = dask.array.from_array(self.params, chunks=(1000, chunk_size_genes)) + self.npar_a = init_a_clipped.shape[0] # Properties to follow gene-wise convergence. diff --git a/batchglm/train/numpy/glm_nb/model.py b/batchglm/train/numpy/glm_nb/model.py index 3fd93fe1..170a3262 100644 --- a/batchglm/train/numpy/glm_nb/model.py +++ b/batchglm/train/numpy/glm_nb/model.py @@ -3,7 +3,6 @@ import numpy as np import scipy.sparse import scipy.special -import sparse from .external import Model, ModelIwls, InputDataGLM from .processModel import ProcessModel @@ -41,7 +40,7 @@ def fim_weight_aa(self): :return: observations x features """ - return - self.location * self.scale / (self.scale + self.location) + return - self.w * self.location * self.scale / (self.scale + self.location) @property def ybar(self) -> np.ndarray: @@ -56,7 +55,7 @@ def fim_weight_aa_j(self, j): :return: observations x features """ - return - self.location_j(j=j) * self.scale_j(j=j) / (self.scale_j(j=j) + self.location_j(j=j)) + return - self.w * self.location_j(j=j) * self.scale_j(j=j) / (self.scale_j(j=j) + self.location_j(j=j)) def ybar_j(self, j) -> np.ndarray: """ @@ -89,7 +88,7 @@ def jac_weight_b(self): const1 = scipy.special.digamma(scale_plus_x) - scipy.special.digamma(scale) const2 = - scale_plus_x / r_plus_mu const3 = np.log(scale) + np.ones_like(scale) - np.log(r_plus_mu) - return scale * (const1 + const2 + const3) + return self.w * scale * (const1 + const2 + const3) def jac_weight_b_j(self, j): """ @@ -111,7 +110,7 @@ def jac_weight_b_j(self, j): const1 = scipy.special.digamma(scale_plus_x) - scipy.special.digamma(scale) const2 = - scale_plus_x / r_plus_mu const3 = np.log(scale) + np.ones_like(scale) - np.log(r_plus_mu) - return scale * (const1 + const2 + const3) + return self.w * scale * (const1 + const2 + const3) @property def fim_ab(self) -> np.ndarray: @@ -150,6 +149,7 @@ def hessian_weight_ab(self): scale = self.scale loc = self.location return np.multiply( + self.w, loc * scale, np.asarray(self.x - loc) / np.square(loc + scale) ) @@ -163,7 +163,7 @@ def hessian_weight_aa(self): else: x_by_scale_plus_one = np.asarray(self.x.divide(scale) + np.ones_like(scale)) - return - loc * x_by_scale_plus_one / np.square((loc / scale) + np.ones_like(loc)) + return - self.w * loc * x_by_scale_plus_one / np.square((loc / scale) + np.ones_like(loc)) @property def hessian_weight_bb(self): @@ -176,7 +176,7 @@ def hessian_weight_bb(self): const2 = - scipy.special.digamma(scale) + scale * scipy.special.polygamma(n=1, x=scale) const3 = - loc * scale_plus_x + np.ones_like(scale) * 2. * scale * scale_plus_loc / np.square(scale_plus_loc) const4 = np.log(scale) + np.ones_like(scale) * 2. - np.log(scale_plus_loc) - return scale * (const1 + const2 + const3 + const4) + return self.w * scale * (const1 + const2 + const3 + const4) @property def ll(self): @@ -191,14 +191,24 @@ def ll(self): self.x * (self.eta_loc - log_r_plus_mu) + \ np.multiply(scale, self.eta_scale - log_r_plus_mu) else: - # sparse scipy + # sparse scipy, assuming self.x is also a dask array + + # if not, it can be fixed as follows: + # the inner np.asarray: ... + (self.x.multiply(np.asarray(...))).tocsr() + # is there because dask does not yet support fancy nd indexing + # tocsr because TypeError: 'coo_matrix' object is not subscriptable + # when computing the values in `ll = np.asarray(ll)` + # + # however, the estimator code will be broken, since it calls: self.model.ll_byfeature.compute() + # which requires this property, which will not be a dask array + ll = scipy.special.gammaln(np.asarray(scale + self.x)) - \ scipy.special.gammaln(self.x + np.ones_like(scale)) - \ scipy.special.gammaln(scale) + \ np.asarray(self.x.multiply(self.eta_loc - log_r_plus_mu) + np.multiply(scale, self.eta_scale - log_r_plus_mu)) ll = np.asarray(ll) - return self.np_clip_param(ll, "ll") + return self.np_clip_param(self.w * ll, "ll") def ll_j(self, j): # Make sure that dimensionality of sliced array is kept: @@ -222,10 +232,11 @@ def ll_j(self, j): np.asarray(self.x[:, j].multiply(self.eta_loc_j(j=j) - log_r_plus_mu) + np.multiply(scale, self.eta_scale_j(j=j) - log_r_plus_mu)) ll = np.asarray(ll) - return self.np_clip_param(ll, "ll") + return self.np_clip_param(self.w * ll, "ll") + # TODO: not used def ll_handle(self): - def fun(x, eta_loc, b_var, xh_scale): + def fun(x, w, eta_loc, b_var, xh_scale): eta_scale = np.matmul(xh_scale, b_var) scale = np.exp(eta_scale) loc = np.exp(eta_loc) @@ -239,11 +250,12 @@ def fun(x, eta_loc, b_var, xh_scale): np.multiply(scale, eta_scale - log_r_plus_mu) else: raise ValueError("type x %s not supported" % type(x)) - return self.np_clip_param(ll, "ll") + return self.np_clip_param(w * ll, "ll") return fun + # TODO: not used def jac_b_handle(self): - def fun(x, eta_loc, b_var, xh_scale): + def fun(x, w, eta_loc, b_var, xh_scale): scale = np.exp(b_var) loc = np.exp(eta_loc) scale_plus_x = scale + x @@ -253,6 +265,6 @@ def fun(x, eta_loc, b_var, xh_scale): const1 = scipy.special.digamma(scale_plus_x) - scipy.special.digamma(scale) const2 = - scale_plus_x / r_plus_mu const3 = np.log(scale) + np.ones_like(scale) - np.log(r_plus_mu) - return scale * (const1 + const2 + const3) + return w * scale * (const1 + const2 + const3) return fun diff --git a/batchglm/train/numpy/utils/__init__.py b/batchglm/train/numpy/utils/__init__.py new file mode 100644 index 00000000..879e3422 --- /dev/null +++ b/batchglm/train/numpy/utils/__init__.py @@ -0,0 +1 @@ +from .utils import maybe_compute, isdask \ No newline at end of file diff --git a/batchglm/train/numpy/utils/external.py b/batchglm/train/numpy/utils/external.py new file mode 100644 index 00000000..0ab57039 --- /dev/null +++ b/batchglm/train/numpy/utils/external.py @@ -0,0 +1 @@ +from batchglm.types import ArrayLike \ No newline at end of file diff --git a/batchglm/train/numpy/utils/utils.py b/batchglm/train/numpy/utils/utils.py new file mode 100644 index 00000000..c95b41a5 --- /dev/null +++ b/batchglm/train/numpy/utils/utils.py @@ -0,0 +1,19 @@ +from typing import Union, Optional + +import numpy as np +from scipy.sparse import spmatrix +import dask.array + +from .external import ArrayLike + + +def maybe_compute(array: Optional[Union[ArrayLike]], copy: bool = False) -> Optional[Union[np.ndarray, spmatrix]]: + if array is None: + return None + if isdask(array): + return array.compute() + return array.copy() if copy else array + + +def isdask(array: Optional[ArrayLike]) -> bool: + return isinstance(array, dask.array.core.Array) diff --git a/batchglm/types.py b/batchglm/types.py new file mode 100644 index 00000000..78b273fc --- /dev/null +++ b/batchglm/types.py @@ -0,0 +1,14 @@ +from typing import TypeVar, Union + +import dask +import numpy as np +from scipy.sparse import spmatrix + +try: + from anndata import AnnData +except ImportError: + AnnData = TypeVar("AnnData") + +ArrayLike = Union[np.ndarray, spmatrix, dask.array.core.Array] +IndexLike = Union[np.ndarray, tuple, list, int] # not exhaustive +InputType = Union[ArrayLike, AnnData, "InputDataBase"]