Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ANN and optimization ability for item2item task #55

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
bd1a4ec
Add ANN interface and implementations
netang Mar 16, 2023
8eeeff8
Add dependencies
netang Mar 16, 2023
4ee8e58
Format code via black code formatter. Add doctest examples.
netang Mar 23, 2023
724a149
Add docstrings to ANNMixin methods
netang Mar 24, 2023
2b25dfa
Add test for ANN models (except word2vec)
netang Mar 28, 2023
e2e1f2b
Merge remote-tracking branch 'sb-repo/main' into sb-main-ann
netang May 10, 2023
0b43d7b
Replace `functools.cached_property` with `cached_property.cached_prop…
netang May 10, 2023
23ef21d
Reformat docstring
netang May 10, 2023
7da34f2
Disable/Fix pylint warns. Add docstrings.
netang May 11, 2023
eb4329d
Fix pycodestyle warn.
netang May 11, 2023
f8302f5
Move `HnswlibIndexFileManager` and `NmslibIndexFileManager` to `repla…
netang May 16, 2023
e7caad0
Add `BaseHnswParam`. Fix sphinx warn.
netang May 16, 2023
c59c16a
Remove commented lines.
netang May 16, 2023
48d0b78
Add `DriverHnswlibIndexBuilder` and `ExecutorHnswlibIndexBuilder`
netang May 16, 2023
461cb7d
Add `DriverNmslibIndexBuilder` and `ExecutorNmslibIndexBuilder`
netang May 16, 2023
77787a1
Move `NeighbourRec` to `base_neighbour_rec.py`. Move ann mixins to an…
netang May 16, 2023
14cef4c
Add ANN to `ADMMSLIM` and `AssociationRulesItemRec` models
netang May 16, 2023
4020e56
Disable pylint R0902
netang May 16, 2023
1d3c8d3
Replace `typing.Literal` with `typing_extensions.Literal`
netang May 16, 2023
daaba5d
Fix tests
netang May 16, 2023
516bfd8
Make `HnswlibMixin` and `NmslibHnswMixin` abstract
netang May 17, 2023
3c48caf
Fix saving/loading
netang May 19, 2023
31bfd1a
Remove duplicated code
netang May 19, 2023
3a50457
Update index builders, add index stores and add index inferers.
netang May 21, 2023
8dd8fc0
Fix index saving/loading
netang May 28, 2023
2ded076
Add clean upping index files
netang May 28, 2023
f8c2959
Add tests of save/load ANN models
netang May 28, 2023
83f236a
Fix pylint warns
netang May 28, 2023
33fec4b
Fix pytest error
netang May 28, 2023
de9e622
Add new save/load tests. Fix error in test.
netang May 30, 2023
a44575a
Set poetry-core version interval
netang May 30, 2023
b9f0130
Revert "Set poetry-core version interval"
netang May 30, 2023
f7d42dc
Add `poetry-core` to dependencies
netang May 30, 2023
e763d43
Add tests
netang May 30, 2023
c6d8d80
Fix pycodestyle warn
netang May 30, 2023
281c816
Add `.coveragerc`
netang May 30, 2023
85a8875
Remove `_inner_predict_wrap` method
netang Jun 6, 2023
c0b442f
Merge remote-tracking branch 'sb-repo/main' into sb-main-ann
netang Jun 16, 2023
397663c
Add test of 'get_csr_matrix' function
netang Jun 21, 2023
535999a
Add `make_build_index_udf` builder function
netang Jun 21, 2023
7180f19
Add test of `build_index_udf`
netang Jun 21, 2023
252f3c4
Format code via black
netang Jun 21, 2023
afa6427
Rewrite index building udf to test inner functionality of udf
netang Jun 25, 2023
4173e7e
Merge remote-tracking branch 'sb-repo/main' into sb-main-ann
netang Jun 25, 2023
f061c91
Add docstring to `build_and_save_index` function
netang Jun 25, 2023
40cc806
Move some functions from `replay.utils` to `replay.ann.index_stores`
netang Jun 25, 2023
136a60d
Move `build_and_save_index` method to `NmslibIndexBuilderMixin`
netang Jun 25, 2023
af0c492
Fix pylint warns
netang Jun 26, 2023
214dacf
Added ann for get_nearest_items method for ALS and Word2Vec models
mralexdmitriy Jul 12, 2023
abe1b6f
Added the optimization ability for item2item task
mralexdmitriy Jul 12, 2023
8cb3bf9
minir fixes for pycodestyle and pylint
mralexdmitriy Jul 12, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
[run]
omit = replay/spark_custom_models/*
omit =
replay/ann/index_stores/hdfs_index_store.py
replay/spark_custom_models/*
1,578 changes: 810 additions & 768 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ pytorch-ranger = "^0.1.1"
d3rlpy = "*"
# required by d3rlpy
gym = "0.17.2"
nmslib = "*"
hnswlib = "*"
cached-property = "*"

[tool.poetry.dev-dependencies]
# dev only
Expand Down
Empty file added replay/ann/__init__.py
Empty file.
295 changes: 295 additions & 0 deletions replay/ann/ann_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
import importlib
from abc import abstractmethod
from typing import Optional, Dict, Any, Union, Iterable

from pyspark.sql import DataFrame
from pyspark.sql import functions as sf

from replay.ann.index_builders.base_index_builder import IndexBuilder
from replay.ann.index_stores.spark_files_index_store import (
SparkFilesIndexStore,
)
from replay.models.base_rec import BaseRecommender
from replay.utils import get_unique_entities, get_top_k_recs, get_top_k, return_recs


class ANNMixin(BaseRecommender):
"""
This class overrides the `_fit_wrap` and `_predict_wrap` methods of the base class,
adding an index construction in the `_fit_wrap` step
and an index inference in the `_predict_wrap` step.
"""

index_builder: Optional[IndexBuilder] = None

@property
def _use_ann(self) -> bool:
"""
Property that determines whether the ANN (index) is used.
If `True`, then the index will be built (at the `fit` stage)
and index will be inferred (at the `predict` stage).
"""
return self.index_builder is not None

@abstractmethod
def _get_vectors_to_build_ann(self, log: DataFrame) -> DataFrame:
"""Implementations of this method must return a dataframe with item vectors.
Item vectors from this method are used to build the index.

Args:
log: DataFrame with interactions

Returns: DataFrame[item_idx int, vector array<double>] or DataFrame[vector array<double>].
Column names in dataframe can be anything.
"""

@abstractmethod
def _get_ann_build_params(self, log: DataFrame) -> Dict[str, Any]:
"""Implementation of this method must return dictionary
with arguments for `_build_ann_index` method.

Args:
log: DataFrame with interactions

Returns: Dictionary with arguments to build index. For example: {
"id_col": "item_idx",
"features_col": "item_factors",
...
}

"""

def _fit_wrap(
self,
log: DataFrame,
user_features: Optional[DataFrame] = None,
item_features: Optional[DataFrame] = None,
) -> None:
"""Wrapper extends `_fit_wrap`, adds construction of ANN index by flag.

Args:
log: historical log of interactions
``[user_idx, item_idx, timestamp, relevance]``
user_features: user features
``[user_idx, timestamp]`` + feature columns
item_features: item features
``[item_idx, timestamp]`` + feature columns

"""
super()._fit_wrap(log, user_features, item_features)

if self._use_ann:
vectors = self._get_vectors_to_build_ann(log)
ann_params = self._get_ann_build_params(log)
self.index_builder.build_index(vectors, **ann_params)

@abstractmethod
def _get_vectors_to_infer_ann_inner(
self, log: DataFrame, users: DataFrame
) -> DataFrame:
"""Implementations of this method must return a dataframe with user vectors.
User vectors from this method are used to infer the index.

Args:
log: DataFrame with interactions
users: DataFrame with users

Returns: DataFrame[user_idx int, vector array<double>] or DataFrame[vector array<double>].
Vector column name in dataframe can be anything.
"""

def _get_vectors_to_infer_ann(
self, log: DataFrame, users: DataFrame, filter_seen_items: bool
) -> DataFrame:
"""This method wraps `_get_vectors_to_infer_ann_inner`
and adds seen items to dataframe with user vectors by flag.

Args:
log: DataFrame with interactions
users: DataFrame with users
filter_seen_items: flag to remove seen items from recommendations based on ``log``.

Returns:

"""
users = self._get_vectors_to_infer_ann_inner(log, users)

# here we add `seen_item_idxs` to filter the viewed items in UDFs (see infer_index_udf)
if filter_seen_items:
user_to_max_items = log.groupBy("user_idx").agg(
sf.count("item_idx").alias("num_items"),
sf.collect_set("item_idx").alias("seen_item_idxs"),
)
users = users.join(user_to_max_items, on="user_idx")

return users

@abstractmethod
def _get_ann_infer_params(self) -> Dict[str, Any]:
"""Implementation of this method must return dictionary
with arguments for `_infer_ann_index` method.

Returns: Dictionary with arguments to infer index. For example: {
"features_col": "user_vector",
...
}

"""

# pylint: disable=too-many-arguments, too-many-locals
def _predict_wrap(
self,
log: Optional[DataFrame],
k: int,
users: Optional[Union[DataFrame, Iterable]] = None,
items: Optional[Union[DataFrame, Iterable]] = None,
user_features: Optional[DataFrame] = None,
item_features: Optional[DataFrame] = None,
filter_seen_items: bool = True,
recs_file_path: Optional[str] = None,
) -> Optional[DataFrame]:
self.logger.debug("Starting predict %s", type(self).__name__)
user_data = users or log or user_features or self.fit_users
users = get_unique_entities(user_data, "user_idx")
users, log = self._filter_cold_for_predict(users, log, "user")

item_data = items or self.fit_items
items = get_unique_entities(item_data, "item_idx")
items, log = self._filter_cold_for_predict(items, log, "item")
num_items = items.count()
if num_items < k:
message = f"k = {k} > number of items = {num_items}"
self.logger.debug(message)

if self._use_ann:
vectors = self._get_vectors_to_infer_ann(
log, users, filter_seen_items
)
ann_params = self._get_ann_infer_params()
inferer = self.index_builder.produce_inferer(filter_seen_items)
recs = inferer.infer(vectors, ann_params["features_col"], k)
else:
recs = self._predict(
log,
k,
users,
items,
user_features,
item_features,
filter_seen_items,
)

if not self._use_ann:
if filter_seen_items and log:
recs = self._filter_seen(recs=recs, log=log, users=users, k=k)

recs = get_top_k_recs(recs, k=k).select(
"user_idx", "item_idx", "relevance"
)

output = return_recs(recs, recs_file_path)
self._clear_model_temp_view("filter_seen_users_log")
self._clear_model_temp_view("filter_seen_num_seen")
return output

def _save_index(self, path):
self.index_builder.index_store.dump_index(path)

def _load_index(self, path: str):
self.index_builder.index_store = SparkFilesIndexStore()
self.index_builder.index_store.load_from_path(path)

def init_builder_from_dict(self, init_meta: dict):
"""Inits an index builder instance from a dict with init meta."""

# index param entity instance initialization
module = importlib.import_module(init_meta["index_param"]["module"])
class_ = getattr(module, init_meta["index_param"]["class"])
index_params = class_(**init_meta["index_param"]["init_args"])

# index builder instance initialization
module = importlib.import_module(init_meta["builder"]["module"])
class_ = getattr(module, init_meta["builder"]["class"])
index_builder = class_(index_params=index_params, index_store=None)

self.index_builder = index_builder


class ANNItem2itemMixin(ANNMixin):
"""
This class overrides the '_get_nearest_items_wrap' methods of the base class,
adding an index inference in the ` _get_nearest_items_wrap` step
"""

@abstractmethod
def _get_item_vectors_to_infer_ann(
self, items: DataFrame
) -> DataFrame:
"""
Implementations of this method must return a dataframe with item vectors.
Item vectors from this method are used to infer the index.

Args:
items: DataFrame with items

Returns: DataFrame[item_idx int, vector array<double>] or DataFrame[vector array<double>].
Vector column name in dataframe can be anything.
"""

@abstractmethod
def _get_ann_nearest_items_infer_params(self) -> Dict[str, Any]:
"""
Implementation of this method must return dictionary
with arguments for `inferer.infer` method.

Returns: Dictionary with arguments to infer index. For example: {
"features_col": "item_vector",
...
}

"""

def _get_nearest_items_wrap(
self,
items: Union[DataFrame, Iterable],
k: int,
metric: Optional[str] = "cosine_similarity",
candidates: Optional[Union[DataFrame, Iterable]] = None,
) -> Optional[DataFrame]:

items = get_unique_entities(items, "item_idx")
if candidates is not None:
candidates = get_unique_entities(candidates, "item_idx")

if self._use_ann:
item_vectors = self._get_item_vectors_to_infer_ann(
items
)
ann_params = self._get_ann_nearest_items_infer_params()
inferer = self.index_builder.produce_inferer(filter_seen_items=False)
nearest_items = inferer.infer(item_vectors, ann_params["features_col"], k)
else:
nearest_items_to_filter = self._get_nearest_items(
items=items,
metric=metric,
candidates=candidates,
)

rel_col_name = metric if metric is not None else "similarity"
nearest_items = get_top_k(
dataframe=nearest_items_to_filter,
partition_by_col=sf.col("item_idx_one"),
order_by_col=[
sf.col(rel_col_name).desc(),
sf.col("item_idx_two").desc(),
],
k=k,
)

nearest_items = nearest_items.withColumnRenamed(
"item_idx_two", "neighbour_item_idx"
)
nearest_items = nearest_items.withColumnRenamed(
"item_idx_one", "item_idx"
)
return nearest_items
Empty file added replay/ann/entities/__init__.py
Empty file.
32 changes: 32 additions & 0 deletions replay/ann/entities/base_hnsw_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dataclasses import dataclass
from typing import Optional


@dataclass
class BaseHnswParam:
"""
Base hnsw params.
"""

space: str
m: int = 200 # pylint: disable=invalid-name
ef_c: int = 20000
post: int = 0
ef_s: Optional[int] = None

def init_meta_as_dict(self) -> dict:
"""
Returns meta-information for class instance initialization. Used to save the entity to disk.
:return: dictionary with init meta.
"""
return {
"module": type(self).__module__,
"class": type(self).__name__,
"init_args": {
"space": self.space,
"m": self.m,
"ef_c": self.ef_c,
"post": self.post,
"ef_s": self.ef_s,
},
}
Loading