diff --git a/projit/projit.py b/projit/projit.py index 1d40ce4..1c38b53 100644 --- a/projit/projit.py +++ b/projit/projit.py @@ -28,6 +28,7 @@ def __init__(self, datasets={}, results={}, params={}, + hyperparams={}, dataresults={}): """ Initialise a projit project object. @@ -57,6 +58,10 @@ def __init__(self, For example: target variable name, identifier column. :type params: Dictionary, optional + :param hyperparams: A dictionary of hyper parameters for experiments. + Structure: {'experiment':{'param':'value', etc}} + :type hyperparams: Dictionary, optional + :param dataresults: The dictionary of results on specific data sets. These are used when you want your experimental results broken down by the datasets. @@ -73,6 +78,7 @@ def __init__(self, self.datasets = datasets self.results = results self.params = params + self.hyperparams = hyperparams self.dataresults = dataresults @@ -103,6 +109,13 @@ def add_experiment(self, name, path): self.experiments.append( (name, path) ) self.save() + + def experiment_exists(self, name): + for elem in self.experiments: + if elem[0] == name: + return True + return False + def clean_experimental_results(self, name): """ Remove all results for a given experiment @@ -137,6 +150,21 @@ def add_param(self, name, value): self.params[name] = value self.save() + def add_hyperparam(self, name, value): + """ + Add a set of hyper parameters to the project. + + :param name: The experiment name + :type name: string, required + + :param value: The Dictionary of hyperparameters + :type value: Dictionary + """ + if self.experiment_exists(name): + self.hyperparams[name] = value + self.save() + else: + raise Exception("ERROR: No experiment called: '%s' -- Register your experiment first." % name) def add_result(self, experiment, metric, value, dataset=None): """ @@ -233,6 +261,12 @@ def get_param(self, name): else: raise Exception("ERROR: Named parameter '%s' is not available:" % name) + def get_hyperparam(self, name): + if name in self.hyperparams: + return self.hyperparams[name] + else: + raise Exception("ERROR: Hyper parameters for experiemnt '%s' are not available:" % name) + def get_path_to_dataset(self, name): ds = self.get_dataset(name) if self.is_complete_path(ds): diff --git a/tests/test_functions.py b/tests/test_functions.py index 3c91696..30bb9ce 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -130,6 +130,21 @@ def test_project_params(): os.chdir("../") shutil.rmtree(testdir) +################################################################# +def test_project_hyperparams(): + testdir = "temp_test_dir_xyz" + os.mkdir(testdir) + os.chdir(testdir) + project = proj.init("default", "test params", "param test") + with pytest.raises(Exception) as e_info: + project.add_hyperparam("myexp", "myval") + + project.add_experiment("myexp", "mypath") + project.add_hyperparam("myexp", "myval") + results = project.get_hyperparam("myexp") + assert results == "myval" + os.chdir("../") + shutil.rmtree(testdir) #################################################################