From df2f7554ccc63424f644dd4dd956ea0b80482be6 Mon Sep 17 00:00:00 2001 From: Douglas Blank Date: Fri, 14 Apr 2023 11:23:35 -0700 Subject: [PATCH] Embedding projections additions (#92) * Allow variable projections * Testing t-SNE; Projections are currently auto-scaling * Updates to secure unpickle; needed an np.array... weird * Cache the embedding traces, rather than asset_ids; much faster * For now, trust all imported, but only builtins by name * Prevent TSNE race condition with pickle; fix key bug; allow Plotly to scale as needed * Remove scale * Version 2.2.9, removed tornado, added opentsne --- backend/kangas/_version.py | 2 +- backend/kangas/datatypes/embedding.py | 73 ++++-- backend/kangas/server/flask_server.py | 10 +- backend/kangas/server/queries.py | 209 +++++++++++------- backend/kangas/server/tasks.py | 6 +- backend/kangas/server/utils.py | 71 ++++++ backend/setup.py | 2 +- .../cells/embedding/EmbeddingCellClient.js | 2 - 8 files changed, 271 insertions(+), 104 deletions(-) diff --git a/backend/kangas/_version.py b/backend/kangas/_version.py index c78e686a..8de81591 100644 --- a/backend/kangas/_version.py +++ b/backend/kangas/_version.py @@ -11,5 +11,5 @@ # All rights reserved # ###################################################### -version_info = (2, 2, 8) +version_info = (2, 2, 9) __version__ = ".".join(map(str, version_info)) diff --git a/backend/kangas/datatypes/embedding.py b/backend/kangas/datatypes/embedding.py index 34140b74..a02ed07b 100644 --- a/backend/kangas/datatypes/embedding.py +++ b/backend/kangas/datatypes/embedding.py @@ -13,6 +13,7 @@ import json +from ..server.utils import pickle_dumps from .base import Asset from .utils import flatten, get_color, get_file_extension, is_valid_file_path @@ -28,6 +29,7 @@ def __init__( self, embedding=None, label=None, + projection="pca", file_name=None, metadata=None, source=None, @@ -53,6 +55,7 @@ def __init__( self.metadata["label"] = label self.metadata["color"] = color + self.metadata["projection"] = projection if file_name: if is_valid_file_path(file_name): @@ -73,7 +76,7 @@ def __init__( @classmethod def get_statistics(cls, datagrid, col_name, field_name): - from sklearn.decomposition import IncrementalPCA + import numpy as np # FIXME: compute min and max of eigenspace minimum = None @@ -86,29 +89,63 @@ def get_statistics(cls, datagrid, col_name, field_name): name = col_name kwargs = {} - - pca = IncrementalPCA(**kwargs) + projection = None batch = [] for row in datagrid.conn.execute( - """SELECT {field_name} as assetId, asset_data from datagrid JOIN assets ON assetId = assets.asset_id;""".format( + """SELECT {field_name} as assetId, asset_data, json_extract(asset_metadata, '$.projection') from datagrid JOIN assets ON assetId = assets.asset_id;""".format( field_name=field_name ) ): embedding = json.loads(row[1]) vectors = embedding["vector"] vector = flatten(vectors) - # FIXME: could scale them here; leave to user for now - batch.append(vector) - if len(batch) == 10: - pca.partial_fit(batch) - batch = [] - if len(batch) > 0: - pca.partial_fit(batch) - - other = json.dumps( - { - "pca_eigen_vectors": pca.components_.tolist(), - "pca_mean": pca.mean_.tolist(), - } - ) + + if row[2] is None or row[2] == "pca": + if projection is None: + from sklearn.decomposition import IncrementalPCA + + projection = IncrementalPCA(**kwargs) + projection_name = "pca" + batch.append(vector) + if len(batch) == 10: + projection.partial_fit(batch) + batch = [] + elif row[2] == "t-sne": + if projection is None: + from openTSNE import TSNE + + projection = TSNE(perplexity=30, learning_rate=10, n_iter=500) + projection_name = "t-sne" + batch.append(vector) + elif row[2] == "umap": + if projection is None: + projection_name = "umap" + else: + raise Exception( + "unknown projection %r; should be 'pca', 't-sne', or 'umap'" + % row[2] + ) + + if projection_name == "pca": + if len(batch) > 0: + projection.partial_fit(batch) + other = json.dumps( + { + "pca_eigen_vectors": projection.components_.tolist(), + "pca_mean": projection.mean_.tolist(), + "projection": projection_name, + } + ) + elif projection_name == "t-sne": + embedding = projection.fit(np.array(batch)) + other = json.dumps( + {"projection": projection_name, "embedding": pickle_dumps(embedding)} + ) + elif projection_name == "umap": + other = json.dumps( + { + "projection": projection_name, + } + ) + return [minimum, maximum, avg, variance, total, stddev, other, name] diff --git a/backend/kangas/server/flask_server.py b/backend/kangas/server/flask_server.py index fd379f77..0145aca7 100644 --- a/backend/kangas/server/flask_server.py +++ b/backend/kangas/server/flask_server.py @@ -47,7 +47,7 @@ select_asset_task, select_category_task, select_histogram_task, - select_pca_data_task, + select_projection_data_task, ) from .translogger import TransLogger from .utils import get_node_version @@ -627,7 +627,7 @@ def get_datagrid_about_handler(): @application.route("/datagrid/embeddings-as-pca", methods=["GET"]) @auth_wrapper -def get_embeddings_as_pca(): +def get_embeddings_as_projection(): dgid = request.args.get("dgid") timestamp = request.args.get("timestamp") # if one asset: @@ -643,7 +643,7 @@ def get_embeddings_as_pca(): width = int(request.args.get("width", "0")) if ensure_datagrid_path(dgid): - pca_data = select_pca_data_task.apply( + projection_data = select_projection_data_task.apply( args=( dgid, timestamp, @@ -656,14 +656,14 @@ def get_embeddings_as_pca(): ).get() if thumbnail: image = generate_chart_image_task.apply( - args=("scatter", pca_data, width, height) + args=("scatter", projection_data, width, height) ).get() response = make_response(image) response.headers.add("Cache-Control", "max-age=604800") response.headers.add("Content-type", "image/png") return response else: - return pca_data + return projection_data else: return error(404) diff --git a/backend/kangas/server/queries.py b/backend/kangas/server/queries.py index 695f8076..3938cad1 100644 --- a/backend/kangas/server/queries.py +++ b/backend/kangas/server/queries.py @@ -37,7 +37,7 @@ pytype_to_dgtype, ) from .computed_columns import update_state -from .utils import process_about, safe_compile, safe_env +from .utils import pickle_loads_embedding, process_about, safe_compile, safe_env LOGGER = logging.getLogger(__name__) KANGAS_ROOT = os.environ.get("KANGAS_ROOT", ".") @@ -56,19 +56,21 @@ VALID_CHARS = string.ascii_letters + string.digits + "_" -PCA_SAMPLES_CACHE = {} -PCA_MAX_SIZE = 1024 +PROJECTION_SAMPLES_CACHE = {} +PROJECTION_MAX_SIZE = 100 -def update_pca_cache(key, value): - if key not in PCA_SAMPLES_CACHE: +def get_projection_cache(key, value): + # Returns a copy of the value list + if key not in PROJECTION_SAMPLES_CACHE: # Check size: - if len(PCA_SAMPLES_CACHE) >= PCA_MAX_SIZE: + if len(PROJECTION_SAMPLES_CACHE) >= PROJECTION_MAX_SIZE: # too many - first_in_key = list(PCA_SAMPLES_CACHE.keys())[0] - # pop the first in - PCA_SAMPLES_CACHE.pop(first_in_key) - PCA_SAMPLES_CACHE[key] = value + first_in_key = list(PROJECTION_SAMPLES_CACHE.keys())[0] + # del the first-in + del PROJECTION_SAMPLES_CACHE[first_in_key] + PROJECTION_SAMPLES_CACHE[key] = value + return PROJECTION_SAMPLES_CACHE[key][:] def sqlite_query_explain( @@ -2192,8 +2194,16 @@ def get_fields(dgid, metadata=None, computed_columns=None): return fields -def process_pca_asset_ids( - name, cur, asset_ids, pca, traces, size, default_color, color_override=None +def process_projection_asset_ids( + name, + cur, + asset_ids, + projection_name, + projection, + traces, + size, + default_color, + color_override=None, ): # asset_ids is a list of str # side-effect: adds to traces @@ -2207,8 +2217,7 @@ def process_pca_asset_ids( values=values, ) - xs = [] - ys = [] + vectors = [] colors = [] for asset_data_row in cur.execute(sql): asset_data_raw = asset_data_row[0] @@ -2221,12 +2230,13 @@ def process_pca_asset_ids( else: color = default_color - # FIXME: can transform all at once - eigen_vector = pca.transform([vector]) - xs.append(round(eigen_vector[0][0], 3)) - ys.append(round(eigen_vector[0][1], 3)) + vectors.append(vector) colors.append(color) + eigen_vector = projection.transform(np.array(vectors)) + xs = eigen_vector[:, 0].tolist() + ys = eigen_vector[:, 1].tolist() + traces.append( { "x": xs, @@ -2239,65 +2249,78 @@ def process_pca_asset_ids( ) -def select_pca_data( +def select_projection_data( dgid, timestamp, asset_id, column_name, column_value, group_by, where_expr ): - from sklearn.decomposition import PCA - conn = get_database_connection(dgid) cur = conn.cursor() metadata = get_metadata(conn) column_limit = None column_offset = 0 - pca_eigen_vectors = metadata[column_name]["other"]["pca_eigen_vectors"] - pca_mean = metadata[column_name]["other"]["pca_mean"] - default_color = get_color(column_name) + if "projection" in metadata[column_name]["other"]: + projection_name = metadata[column_name]["other"]["projection"] + else: + projection_name = "pca" + + if projection_name == "pca": + from sklearn.decomposition import PCA + + pca_eigen_vectors = metadata[column_name]["other"]["pca_eigen_vectors"] + pca_mean = metadata[column_name]["other"]["pca_mean"] + projection = PCA() + projection.components_ = np.array(pca_eigen_vectors) + projection.mean_ = np.array(pca_mean) + elif projection_name == "t-sne": + # FIXME: Trying to prevent an error on first load; race condition? + from openTSNE import TSNE # noqa - pca = PCA() - pca.components_ = np.array(pca_eigen_vectors) - pca.mean_ = np.array(pca_mean) + ascii_string = metadata[column_name]["other"]["embedding"] + projection = pickle_loads_embedding(ascii_string) + + elif projection_name == "umap": + pass + else: + return + + default_color = get_color(column_name) traces = [] if asset_id: # First, add some points to provide context: - key = (dgid, timestamp, column_name, where_expr) - if key not in PCA_SAMPLES_CACHE: - # FIXME: make an LRU so as not to fill up memory - update_pca_cache( - key, - list( - select_query_raw( - cur, - metadata, - [column_name], - offset="0", - sort_by="RANDOM()", - sort_desc=None, - where=None, - limit=200, - computed_columns=None, - where_expr=where_expr, - debug=False, - ) - ), + key = ("sampled", dgid, timestamp, column_name, where_expr) + if key not in PROJECTION_SAMPLES_CACHE: + rows = select_query_raw( + cur, + metadata, + [column_name], + offset="0", + sort_by="RANDOM()", + sort_desc=None, + where=None, + limit=200, + computed_columns=None, + where_expr=where_expr, + debug=False, ) - rows = PCA_SAMPLES_CACHE[key] - process_pca_asset_ids( - "Sampled Data", - cur, - [row[0] for row in rows], - pca, - traces, - 3, - default_color, - "gray", - ) + process_projection_asset_ids( + "Sampled Data", + cur, + [row[0] for row in rows], + projection_name, + projection, + traces, + 3, + default_color, + "gray", + ) + # Traces contains projection data: + traces = get_projection_cache(key, traces) # Next, add the selected asset: asset_data_raw = select_asset(dgid, asset_id) asset_data = json.loads(asset_data_raw) - vector = pca.transform([asset_data["vector"]]) + vector = projection.transform(np.array([asset_data["vector"]])) if asset_data["color"]: color = asset_data["color"] else: @@ -2314,18 +2337,38 @@ def select_pca_data( } ) else: - rows = select_group_by_rows( - column_name, column_value, group_by, where_expr, metadata, cur + key = ( + "selection", + dgid, + timestamp, + column_name, + column_value, + group_by, + where_expr, ) - if rows: - row = rows[0] - if row and row[0]: - values = row[0].split(",") - if column_limit is not None: - values = values[column_offset : column_offset + column_limit] - process_pca_asset_ids( - column_name, cur, values, pca, traces, 3, default_color - ) + if key not in PROJECTION_SAMPLES_CACHE: + rows = select_group_by_rows( + column_name, column_value, group_by, where_expr, metadata, cur + ) + if rows: + row = rows[0] + if row and row[0]: + values = row[0].split(",") + if column_limit is not None: + values = values[column_offset : column_offset + column_limit] + + process_projection_asset_ids( + column_name, + cur, + values, + projection_name, + projection, + traces, + 3, + default_color, + ) + # Traces contains projection data: + traces = get_projection_cache(key, traces) return traces @@ -2557,6 +2600,8 @@ def generate_chart_image(chart_type, data, width, height): image = PIL.Image.new("RGBA", (width, height)) drawing = PIL.ImageDraw.Draw(image) + max_x, min_x = None, None + max_y, min_y = None, None for trace in data: if chart_type == "category": @@ -2610,8 +2655,10 @@ def generate_chart_image(chart_type, data, width, height): if "y" not in trace or len(trace["y"]) == 0: continue - min_x, max_x = -3, 3 - min_y, max_y = -3, 3 + if max_x is None: + min_x, max_x = min(trace["x"]), max(trace["x"]) + min_y, max_y = min(trace["y"]), max(trace["y"]) + span_x = max_x - min_x span_y = max_y - min_y @@ -2625,10 +2672,24 @@ def generate_chart_image(chart_type, data, width, height): span_y = max_y - min_y drawing.line( - [width / 2, margin, width / 2, height - margin], fill="black", width=1 + [ + margin + (total_width * (-100 - min_x) / span_x), + margin + (total_height - total_height * (0 - min_y) / span_y), + margin + (total_width * (100 - min_x) / span_x), + margin + (total_height - total_height * (0 - min_y) / span_y), + ], + fill="black", + width=1, ) drawing.line( - [margin, height / 2, width - margin, height / 2], fill="black", width=1 + [ + margin + (total_width * (0 - min_x) / span_x), + margin + (total_height - total_height * (-100 - min_y) / span_y), + margin + (total_width * (0 - min_x) / span_x), + margin + (total_height - total_height * (100 - min_y) / span_y), + ], + fill="black", + width=1, ) for count, [x, y] in enumerate(zip(trace["x"], trace["y"])): diff --git a/backend/kangas/server/tasks.py b/backend/kangas/server/tasks.py index 844eb3a1..4afe83d2 100644 --- a/backend/kangas/server/tasks.py +++ b/backend/kangas/server/tasks.py @@ -30,7 +30,7 @@ select_asset_metadata, select_category, select_histogram, - select_pca_data, + select_projection_data, ) # from .utils import get_bool_from_env @@ -124,11 +124,11 @@ def select_asset_metadata_task(self, dgid, asset_id): @app.task(bind=True) -def select_pca_data_task( +def select_projection_data_task( self, dgid, timestamp, asset_id, column_name, column_value, group_by, where_expr ): try: - result = select_pca_data( + result = select_projection_data( dgid, timestamp, asset_id, column_name, column_value, group_by, where_expr ) return result diff --git a/backend/kangas/server/utils.py b/backend/kangas/server/utils.py index 6a888c08..b75982c9 100644 --- a/backend/kangas/server/utils.py +++ b/backend/kangas/server/utils.py @@ -11,9 +11,13 @@ # All rights reserved # ###################################################### +import base64 import inspect +import io +import pickle import re import subprocess +import sys import urllib try: @@ -218,3 +222,70 @@ def get_node_version(): return output.decode("utf-8").strip() return "unknown" + + +class RestrictedUnpickler(pickle.Unpickler): + def __init__(self, safe, *args, **kwargs): + self.safe = safe + super().__init__(*args, **kwargs) + + def find_class(self, module, name): + if (module, name) in self.safe: + if module != "builtins": + return getattr(sys.modules[module], name) + + raise pickle.UnpicklingError( + "global module '%s', name '%s' is forbidden" % (module, name) + ) + + +def pickle_dumps(obj): + """ + Helper function analogous to pickle.dumps(). + """ + return base64.b64encode(pickle.dumps(obj)).decode("ascii") + + +def pickle_loads(safe, ascii_string): + """ + Helper function analogous to pickle.loads(). + """ + return RestrictedUnpickler( + safe, + io.BytesIO(base64.b64decode(ascii_string)), + ).load() + + +def pickle_loads_embedding_unsafe(ascii_string): + return pickle.Unpickler( + io.BytesIO(base64.b64decode(ascii_string)), + ).load() + + +def pickle_loads_embedding(ascii_string): + safe = { + ("numpy", "dtype"), + ("numpy", "ndarray"), + ("numpy.core.multiarray", "_reconstruct"), + ("numpy.core.multiarray", "scalar"), + ("openTSNE.affinity", "MultiscaleMixture"), + ("openTSNE.nearest_neighbors", "Sklearn"), + ("openTSNE.tsne", "TSNEEmbedding"), + ("openTSNE.tsne", "gradient_descent"), + ("scipy.sparse._csr", "csr_matrix"), + ("sklearn.base", "clone"), + ("sklearn.metrics._dist_metrics", "EuclideanDistance"), + ("sklearn.metrics._dist_metrics", "newObj"), + ("sklearn.neighbors._kd_tree", "KDTree"), + ("sklearn.neighbors._kd_tree", "newObj"), + ("sklearn.neighbors._unsupervised", "NearestNeighbors"), + ("openTSNE.nearest_neighbors", "Annoy"), + } + import numpy # noqa + import openTSNE # noqa + import openTSNE.tsne # noqa + import scipy # noqa + import sklearn # noqa + import sklearn.decomposition # noqa + + return pickle_loads(safe, ascii_string) diff --git a/backend/setup.py b/backend/setup.py index b5cd82ec..6d591d0e 100644 --- a/backend/setup.py +++ b/backend/setup.py @@ -60,8 +60,8 @@ def get_version(file, name="__version__"): "requests", "scikit-learn", "scipy", - "tornado", "waitress", + "opentsne", ], packages=[ "kangas", diff --git a/frontend/app/cells/embedding/EmbeddingCellClient.js b/frontend/app/cells/embedding/EmbeddingCellClient.js index 8be11985..4e10274c 100644 --- a/frontend/app/cells/embedding/EmbeddingCellClient.js +++ b/frontend/app/cells/embedding/EmbeddingCellClient.js @@ -75,14 +75,12 @@ const EmbeddingClient = ({ value, expanded, query, columnName, ssrData }) => { size: 13, color: '#3D4355', }, - range: [ -3, 3 ] }, yaxis: { font: { size: 13, color: '#3D4355', }, - range: [ -3, 3 ] } }; }, [columnName]);