Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

v0.2 Roadmap #190

Open
4 of 6 tasks
BirkhoffG opened this issue Jun 7, 2023 · 3 comments
Open
4 of 6 tasks

v0.2 Roadmap #190

BirkhoffG opened this issue Jun 7, 2023 · 3 comments

Comments

@BirkhoffG
Copy link
Owner

BirkhoffG commented Jun 7, 2023

TODOs

  • Show benchmark results in the docsite or wandb
  • Decouple API in recourse methods
  • Decouple API in module.py and trainer.py
  • Consider to use equinox over haiku Use keras_core
  • Save and load CFModule, MLModule, DataModule, and their corresponding configs.
  • Support multi-class setting

Directory Organization

- relax
  - methods
    - base.py
    - vanilla.py
    - ...
  - __init__.py
  - _modidx.py
  - explain.py
  - evaluate.py
  - data.py
  - strategy.py
  - metrics.py
  - ml_model.py
  - plots.py
  - utils.py
@BirkhoffG
Copy link
Owner Author

High-level One-Liner APIs

  • relax.generate_cf_explanations
  • relax.evaluate_cfs
relax.generate_cf_explanations(
    cf_module: BaseCFModule, # Recourse Explanation Module. 
    datamodule: TabularDataModule, # Data Module.
    pred_fn: callable = None, # Predictive function. If None, start training the model.
    strategy: str | BaseGenerationStrategy = 'vmap', # Parallelism Strategy for generating CFs.
    t_configs: TrainingConfigs = None, # Training configs for `BaseParametricCFModule`.
    pred_fn_args: dict = None # auxiliary arguments for `pred_fn`.
) -> Explanation:
    ...

relax.evaluate_cfs(
    cf_exp: Explanation, # CF Explanations
    metrics: Iterable[Union[str, BaseEvalMetrics]] = None, # A list of Metrics. Can be `str` or a subclass of `BaseEvalMetrics`
    return_dict: bool = True, # return a dictionary or not (default: True)
    return_df: bool = False # return a pandas Dataframe or not (default: False)
):

@BirkhoffG
Copy link
Owner Author

BirkhoffG commented Jul 12, 2023

Mid-level Modules

  • CFModule
class CFConfig:
  pass

class CFModue:
  def __init__(
    self,
    configs, 
    *,
    name: str = None,
    apply_constraints_fn = None,
    apply_regularization_fn = None
):
    ...

  @property
  def name(self) -> str:
    ...

  def init_apply_fn(
    self,
    apply_constraints_fn = None,
    apply_regularization_fn = None
): 
    ...

  def generate_cf(
    self,
    x: Array, # Input to be explained.
    pred_fn: Callable = None,
    pred_fn_args = None
): -> Array
      ...

class ParametricCFModule(CFModule):
  def train(
    self,
    data_module: DataModule,
    pred_fn = None,
    **train_kwargs
):
    ...

  def is_trained(self) -> bool:
    ...

  def save(self):
    """Save the CFModule"""
    ...

def load_cf_module(
  name: str,
  return_config: bool = False, # Return config or not
  configs: dict = None # Config to override default configuration
) -> CFModule:
  """Load the CFModule"""
  ...
  • MLModule
class PredFnMixedin:
  def pred_fn(self, x, **kwargs):
    ...

class MLModule(keras.Model, PredFnMixedin):
  @property
  def name(self) -> str:
    ...
  def train(self, data_module: DataModule, **train_kwargs):
    ...

def load_ML_module(
  name: str,
  return_config: bool = False, # Return config or not
  configs: dict = None # Config to override default configuration
) -> MLModule:
  """Load the ML module"""
  ...
  • DataModule
class DataModule:
  def __init__(self, config):
    ...

  def from_pandas(self, df: pd.DataFrame):
    ...

  def save(self): 
    """Save the data module to a directory. Decrease pre-processing time"""
    ...

  def prepare(self):
    """Prepare and pre-process data"""
    ...

  def dataset(
    self, 
    name: str # Should be one of  ['train', 'val', 'test'].
  ):
   ...

  def transform(
    self, 
    data: pd.DataFrame, # Data to be transformed to `numpy.ndarray`
  ) -> Tuple[np.ndarray, np.ndarray]: # Return `(X, y)`
    """Transform data into numerical representations."""
    ...

  def inverse_transform(
    self, 
    x: Array, # The transformed input to be scaled back.
    y: Array = None # The transformed label to be scaled back. If `None`, the target columns will not be scaled back.
  ) -> pd.DataFrame
    """Transform back into `pd.DataFrame`."""
    ...

def apply_constraints(self, x: Array, cf: Array, hard: bool = False, **kwargs) -> Array:
   ...

def apply_regularization(self, x: Array, cf: Array, hard: bool = False, **kwargs) -> float:
   ...

def load_data_module(
  name: str,
  return_config: bool = False, # Return config or not
  configs: dict = None # Config to override default configuration
) -> MLModule:
  """Load the Data module"""
  ...

@BirkhoffG
Copy link
Owner Author

BirkhoffG commented Jul 12, 2023

Other Modules

  • Strategy
  • Explanation
  • Metrics

Useful functional utils

  • plots
  • Callbacks
    • tqdm callback

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant