Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weighted model #113

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
100 changes: 52 additions & 48 deletions batchglm/models/base/estimator.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
import abc
import dask
from abc import ABC
from enum import Enum
from pathlib import Path
from typing import Optional, Union

import dask
import logging
import numpy as np
import pandas as pd
import pprint
import sys

try:
import anndata
except ImportError:
anndata = None

from .input import InputDataBase
from .model import _ModelBase
from .external import maybe_compute, types as T

logger = logging.getLogger(__name__)


# TODO: why abc.Meta instead of abc.ABC
class _EstimatorBase(metaclass=abc.ABCMeta):
r"""
Estimator base class
"""
# TODO: why are these here?
model: _ModelBase
_loss: np.ndarray
_loss: np.ndarray # is this really an array?
_jacobian: np.ndarray

# TODO: better enums (use the pretty-printing like in CellRank)
# TODO: don't use nested classes
class TrainingStrategy(Enum):
AUTO = None

Expand All @@ -44,30 +48,32 @@ def __init__(
self._error_codes = None
self._niter = None

# TODO: type
@property
def error_codes(self):
return self._error_codes

@property
def niter(self):
def niter(self) -> int:
return self._niter

@property
def loss(self):
def loss(self) -> np.ndarray:
return self._loss

@property
def log_likelihood(self):
return self._log_likelihood

@property
def jacobian(self):
def jacobian(self) -> np.ndarray:
return self._jacobian

@property
def hessian(self):
def hessian(self) -> np.ndarray:
return self._hessian

# TODO: type
@property
def fisher_inv(self):
return self._fisher_inv
Expand All @@ -77,18 +83,16 @@ def x(self) -> np.ndarray:
return self.input_data.x

@property
def a_var(self):
if isinstance(self.model.a_var, dask.array.core.Array):
return self.model.a_var.compute()
else:
return self.model.a_var
def w(self) -> np.ndarray:
return self.input_data.w

@property
def a_var(self) -> np.ndarray:
return maybe_compute(self.model.a_var)

@property
def b_var(self) -> np.ndarray:
if isinstance(self.model.b_var, dask.array.core.Array):
return self.model.b_var.compute()
else:
return self.model.b_var
return maybe_compute(self.model.b_var)

@abc.abstractmethod
def initialize(self, **kwargs):
Expand All @@ -97,11 +101,14 @@ def initialize(self, **kwargs):
"""
pass

# TODO: type training strategy
# TODO: docs
def train_sequence(
self,
training_strategy,
**kwargs
):
# TODO: better enums (use the pretty-printing like in CellRank)
if isinstance(training_strategy, Enum):
training_strategy = training_strategy.value
elif isinstance(training_strategy, str):
Expand All @@ -117,6 +124,7 @@ def train_sequence(
if np.any([x in list(d.keys()) for x in list(kwargs.keys())]):
d = dict([(x, y) for x, y in d.items() if x not in list(kwargs.keys())])
for x in [xx for xx in list(d.keys()) if xx in list(kwargs.keys())]:
# TODO: don't use sys
sys.stdout.write(
"overrding %s from training strategy with value %s with new value %s\n" %
(x, str(d[x]), str(kwargs[x]))
Expand All @@ -139,20 +147,21 @@ def finalize(self, **kwargs):
"""
pass

# TODO: make static, not use model?
def _plot_coef_vs_ref(
self,
true_values: np.ndarray,
estim_values: np.ndarray,
size=1,
log=False,
save=None,
show=True,
ncols=5,
row_gap=0.3,
col_gap=0.25,
title=None,
return_axs=False
):
true_values: T.ArrayLike,
estim_values: T.ArrayLike,
size: float = 1,
log: bool = False,
save: Optional[Union[str, Path]] = None,
show: bool = True,
ncols: int = 5,
row_gap: float = 0.3,
col_gap: float = 0.25,
title: Optional[str] = None,
return_axs: bool = False
) -> Optional['matplotlib.axes.Axes']:
"""
Plot estimated coefficients against reference (true) coefficients.

Expand All @@ -174,10 +183,8 @@ def _plot_coef_vs_ref(
from matplotlib import gridspec
from matplotlib import rcParams

if isinstance(true_values, dask.array.core.Array):
true_values = true_values.compute()
if isinstance(estim_values, dask.array.core.Array):
estim_values = estim_values.compute()
true_values = maybe_compute(true_values)
estim_values = maybe_compute(estim_values)

plt.ioff()

Expand Down Expand Up @@ -246,18 +253,17 @@ def _plot_coef_vs_ref(

if return_axs:
return axs
else:
return

# TODO: make static, not use model?
def _plot_deviation(
self,
true_values: np.ndarray,
estim_values: np.ndarray,
save=None,
show=True,
title=None,
return_axs=False
):
true_values: T.ArrayLike,
estim_values: T.ArrayLike,
save: Optional[Union[str, Path]] = None,
show: bool = True,
title: Optional[str] = None,
return_axs: bool = False
) -> Optional['matplotlib.axes.Axes']:
"""
Plot estimated coefficients against reference (true) coefficients.

