Skip to content

Commit

Permalink
rebase Mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianLeRoyKili committed Jul 22, 2022
1 parent c429c9e commit d04c963
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 158 deletions.
31 changes: 21 additions & 10 deletions mapper.py → commands/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,22 @@
from commands.common_args import Options, PredictOptions, TrainOptions
from commands.predict import predict_one_job
from kiliautoml.models import PyTorchVisionImageClassificationModel
from kiliautoml.utils.constants import ModelFrameworkT, ModelNameT, ModelRepositoryT
from kiliautoml.utils.helpers import get_assets, get_label, get_project, kili_print
from kiliautoml.utils.helpers import (
_get_label,
get_assets,
get_content_input_from_job,
get_project,
kili_print,
)
from kiliautoml.utils.mapper.create import MapperClassification
from kiliautoml.utils.type import AssetStatusT, LabelMergeStrategyT
from kiliautoml.utils.type import (
AssetStatusT,
LabelMergeStrategyT,
ModelFrameworkT,
ModelNameT,
ModelRepositoryT,
ProjectIdT,
)


@click.command()
Expand Down Expand Up @@ -53,7 +65,7 @@
def main(
api_endpoint: str,
api_key: str,
project_id: str,
project_id: ProjectIdT,
clear_dataset_cache: bool,
target_job: List[str],
model_framework: ModelFrameworkT,
Expand All @@ -68,7 +80,7 @@ def main(
epochs: int,
focus_class: Optional[List[str]],
from_model: Optional[ModelFrameworkT],
from_project: Optional[str],
from_project: Optional[ProjectIdT],
graph_name: str,
):
"""
Expand All @@ -87,7 +99,7 @@ def main(

kili_print(f"Create Mapper for job: {job_name}")

content_input = job.get("content", {}).get("input")
content_input = get_content_input_from_job(job)
ml_task = job.get("mlTask")
tools = job.get("tools")
if content_input == "radio" and ml_task == "CLASSIFICATION" and input_type == "IMAGE":
Expand All @@ -101,13 +113,13 @@ def main(
labeled_assets = []
labels = []
for asset in assets:
label = get_label(asset, label_merge_strategy)
label = _get_label(asset, job_name, label_merge_strategy)
if (label is None) or (job_name not in label["jsonResponse"]):
asset_id = asset["id"]
warnings.warn(f"${asset_id}: No annotation for job ${job_name}")
else:
labeled_assets.append(asset)
labels.append(label["jsonResponse"][job_name]["categories"][0]["name"])
labels.append(asset.get_annotations_classification(job_name))

if predictions_path is None:

Expand All @@ -122,7 +134,6 @@ def main(

training_loss = image_classification_model.train(
assets=labeled_assets,
label_merge_strategy=label_merge_strategy,
batch_size=batch_size,
epochs=epochs,
clear_dataset_cache=clear_dataset_cache,
Expand Down Expand Up @@ -155,7 +166,7 @@ def main(
clear_dataset_cache=clear_dataset_cache,
)

predictions = job_predictions.predictions_probability
predictions = job_predictions.predictions_probability # type: ignore
else:
with open("/content/predictions.csv", "r") as csv:
first_line = csv.readline()
Expand Down
144 changes: 0 additions & 144 deletions kiliautoml/utils/mapper/Tuto_Mapper_AutoML.ipynb

This file was deleted.

3 changes: 1 addition & 2 deletions kiliautoml/utils/mapper/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer # type: ignore

from kiliautoml.utils.constants import InputTypeT
from kiliautoml.utils.download_assets import (
download_project_images,
download_project_text,
Expand All @@ -29,7 +28,7 @@
gudhi_to_KM,
topic_score,
)
from kiliautoml.utils.type import JobT
from kiliautoml.utils.type import InputTypeT, JobT


def embeddings_text(list_text: List[str]):
Expand Down
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
sys.excepthook = ultratb.FormattedTB(mode="Verbose", color_scheme="Linux", call_pdb=False)

from commands.label_errors import main as label_errors
from commands.mapper import main as mapper
from commands.predict import main as predict
from commands.prioritize import main as prioritize
from commands.train import main as train
Expand All @@ -36,7 +37,7 @@ def kiliautoml():
kiliautoml.add_command(predict, name="predict")
kiliautoml.add_command(label_errors, name="label_errors")
kiliautoml.add_command(prioritize, name="prioritize")

kiliautoml.add_command(mapper, name="mapper")

if __name__ == "__main__":
kiliautoml()
2 changes: 1 addition & 1 deletion notebooks/Tuto_Mapper.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@
},
"outputs": [],
"source": [
"!python automl/mapper.py --project-id $project_id --assets-repository /content/assets --predictions-path /content/predictions.csv"
"!PYTHONPATH=$(pwd) kiliautoml --project-id $project_id --assets-repository /content/assets --predictions-path /content/predictions.csv"
]
}
],
Expand Down

0 comments on commit d04c963

Please sign in to comment.