Skip to content

Commit

Permalink
remove embedding object and use dicts instead (#46)
Browse files Browse the repository at this point in the history
* remove embedding object and use dicts instead

* pin to 3.11 for now

* fix tests

* fix updates

* fix tests

---------

Co-authored-by: scottwey <[email protected]>
  • Loading branch information
DuongTyler and scottwey committed Oct 28, 2023
1 parent 57eb112 commit e068ad7
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 46 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.x'
cache: 'pip'
python-version: "3.11"
cache: "pip"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
12 changes: 11 additions & 1 deletion python/starpoint/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Optional
from typing import Dict, Optional, List
from uuid import UUID

import requests
Expand Down Expand Up @@ -102,3 +102,13 @@ def _check_collection_identifier_collision(
raise ValueError(NO_COLLECTION_VALUE_ERROR)
elif collection_id and collection_name:
raise ValueError(MULTI_COLLECTION_VALUE_ERROR)


def _ensure_embedding_dict(embeddings: List[float] | Dict[str, List[float] | int] | None):
if isinstance(embeddings, list):
dict_embeddings = {
"values": embeddings,
"dimensionality": len(embeddings)
}
return dict_embeddings
return embeddings
16 changes: 6 additions & 10 deletions python/starpoint/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import validators

from starpoint import reader, writer, _utils
from starpoint.embedding import Embedding

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,7 +87,7 @@ def insert(

def column_insert(
self,
embeddings: List[Embedding],
embeddings: List[Dict[str, List[float] | int]],
document_metadatas: List[Dict[Any, Any]],
collection_id: Optional[str] = None,
collection_name: Optional[str] = None,
Expand Down Expand Up @@ -127,7 +126,7 @@ def query(
sql: Optional[str] = None,
collection_id: Optional[str] = None,
collection_name: Optional[str] = None,
query_embedding: Optional[List[float] | Embedding] = None,
query_embedding: Optional[List[float] | Dict[str, List[float] | int]] = None,
params: Optional[List[Any]] = None,
text_search_query: Optional[List[str]] = None,
text_search_weight: Optional[float] = None,
Expand All @@ -143,6 +142,7 @@ def query(
collection_name: The collection's name where the query will happen.
This or the `collection_id` needs to be provided.
query_embedding: An embedding to query against the collection using similarity search.
This is of the shape {"values": List[float], "dimensionality": int}
params: values for parameterized sql
Returns:
Expand All @@ -154,12 +154,6 @@ def query(
requests.exceptions.SSLError: Failure likely due to network issues.
"""

# check if query embedding is a float, if it is, convert to a embedding object
if isinstance(query_embedding, list):
query_embedding = Embedding(
vectors=query_embedding,
dim=len(query_embedding))

return self.reader.query(
sql=sql,
collection_id=collection_id,
Expand Down Expand Up @@ -231,7 +225,8 @@ def update(

def column_update(
self,
embeddings: List[Embedding],
ids: List[str],
embeddings: List[Dict[str, List[float] | int]],
document_metadatas: List[Dict[Any, Any]],
collection_id: Optional[str] = None,
collection_name: Optional[str] = None,
Expand Down Expand Up @@ -259,6 +254,7 @@ def column_update(
requests.exceptions.SSLError: Failure likely due to network issues.
"""
return self.writer.column_update(
ids=ids,
embeddings=embeddings,
document_metadatas=document_metadatas,
collection_id=collection_id,
Expand Down
9 changes: 0 additions & 9 deletions python/starpoint/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,6 @@
)


class Embedding(object):
values: List[float]
dimensionality: int

def __init__(self, values: List[float], dimensionality: Optional[int] = None):
self.values = values
self.dimensionality = len(values) if dimensionality is None else dimensionality


class EmbeddingModel(Enum):
MINILM = "MINI_LM"

Expand Down
9 changes: 6 additions & 3 deletions python/starpoint/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
_build_header,
_check_collection_identifier_collision,
_validate_host,
_ensure_embedding_dict
)

from starpoint.embedding import Embedding

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,7 +49,7 @@ def query(
sql: Optional[str] = None,
collection_id: Optional[str] = None,
collection_name: Optional[str] = None,
query_embeddings: Optional[Embedding] = None,
query_embeddings: Optional[Dict[str, int | List[float]] | List[float]] = None,
params: Optional[List[Any]] = None,
text_search_query: Optional[List[str]] = None,
text_search_weight: Optional[float] = None,
Expand Down Expand Up @@ -89,10 +89,13 @@ def query(
)
"""

# check if type of query embeddings is list of float, if so convert to a dict
query_embeddings = _ensure_embedding_dict(query_embeddings)

request_data = dict(
collection_id=collection_id,
collection_name=collection_name,
query_embeddings=query_embeddings,
query_embedding=query_embeddings,
sql=sql,
params=params,
text_search_query=text_search_query,
Expand Down
11 changes: 6 additions & 5 deletions python/starpoint/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
_validate_host,
)

from starpoint.embedding import Embedding

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -171,7 +170,7 @@ def insert(

def column_insert(
self,
embeddings: List[Embedding],
embeddings: List[Dict[str, List[float] | int]],
document_metadatas: List[Dict[Any, Any]],
collection_id: Optional[str] = None,
collection_name: Optional[str] = None,
Expand Down Expand Up @@ -282,7 +281,8 @@ def update(

def column_update(
self,
embeddings: List[Embedding],
ids: List[str],
embeddings: List[Dict[str, List[float] | int]],
document_metadatas: List[Dict[Any, Any]],
collection_id: Optional[str] = None,
collection_name: Optional[str] = None,
Expand All @@ -308,15 +308,16 @@ def column_update(
ValueError: If both collection id and collection name are provided.
requests.exceptions.SSLError: Failure likely due to network issues.
"""
if len(embeddings) != len(document_metadatas):
if len(embeddings) != len(document_metadatas) or len(embeddings) != len(ids):
LOGGER.warning(EMBEDDING_METADATA_LENGTH_MISMATCH_WARNING)

documents = [
{
"id": id,
"embeddings": embedding,
"metadata": document_metadata,
}
for embedding, document_metadata in zip(embeddings, document_metadatas)
for embedding, document_metadata, id in zip(embeddings, document_metadatas, ids)
]

return self.update(
Expand Down
6 changes: 2 additions & 4 deletions python/tests/test_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from tempfile import NamedTemporaryFile
from uuid import uuid4
from unittest.mock import MagicMock, patch
from starpoint.embedding import Embedding

import pytest
from _pytest.monkeypatch import MonkeyPatch
Expand Down Expand Up @@ -47,7 +46,7 @@ def test_client_insert(mock_writer: MagicMock, mock_reader: MagicMock):
def test_client_column_insert(mock_writer: MagicMock, mock_reader: MagicMock):
client = db.Client(api_key=uuid4())

client.column_insert(embeddings=[Embedding([1.1])], document_metadatas=[{"mock": "value"}])
client.column_insert(embeddings=[{"values": [1.1], "dimensionality": 1}], document_metadatas=[{"mock": "value"}])

mock_reader.assert_called_once() # Only called during init
mock_writer().column_insert.assert_called_once()
Expand Down Expand Up @@ -91,7 +90,6 @@ def test_client_update(mock_writer: MagicMock, mock_reader: MagicMock):
def test_client_column_update(mock_writer: MagicMock, mock_reader: MagicMock):
client = db.Client(api_key=uuid4())

client.column_update(embeddings=[Embedding([1.1])], document_metadatas=[{"mock": "value"}])

client.column_update(ids=["a"], embeddings=[{"values": [1.1], "dimensionality": 1}], document_metadatas=[{"mock": "value"}])
mock_reader.assert_called_once() # Only called during init
mock_writer().column_update.assert_called_once()
52 changes: 40 additions & 12 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from requests.exceptions import SSLError

from starpoint import writer
from starpoint.embedding import Embedding


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -161,7 +160,10 @@ def test_writer_insert_SSLError(

@patch("starpoint.writer.Writer.insert")
def test_writer_column_insert(insert_mock: MagicMock, mock_writer: writer.Writer):
test_embeddings = [Embedding([0.88]), Embedding([0.71])]
test_embeddings = [
{"values": [0.88], "dimensionality": 1},
{"values": [0.71], "dimensionality": 1}
]
test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}]
expected_insert_document = [
{
Expand Down Expand Up @@ -189,7 +191,10 @@ def test_writer_column_insert(insert_mock: MagicMock, mock_writer: writer.Writer
def test_writer_column_insert_collection_id_collection_name_passed_through(
insert_mock: MagicMock, mock_writer: writer.Writer
):
test_embeddings = [Embedding([0.88])]
test_embeddings = [
{"values": [0.88], "dimensionality": 1},
]

test_document_metadatas = [{"mock": "metadata"}]
expected_insert_document = [
{
Expand Down Expand Up @@ -218,7 +223,10 @@ def test_writer_column_insert_collection_id_collection_name_passed_through(
def test_writer_column_insert_shorter_metadatas_length(
insert_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch
):
test_embeddings = [Embedding([0.88]), Embedding([0.71])]
test_embeddings = [
{"values": [0.88], "dimensionality": 1},
{"values": [0.71], "dimensionality": 1}
]
test_document_metadatas = [{"mock": "metadata"}]
expected_insert_document = [
{
Expand Down Expand Up @@ -248,7 +256,9 @@ def test_writer_column_insert_shorter_metadatas_length(
def test_writer_column_insert_shorter_embeddings_length(
insert_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch
):
test_embeddings = [Embedding([0.88])]
test_embeddings = [
{"values": [0.88], "dimensionality": 1},
]
test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}]
expected_insert_document = [
{
Expand Down Expand Up @@ -338,21 +348,27 @@ def test_writer_update_SSLError(

@patch("starpoint.writer.Writer.update")
def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer):
test_embeddings = [Embedding([0.88]), Embedding([0.71])]
ids = ["a", "b"]
test_embeddings = [
{"values": [0.88], "dimensionality": 1},
{"values": [0.71], "dimensionality": 1}
]
test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}]
expected_update_document = [
{
"id": "a",
"embeddings": test_embeddings[0],
"metadata": test_document_metadatas[0],
},
{
"id": "b",
"embeddings": test_embeddings[1],
"metadata": test_document_metadatas[1],
},
]

mock_writer.column_update(
embeddings=test_embeddings, document_metadatas=test_document_metadatas
ids=ids, embeddings=test_embeddings, document_metadatas=test_document_metadatas
)

update_mock.assert_called_once_with(
Expand All @@ -366,10 +382,12 @@ def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer
def test_writer_column_update_collection_id_collection_name_passed_through(
update_mock: MagicMock, mock_writer: writer.Writer
):
test_embeddings = [Embedding([0.88])]
ids = ["a"]
test_embeddings = [{"values": [0.88], "dimensionality": 1}]
test_document_metadatas = [{"mock": "metadata"}]
expected_update_document = [
{
"id": "a",
"embeddings": test_embeddings[0],
"metadata": test_document_metadatas[0],
},
Expand All @@ -378,6 +396,7 @@ def test_writer_column_update_collection_id_collection_name_passed_through(
expected_collection_name = "mock_name"

mock_writer.column_update(
ids=ids,
embeddings=test_embeddings,
document_metadatas=test_document_metadatas,
collection_id=expected_collection_id,
Expand All @@ -395,10 +414,15 @@ def test_writer_column_update_collection_id_collection_name_passed_through(
def test_writer_column_insert_shorter_metadatas_length(
update_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch
):
test_embeddings = [Embedding([0.88]), Embedding([0.71])]
ids = ["a", "b"]
test_embeddings = [
{"values": [0.88], "dimensionality": 1},
{"values": [0.71], "dimensionality": 1}
]
test_document_metadatas = [{"mock": "metadata"}]
expected_update_document = [
{
"id": "a",
"embeddings": test_embeddings[0],
"metadata": test_document_metadatas[0],
},
Expand All @@ -408,7 +432,7 @@ def test_writer_column_insert_shorter_metadatas_length(
monkeypatch.setattr(writer, "LOGGER", logger_mock)

mock_writer.column_update(
embeddings=test_embeddings, document_metadatas=test_document_metadatas
ids=ids, embeddings=test_embeddings, document_metadatas=test_document_metadatas
)

logger_mock.warning.assert_called_once_with(
Expand All @@ -425,10 +449,14 @@ def test_writer_column_insert_shorter_metadatas_length(
def test_writer_column_update_shorter_embeddings_length(
update_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch
):
test_embeddings = [Embedding([0.88])]
ids = ["a", "b"]
test_embeddings = [
{"values": [0.88], "dimensionality": 1},
]
test_document_metadatas = [{"mock": "metadata"}, {"mock2": "metadata2"}]
expected_update_document = [
{
"id": "a",
"embeddings": test_embeddings[0],
"metadata": test_document_metadatas[0],
},
Expand All @@ -438,7 +466,7 @@ def test_writer_column_update_shorter_embeddings_length(
monkeypatch.setattr(writer, "LOGGER", logger_mock)

mock_writer.column_update(
embeddings=test_embeddings, document_metadatas=test_document_metadatas
ids=ids, embeddings=test_embeddings, document_metadatas=test_document_metadatas
)

logger_mock.warning.assert_called_once_with(
Expand Down

0 comments on commit e068ad7

Please sign in to comment.