Skip to content

Commit

Permalink
sklearn compat
Browse files Browse the repository at this point in the history
  • Loading branch information
basaks committed Aug 15, 2023
1 parent 535e4cb commit a9a6966
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
3 changes: 2 additions & 1 deletion uncoverml/optimise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def __init__(self, target_transform='identity',
self.target_transform = target_transform
# loss = 'quantile' # use quantile loss for median
# alpha = 0.5 # median
self.median_quantile_params ={'objective': 'quantile', "metric": "quantile", 'alpha': 0.5}
self.median_quantile_params ={'objective': 'quantile', "metric": "quantile", 'alpha': alpha}
self.upper_quantile_params = {'objective': 'quantile', "metric": "quantile", 'alpha': upper_alpha}
self.lower_quantile_params = {'objective': 'quantile', "metric": "quantile", 'alpha': lower_alpha}

Expand All @@ -639,6 +639,7 @@ def __init__(self, target_transform='identity',
**kwargs,
**self.lower_quantile_params
)
self.alpha = alpha
self.upper_alpha = upper_alpha
self.lower_alpha = lower_alpha

Expand Down
1 change: 0 additions & 1 deletion uncoverml/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ def local_crossval(x_all, targets_all: targ.Targets, config: Config):
return

log.info("Validating with {} folds".format(config.folds))
print(config.algorithm_args)
model = modelmaps[config.algorithm](**config.algorithm_args)
classification = hasattr(model, 'predict_proba')
groups = targets_all.groups
Expand Down

0 comments on commit a9a6966

Please sign in to comment.