From 07b04d2673ed57c670b157f27c3a32835944cbdf Mon Sep 17 00:00:00 2001 From: Hendrik Schreiber Date: Mon, 12 Oct 2020 14:33:04 +0200 Subject: [PATCH] Cache models locally. Load models from GitHub (as fallback). --- CHANGES.rst | 2 ++ setup.py | 68 ++++++++++++++++++++++-------------------- tempocnn/classifier.py | 48 +++++++++++++++++++++++------ tempocnn/version.py | 2 +- 4 files changed, 77 insertions(+), 43 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index a9094e4..f49dbb2 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,8 @@ Changes - Officially support Python 3.7. - Enabled GitHub actions for packaging and testing. - Added Pypi workflow. + - Cache models locally. + - Load models from GitHub. 0.0.4: - Added support for DeepTemp, DeepSquare, and ShallowTemp models. diff --git a/setup.py b/setup.py index 86cd431..fc4afbd 100755 --- a/setup.py +++ b/setup.py @@ -19,39 +19,41 @@ scripts = glob.glob('bin/*') # define the models to be included in the PyPI package -package_data = ['models/cnn.h5', - 'models/fcn.h5', - 'models/ismir2018.h5', - 'models/fma2018.h5', - 'models/fma2018-meter.h5', - 'models/dt_maz_m_fold0.h5', - 'models/dt_maz_m_fold1.h5', - 'models/dt_maz_m_fold2.h5', - 'models/dt_maz_m_fold3.h5', - 'models/dt_maz_m_fold4.h5', - 'models/dt_maz_v_fold0.h5', - 'models/dt_maz_v_fold1.h5', - 'models/dt_maz_v_fold2.h5', - 'models/dt_maz_v_fold3.h5', - 'models/dt_maz_v_fold4.h5', - 'models/deepsquare_k1.h5', - 'models/deepsquare_k2.h5', - 'models/deepsquare_k4.h5', - 'models/deepsquare_k8.h5', - 'models/deepsquare_k16.h5', - 'models/deepsquare_k24.h5', - 'models/deeptemp_k2.h5', - 'models/deeptemp_k4.h5', - 'models/deeptemp_k8.h5', - 'models/deeptemp_k16.h5', - 'models/deeptemp_k24.h5', - 'models/shallowtemp_k1.h5', - 'models/shallowtemp_k2.h5', - 'models/shallowtemp_k4.h5', - 'models/shallowtemp_k6.h5', - 'models/shallowtemp_k8.h5', - 'models/shallowtemp_k12.h5', - ] +# do not package some large models, to stay below PyPI 100mb threshold +package_data = [ + 'models/cnn.h5', + 'models/fcn.h5', + 'models/ismir2018.h5', +# 'models/fma2018.h5', +# 'models/fma2018-meter.h5', + 'models/dt_maz_m_fold0.h5', + 'models/dt_maz_m_fold1.h5', + 'models/dt_maz_m_fold2.h5', + 'models/dt_maz_m_fold3.h5', + 'models/dt_maz_m_fold4.h5', + 'models/dt_maz_v_fold0.h5', + 'models/dt_maz_v_fold1.h5', + 'models/dt_maz_v_fold2.h5', + 'models/dt_maz_v_fold3.h5', + 'models/dt_maz_v_fold4.h5', + 'models/deepsquare_k1.h5', + 'models/deepsquare_k2.h5', + 'models/deepsquare_k4.h5', + 'models/deepsquare_k8.h5', +# 'models/deepsquare_k16.h5', +# 'models/deepsquare_k24.h5', + 'models/deeptemp_k2.h5', + 'models/deeptemp_k4.h5', + 'models/deeptemp_k8.h5', +# 'models/deeptemp_k16.h5', +# 'models/deeptemp_k24.h5', + 'models/shallowtemp_k1.h5', + 'models/shallowtemp_k2.h5', + 'models/shallowtemp_k4.h5', + 'models/shallowtemp_k6.h5', +# 'models/shallowtemp_k8.h5', +# 'models/shallowtemp_k12.h5', +] # requirements with open('requirements.txt', 'r') as fh: diff --git a/tempocnn/classifier.py b/tempocnn/classifier.py index 90c360c..52e8a13 100644 --- a/tempocnn/classifier.py +++ b/tempocnn/classifier.py @@ -1,9 +1,11 @@ # encoding: utf-8 - +import logging import os import pkgutil import sys -import tempfile +import urllib.request +from pathlib import Path +from urllib.error import HTTPError import numpy as np from tensorflow.python.keras.models import load_model @@ -83,10 +85,7 @@ def __init__(self, model_name='fcn'): print('Failed to find a model named \'{}\'. Please check the model name.'.format(model_name), file=sys.stderr) raise e - try: - self.model = load_model(file) - finally: - os.remove(file) + self.model = load_model(file) def estimate(self, data): """ @@ -273,8 +272,39 @@ def _to_model_resource(model_name): def _extract_from_package(resource): + # check local cache + cache_path = Path(Path.home(), '.tempocnn', resource) + if cache_path.exists(): + return str(cache_path) + + # ensure cache path exists + cache_path.parent.mkdir(parents=True, exist_ok=True) + data = pkgutil.get_data('tempocnn', resource) - with tempfile.NamedTemporaryFile(prefix='model', suffix='.h5', delete=False) as f: + if not data: + data = _load_model_from_github(resource) + + # write to cache + with open(cache_path, 'wb') as f: f.write(data) - name = f.name - return name + + return str(cache_path) + + +def _load_model_from_github(resource): + url = f"https://raw.githubusercontent.com/hendriks73/tempo-cnn/main/tempocnn/{resource}" + logging.info(f"Attempting to download model file from main branch {url}") + try: + response = urllib.request.urlopen(url) + return response.read() + except HTTPError as e: + # fall back to dev branch + try: + url = f"https://raw.githubusercontent.com/hendriks73/tempo-cnn/dev/tempocnn/{resource}" + logging.info(f"Attempting to download model file from dev branch {url}") + response = urllib.request.urlopen(url) + return response.read() + except Exception: + pass + + raise FileNotFoundError(f"Failed to download model from {url}: {type(e).__name__}: {e}") diff --git a/tempocnn/version.py b/tempocnn/version.py index 81f0fde..afb6fc6 100644 --- a/tempocnn/version.py +++ b/tempocnn/version.py @@ -1 +1 @@ -__version__ = "0.0.4" +__version__ = "0.0.5.dev0"