Skip to content

Commit

Permalink
fix updates
Browse files Browse the repository at this point in the history
  • Loading branch information
DuongTyler committed Oct 28, 2023
1 parent 94c58d9 commit aaa6a56
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
2 changes: 2 additions & 0 deletions python/starpoint/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def update(

def column_update(
self,
ids: List[str],
embeddings: List[Dict[str, List[float] | int]],
document_metadatas: List[Dict[Any, Any]],
collection_id: Optional[str] = None,
Expand Down Expand Up @@ -253,6 +254,7 @@ def column_update(
requests.exceptions.SSLError: Failure likely due to network issues.
"""
return self.writer.column_update(
ids=ids,
embeddings=embeddings,
document_metadatas=document_metadatas,
collection_id=collection_id,
Expand Down
6 changes: 4 additions & 2 deletions python/starpoint/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def update(

def column_update(
self,
ids: List[str],
embeddings: List[Dict[str, List[float] | int]],
document_metadatas: List[Dict[Any, Any]],
collection_id: Optional[str] = None,
Expand All @@ -307,15 +308,16 @@ def column_update(
ValueError: If both collection id and collection name are provided.
requests.exceptions.SSLError: Failure likely due to network issues.
"""
if len(embeddings) != len(document_metadatas):
if len(embeddings) != len(document_metadatas) or len(embeddings) != len(ids):
LOGGER.warning(EMBEDDING_METADATA_LENGTH_MISMATCH_WARNING)

documents = [
{
"id": id,
"embeddings": embedding,
"metadata": document_metadata,
}
for embedding, document_metadata in zip(embeddings, document_metadatas)
for embedding, document_metadata, id in zip(embeddings, document_metadatas, ids)
]

return self.update(
Expand Down
11 changes: 8 additions & 3 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def test_writer_update_SSLError(

@patch("starpoint.writer.Writer.update")
def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer):
ids = ["a", "b"]
test_embeddings = [
{"values": [0.88], "dimensionality": 1},
{"values": [0.71], "dimensionality": 1}
Expand All @@ -365,7 +366,7 @@ def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer
]

mock_writer.column_update(
embeddings=test_embeddings, document_metadatas=test_document_metadatas
ids=ids, embeddings=test_embeddings, document_metadatas=test_document_metadatas
)

update_mock.assert_called_once_with(
Expand All @@ -379,6 +380,7 @@ def test_writer_column_update(update_mock: MagicMock, mock_writer: writer.Writer
def test_writer_column_update_collection_id_collection_name_passed_through(
update_mock: MagicMock, mock_writer: writer.Writer
):
ids = ["a", "b"]
test_embeddings = [{"values": [0.88], "dimensionality": 1}]
test_document_metadatas = [{"mock": "metadata"}]
expected_update_document = [
Expand All @@ -391,6 +393,7 @@ def test_writer_column_update_collection_id_collection_name_passed_through(
expected_collection_name = "mock_name"

mock_writer.column_update(
ids=ids,
embeddings=test_embeddings,
document_metadatas=test_document_metadatas,
collection_id=expected_collection_id,
Expand All @@ -408,6 +411,7 @@ def test_writer_column_update_collection_id_collection_name_passed_through(
def test_writer_column_insert_shorter_metadatas_length(
update_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch
):
ids = ["a", "b"]
test_embeddings = [
{"values": [0.88], "dimensionality": 1},
{"values": [0.71], "dimensionality": 1}
Expand All @@ -424,7 +428,7 @@ def test_writer_column_insert_shorter_metadatas_length(
monkeypatch.setattr(writer, "LOGGER", logger_mock)

mock_writer.column_update(
embeddings=test_embeddings, document_metadatas=test_document_metadatas
ids=ids, embeddings=test_embeddings, document_metadatas=test_document_metadatas
)

logger_mock.warning.assert_called_once_with(
Expand All @@ -441,6 +445,7 @@ def test_writer_column_insert_shorter_metadatas_length(
def test_writer_column_update_shorter_embeddings_length(
update_mock: MagicMock, mock_writer: writer.Writer, monkeypatch: MonkeyPatch
):
ids = ["a", "b"]
test_embeddings = [
{"values": [0.88], "dimensionality": 1},
]
Expand All @@ -456,7 +461,7 @@ def test_writer_column_update_shorter_embeddings_length(
monkeypatch.setattr(writer, "LOGGER", logger_mock)

mock_writer.column_update(
embeddings=test_embeddings, document_metadatas=test_document_metadatas
ids=ids, embeddings=test_embeddings, document_metadatas=test_document_metadatas
)

logger_mock.warning.assert_called_once_with(
Expand Down

0 comments on commit aaa6a56

Please sign in to comment.