Skip to content

Commit

Permalink
Embedding projections additions (#92)
Browse files Browse the repository at this point in the history
* Allow variable projections

* Testing t-SNE; Projections are currently auto-scaling

* Updates to secure unpickle; needed an np.array... weird

* Cache the embedding traces, rather than asset_ids; much faster

* For now, trust all imported, but only builtins by name

* Prevent TSNE race condition with pickle; fix key bug; allow Plotly to scale as needed

* Remove scale

* Version 2.2.9, removed tornado, added opentsne
  • Loading branch information
dsblank authored Apr 14, 2023
1 parent 8e23aab commit df2f755
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 104 deletions.
2 changes: 1 addition & 1 deletion backend/kangas/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
# All rights reserved #
######################################################

version_info = (2, 2, 8)
version_info = (2, 2, 9)
__version__ = ".".join(map(str, version_info))
73 changes: 55 additions & 18 deletions backend/kangas/datatypes/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import json

from ..server.utils import pickle_dumps
from .base import Asset
from .utils import flatten, get_color, get_file_extension, is_valid_file_path

Expand All @@ -28,6 +29,7 @@ def __init__(
self,
embedding=None,
label=None,
projection="pca",
file_name=None,
metadata=None,
source=None,
Expand All @@ -53,6 +55,7 @@ def __init__(

self.metadata["label"] = label
self.metadata["color"] = color
self.metadata["projection"] = projection

if file_name:
if is_valid_file_path(file_name):
Expand All @@ -73,7 +76,7 @@ def __init__(

@classmethod
def get_statistics(cls, datagrid, col_name, field_name):
from sklearn.decomposition import IncrementalPCA
import numpy as np

# FIXME: compute min and max of eigenspace
minimum = None
Expand All @@ -86,29 +89,63 @@ def get_statistics(cls, datagrid, col_name, field_name):
name = col_name

kwargs = {}

pca = IncrementalPCA(**kwargs)
projection = None
batch = []
for row in datagrid.conn.execute(
"""SELECT {field_name} as assetId, asset_data from datagrid JOIN assets ON assetId = assets.asset_id;""".format(
"""SELECT {field_name} as assetId, asset_data, json_extract(asset_metadata, '$.projection') from datagrid JOIN assets ON assetId = assets.asset_id;""".format(
field_name=field_name
)
):
embedding = json.loads(row[1])
vectors = embedding["vector"]
vector = flatten(vectors)
# FIXME: could scale them here; leave to user for now
batch.append(vector)
if len(batch) == 10:
pca.partial_fit(batch)
batch = []
if len(batch) > 0:
pca.partial_fit(batch)

other = json.dumps(
{
"pca_eigen_vectors": pca.components_.tolist(),
"pca_mean": pca.mean_.tolist(),
}
)

if row[2] is None or row[2] == "pca":
if projection is None:
from sklearn.decomposition import IncrementalPCA

projection = IncrementalPCA(**kwargs)
projection_name = "pca"
batch.append(vector)
if len(batch) == 10:
projection.partial_fit(batch)
batch = []
elif row[2] == "t-sne":
if projection is None:
from openTSNE import TSNE

projection = TSNE(perplexity=30, learning_rate=10, n_iter=500)
projection_name = "t-sne"
batch.append(vector)
elif row[2] == "umap":
if projection is None:
projection_name = "umap"
else:
raise Exception(
"unknown projection %r; should be 'pca', 't-sne', or 'umap'"
% row[2]
)

if projection_name == "pca":
if len(batch) > 0:
projection.partial_fit(batch)
other = json.dumps(
{
"pca_eigen_vectors": projection.components_.tolist(),
"pca_mean": projection.mean_.tolist(),
"projection": projection_name,
}
)
elif projection_name == "t-sne":
embedding = projection.fit(np.array(batch))
other = json.dumps(
{"projection": projection_name, "embedding": pickle_dumps(embedding)}
)
elif projection_name == "umap":
other = json.dumps(
{
"projection": projection_name,
}
)

return [minimum, maximum, avg, variance, total, stddev, other, name]
10 changes: 5 additions & 5 deletions backend/kangas/server/flask_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
select_asset_task,
select_category_task,
select_histogram_task,
select_pca_data_task,
select_projection_data_task,
)
from .translogger import TransLogger
from .utils import get_node_version
Expand Down Expand Up @@ -627,7 +627,7 @@ def get_datagrid_about_handler():

@application.route("/datagrid/embeddings-as-pca", methods=["GET"])
@auth_wrapper
def get_embeddings_as_pca():
def get_embeddings_as_projection():
dgid = request.args.get("dgid")
timestamp = request.args.get("timestamp")
# if one asset:
Expand All @@ -643,7 +643,7 @@ def get_embeddings_as_pca():
width = int(request.args.get("width", "0"))

if ensure_datagrid_path(dgid):
pca_data = select_pca_data_task.apply(
projection_data = select_projection_data_task.apply(
args=(
dgid,
timestamp,
Expand All @@ -656,14 +656,14 @@ def get_embeddings_as_pca():
).get()
if thumbnail:
image = generate_chart_image_task.apply(
args=("scatter", pca_data, width, height)
args=("scatter", projection_data, width, height)
).get()
response = make_response(image)
response.headers.add("Cache-Control", "max-age=604800")
response.headers.add("Content-type", "image/png")
return response
else:
return pca_data
return projection_data
else:
return error(404)

Expand Down
Loading

0 comments on commit df2f755

Please sign in to comment.