Skip to content

Commit

Permalink
add logistic regression support back
Browse files Browse the repository at this point in the history
  • Loading branch information
basaks committed Aug 7, 2023
1 parent 34fcceb commit 21a4b76
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 33 deletions.
8 changes: 7 additions & 1 deletion configs/classification.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,15 @@ targets:
property: Na_cats

learning:
# algorithm: transformedlogistic
# arguments:
# max_iter: 20000
# random_state: 1
algorithm: transformedforestclassifier
arguments:
n_estimators: 20
random_state: 1


prediction:
quantiles: 0.95
Expand All @@ -48,4 +54,4 @@ validation:
random_seed: 1

output:
directory: scratch/
directory: logistic/
82 changes: 50 additions & 32 deletions uncoverml/optimise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from sklearn.gaussian_process.kernels import RBF, Matern, RationalQuadratic
from sklearn.linear_model import (HuberRegressor,
LinearRegression,
ElasticNet, SGDRegressor)
ElasticNet, SGDRegressor,
LogisticRegression)
from sklearn.linear_model._stochastic_gradient import DEFAULT_EPSILON
from sklearn.preprocessing import LabelEncoder
from sklearn.svm import SVR
Expand Down Expand Up @@ -856,60 +857,77 @@ def predict_dist(self, X, interval=0.95, **kwargs):
return Ey, Vy, ql, qu


no_test_support = {
'xgboost': XGBoost,
'xgbquantileregressor': XGBQuantileRegressor,
'xgbquantile': QuantileXGB,
'quantilegb': QuantileGradientBoosting,
'gradientboost': GBMReg,
'catboost': CatBoostWrapper,
'lgbm': LGBMReg,
'quantilelgbm': QuantileLGBM,
}
class EncodedClassifierMixin():

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.le = LabelEncoder()

def fit(self, X, y, **kwargs):
y_t = self.le.fit_transform(y)
super().fit(X, y_t, sample_weight=kwargs['sample_weight'])

def predict_proba(self, X, **kwargs):
p = super().predict_proba(X)
y = np.argmax(p, axis=1) # Also return hard labels
return y, p

def get_classes(self):
tags = ["most_likely"]
tags += ["{}_{}".format(c, i)
for i, c in enumerate(self.le.classes_)]
return tags

class RandomForestClassifier(RandomForestClassifier, TagsMixin):

class RandomForestClassifier(EncodedClassifierMixin, RandomForestClassifier, TagsMixin):
"""
Random Forest for muli-class classification.
http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
"""
def __init__(self,
target_transform='identity',
n_estimators=10,
**kwargs
):
super(RandomForestClassifier, self).__init__(
super(EncodedClassifierMixin, self).__init__()
super().__init__(
n_estimators=n_estimators,
**kwargs
)

# training uses str
if isinstance(target_transform, str):
target_transform = transforms.transforms[target_transform]()

# used during optimisation
self.target_transform = target_transform
self.le = LabelEncoder()

def fit(self, X, y, **kwargs):
y_t = self.le.fit_transform(y)
super().fit(X, y_t, sample_weight=kwargs['sample_weight'])
class LogisticClassifier(EncodedClassifierMixin, LogisticRegression, TagsMixin):
"""
Logistic Regression for muli-class classification.
def predict_proba(self, X, **kwargs):
p = super().predict_proba(X)
y = np.argmax(p, axis=1) # Also return hard labels
return y, p
http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
"""
def __init__(self,
max_iter=1000,
**kwargs
):
super(EncodedClassifierMixin, self).__init__()
super().__init__(
max_iter=max_iter,
**kwargs
)

def get_classes(self):
tags = ["most_likely"]
tags += ["{}_{}".format(c, i)
for i, c in enumerate(self.le.classes_)]
return tags

no_test_support = {
'xgboost': XGBoost,
'xgbquantileregressor': XGBQuantileRegressor,
'xgbquantile': QuantileXGB,
'quantilegb': QuantileGradientBoosting,
'gradientboost': GBMReg,
'catboost': CatBoostWrapper,
'lgbm': LGBMReg,
'quantilelgbm': QuantileLGBM,
}

no_test_support_classifiers = {
'transformedforestclassifier': RandomForestClassifier,
'transformedlogistic': LogisticClassifier,
}


Expand Down
1 change: 1 addition & 0 deletions uncoverml/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def local_crossval(x_all, targets_all: targ.Targets, config: Config):
return

log.info("Validating with {} folds".format(config.folds))
print(config.algorithm_args)
model = modelmaps[config.algorithm](**config.algorithm_args)
classification = hasattr(model, 'predict_proba')
groups = targets_all.groups
Expand Down

0 comments on commit 21a4b76

Please sign in to comment.