Expand Down Expand Up @@ -309,11 +315,9 @@ def _plot_deviation(

if return_axs:
return ax
else:
return


class EstimatorBaseTyping(_EstimatorBase):
class EstimatorBaseTyping(_EstimatorBase, ABC):
r"""
Estimator base class used for typing in other packages.
"""
Expand Down
2 changes: 2 additions & 0 deletions batchglm/models/base/external.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
import batchglm.pkg_constants as pkg_constants
import batchglm.data as data_utils
import batchglm.types as types
from batchglm.train.numpy.utils import maybe_compute
77 changes: 55 additions & 22 deletions batchglm/models/base/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import numpy as np
import scipy.sparse
import sparse
from typing import List
from typing import List, Union, Optional, Tuple
from .external import types as T, maybe_compute

# TODO: make this nicer (i.e. top-level import + function + warning if None)
try:
import anndata
try:
Expand All @@ -29,13 +31,14 @@ class InputDataBase:

def __init__(
self,
data,
observation_names=None,
feature_names=None,
data: T.InputType,
weights: Optional[Union[T.ArrayLike, str]] = None,
observation_names: Optional[List[str]] = None,
feature_names: Optional[List[str]] = None,
chunk_size_cells: int = 100000,
chunk_size_genes: int = 100,
as_dask: bool = True,
cast_dtype=None
cast_dtype: Optional[np.dtype] = None
):
"""
Create a new InputData object.
Expand All @@ -46,73 +49,103 @@ def __init__(
- np.ndarray: NumPy array containing the raw data
- anndata.AnnData: AnnData object containing the count data and optional the design models
stored as data.obsm[design_loc] and data.obsm[design_scale]
:param weights: (optional) observation weights
:param observation_names: (optional) names of the observations.
:param feature_names: (optional) names of the features.
:param cast_dtype: data type of all data; should be either float32 or float64
:return: InputData object
"""
self.observations = observation_names
self.features = feature_names
self.w = None

if isinstance(data, np.ndarray) or \
isinstance(data, scipy.sparse.csr_matrix) or \
isinstance(data, dask.array.core.Array):
self.x = data
elif isinstance(data, anndata.AnnData) or isinstance(data, Raw):
self.x = data.X
if isinstance(weights, str):
self.w = data.obs[weights].values
elif isinstance(data, InputDataBase):
self.x = data.x
self.w = data.w
else:
raise ValueError("type of data %s not recognized" % type(data))

self._cast_type = cast_dtype if cast_dtype is not None else self.x.dtype
self._cast_type_float = np.float32 if not issubclass(type(self._cast_type), np.floating) else self._cast_type

if self.w is None:
self.w = np.ones(self.x.shape[0], dtype=self._cast_type_float)

if scipy.sparse.issparse(self.w):
self.w = self.w.toarray()
if self.w.ndim == 1:
self.w = self.w.reshape((-1, 1))

# sanity checks
# TODO: don't use asserts
assert self.w.shape == (self.x.shape[0], 1), "invalid weight shape %s" % self.w.shape
assert issubclass(self.w.dtype.type, np.floating), "invalid weight type %s" % self.w.dtype

if self.observations is not None:
assert len(self.observations) == self.x.shape[0]
if self.features is not None:
assert len(self.features) == self.x.shape[1]

self.x = maybe_compute(self.x)
self.w = maybe_compute(self.w)

if as_dask:
if isinstance(self.x, dask.array.core.Array):
self.x = self.x.compute()
# Need to wrap dask around the COO matrix version of the sparse package if matrix is sparse.
if isinstance(self.x, scipy.sparse.spmatrix):
if scipy.sparse.issparse(self.x):
self.x = dask.array.from_array(
sparse.COO.from_scipy_sparse(
self.x.astype(cast_dtype if cast_dtype is not None else self.x.dtype)
),
sparse.COO.from_scipy_sparse(self.x.astype(self._cast_type)),
chunks=(chunk_size_cells, chunk_size_genes),
asarray=False
)
else:
self.x = dask.array.from_array(
self.x.astype(cast_dtype if cast_dtype is not None else self.x.dtype),
self.x.astype(self._cast_type),
chunks=(chunk_size_cells, chunk_size_genes),
)

self.w = dask.array.from_array(self.w.astype(self._cast_type_float), chunks=(chunk_size_cells, 1))
else:
if isinstance(self.x, dask.array.core.Array):
self.x = self.x.compute()
if cast_dtype is not None:
self.x = self.x.astype(cast_dtype)
# TODO: fix this
if scipy.sparse.issparse(self.x):
raise TypeError(f"For sparse matrices, only dask is supported.")

self.x = self.x.astype(self._cast_type)
self.w = self.w.astype(self._cast_type)

self._feature_allzero = np.sum(self.x, axis=0) == 0
self.chunk_size_cells = chunk_size_cells
self.chunk_size_genes = chunk_size_genes

@property
def num_observations(self):
def num_observations(self) -> int:
return self.x.shape[0]

@property
def num_features(self):
def num_features(self) -> int:
return self.x.shape[1]

@property
def feature_isnonzero(self):
def feature_isnonzero(self) -> np.ndarray:
return ~self._feature_allzero

@property
def feature_isallzero(self):
def feature_isallzero(self) -> np.ndarray:
return self._feature_allzero

def fetch_x_dense(self, idx):
def fetch_x_dense(self, idx: T.IndexLike) -> np.ndarray:
assert isinstance(self.x, np.ndarray), "tried to fetch dense from non ndarray"

return self.x[idx, :]

def fetch_x_sparse(self, idx):
def fetch_x_sparse(self, idx: T.IndexLike) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
assert isinstance(self.x, scipy.sparse.csr_matrix), "tried to fetch sparse from non csr_matrix"

data = self.x[idx, :]
Expand Down
Loading