From 58286f6f1c6f6532f3e68d667d4b4ded2e038423 Mon Sep 17 00:00:00 2001 From: Andras Schmelczer Date: Sun, 7 Aug 2022 14:12:43 +0200 Subject: [PATCH] Improve TinyDB performance for *_ground_truth --- .../persistence/parallel_tinydb_driver.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/great_ai/persistence/parallel_tinydb_driver.py b/great_ai/persistence/parallel_tinydb_driver.py index 92b5fc8..e648477 100644 --- a/great_ai/persistence/parallel_tinydb_driver.py +++ b/great_ai/persistence/parallel_tinydb_driver.py @@ -63,15 +63,11 @@ def does_match(d: Dict[str, Any]) -> bool: ) ) - documents: List[Trace] = [ - Trace.parse_obj(t) - for t in self._safe_execute(lambda db: db.search(does_match)) - ] - + documents = self._safe_execute(lambda db: db.search(does_match)) if not documents: return [], 0 - df = pd.DataFrame([d.to_flat_dict() for d in documents]) + df = pd.DataFrame(documents) for f in conjunctive_filters: operator = f.operator.lower() @@ -91,10 +87,7 @@ def does_match(d: Dict[str, Any]) -> bool: count = len(df) result = df.iloc[skip:] if take is None else df.iloc[skip : skip + take] - return [ - next(d for d in documents if d.trace_id == trace_id) - for trace_id in result["trace_id"] - ], count + return [Trace.parse_obj(trace) for _, trace in result.iterrows()], count def update(self, id: str, new_version: Trace) -> None: self._safe_execute( @@ -105,8 +98,10 @@ def delete(self, id: str) -> None: self._safe_execute(lambda db: db.remove(lambda d: d["trace_id"] == id)) def delete_batch(self, ids: List[str]) -> None: - for i in ids: - self.delete(i) + with lock: + with TinyDB(self.path_to_db) as db: + for id in ids: + db.remove(lambda d: d["trace_id"] == id) def _safe_execute(self, func: Callable[[TinyDB], Any]) -> Any: with lock: