Skip to content

Commit

Permalink
feat: add generator eval
Browse files Browse the repository at this point in the history
  • Loading branch information
datvodinh committed May 25, 2024
1 parent 85a377d commit e5a9027
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
pkg-manager: poetry
- run:
name: Run tests
command: poetry run pytest rag_chatbot/tests/test.py --junitxml=junit.xml || ((($? == 5)) && echo 'Did not find any tests to run.')
command: poetry run pytest rag_chatbot/test/test.py --junitxml=junit.xml || ((($? == 5)) && echo 'Did not find any tests to run.')
- store_test_results:
path: junit.xml
deploy:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest rag_chatbot/tests/test.py
pytest rag_chatbot/test/test.py
7 changes: 5 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,8 @@ data
.vscode
mlx_model
.idea
test
junit.xml
/test/
junit.xml
harry_potter_dataset
storage
result
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ llama-index-callbacks-wandb = "^0.1.2"
llama-index-retrievers-bm25 = "^0.1.3"
pytest = "^8.2.0"
pymupdf = "^1.24.3"
tqdm = "^4.66.4"



Expand Down
164 changes: 122 additions & 42 deletions rag_chatbot/eval/__main__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import os
import asyncio
import json
import argparse
import pandas as pd
from dotenv import load_dotenv
from tqdm.asyncio import tqdm_asyncio
from llama_index.core import VectorStoreIndex, Settings
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.evaluation import (
RetrieverEvaluator,
CorrectnessEvaluator,
FaithfulnessEvaluator,
AnswerRelevancyEvaluator,
ContextRelevancyEvaluator,
)
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset
from llama_index.core.storage.docstore import DocumentStore
Expand All @@ -35,6 +34,10 @@ def __init__(
docstore_path: str = "val_dataset/docstore.json",
) -> None:
self._setting = RAGSettings()
if llm not in ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4-turbo"]:
print("Pulling LLM model")
LocalRAGModel.pull(host=host, model_name=llm)
print("Pulling complete")
self._llm = LocalRAGModel.set(model_name=llm, host=host)
self._teacher = LocalRAGModel.set(model_name=teacher, host=host)
self._engine = LocalChatEngine(host=host)
Expand All @@ -44,42 +47,29 @@ def __init__(
# dataset
docstore = DocumentStore.from_persist_path(docstore_path)
nodes = list(docstore.docs.values())
index = VectorStoreIndex(nodes=nodes)
self._index = VectorStoreIndex(nodes=nodes)
self._dataset = EmbeddingQAFinetuneDataset.from_json(dataset_path)
self._top_k = self._setting.retriever.similarity_top_k
self._top_k_rerank = self._setting.retriever.top_k_rerank

self._retriever = {
"base": VectorIndexRetriever(
index=index, similarity_top_k=self._setting.retriever.top_k_rerank, verbose=True
index=self._index, similarity_top_k=self._top_k_rerank, verbose=True
),
"bm25": BM25Retriever.from_defaults(
index=index, similarity_top_k=self._setting.retriever.top_k_rerank, verbose=True
index=self._index, similarity_top_k=self._top_k_rerank, verbose=True
),
"base_rerank": VectorIndexRetriever(
index=index, similarity_top_k=self._setting.retriever.similarity_top_k, verbose=True
index=self._index, similarity_top_k=self._top_k, verbose=True
),
"bm25_rerank": BM25Retriever.from_defaults(
index=index, similarity_top_k=self._setting.retriever.similarity_top_k, verbose=True
index=self._index, similarity_top_k=self._top_k, verbose=True
),
"router": LocalRetriever(host=host).get_retrievers(
llm=self._llm, nodes=nodes
),
}

self._query_engine = {
"base": RetrieverQueryEngine.from_args(
retriever=self._retriever["base"],
llm=self._llm,
),
"bm25": RetrieverQueryEngine.from_args(
retriever=self._retriever["bm25"],
llm=self._llm,
),
"router": RetrieverQueryEngine.from_args(
retriever=self._retriever["router"],
llm=self._llm,
),
}

self._retriever_evaluator = {
"base": RetrieverEvaluator.from_metric_names(
["mrr", "hit_rate"], retriever=self._retriever["base"]
Expand All @@ -91,7 +81,7 @@ def __init__(
["mrr", "hit_rate"], retriever=self._retriever["base_rerank"],
node_postprocessors=[
SentenceTransformerRerank(
top_n=self._setting.retriever.top_k_rerank,
top_n=self._top_k_rerank,
model=self._setting.retriever.rerank_llm,
)
],
Expand All @@ -100,7 +90,7 @@ def __init__(
["mrr", "hit_rate"], retriever=self._retriever["bm25_rerank"],
node_postprocessors=[
SentenceTransformerRerank(
top_n=self._setting.retriever.top_k_rerank,
top_n=self._top_k_rerank,
model=self._setting.retriever.rerank_llm,
)
],
Expand All @@ -110,23 +100,92 @@ def __init__(
),
}

self._faithfulness_evaluator = FaithfulnessEvaluator(
llm=self._teacher,
)
self._generator_evaluator = {
"faithfulness": FaithfulnessEvaluator(
llm=self._teacher,
),
"answer_relevancy": AnswerRelevancyEvaluator(
llm=self._teacher,
),
"context_relevancy": ContextRelevancyEvaluator(
llm=self._teacher
)
}

async def eval_retriever(self):
result = {}
for retriever_name in self._retriever_evaluator.keys():
print(f"Running {retriever_name} retriever")
result[retriever_name] = self.display_results(
result[retriever_name] = self._process_retriever_result(
retriever_name,
await self._retriever_evaluator[retriever_name].aevaluate_dataset(
self._dataset, show_progress=True
)
)
return result

def display_results(self, name, eval_results):
async def _query_with_delay(self, query_engine, q, delay):
await asyncio.sleep(delay)
return await query_engine.aquery(q)

async def eval_generator(self):
queries = list(self._dataset.queries.values())[:3]
context = list(self._dataset.corpus.values())[:3]
query_engine = self._index.as_query_engine(
llm=self._llm,
)
response = []
for i in range(0, len(queries), 10):
print(f"Running queries {i} to {i+10}")
task = [query_engine.aquery(q) for q in queries[i:i + 10]]
response += await tqdm_asyncio.gather(*task, desc="querying")
await asyncio.sleep(5)

response = [str(r) for r in response]

faithful_task = []
answer_relevancy_task = []
context_relevancy_task = []
for q, r, c in zip(queries, response, context):
faithful_task.append(
self._generator_evaluator["faithfulness"].aevaluate(
response=r, contexts=[c]
)
)
answer_relevancy_task.append(
self._generator_evaluator["answer_relevancy"].aevaluate(
query=q, response=r
)
)
context_relevancy_task.append(
self._generator_evaluator["context_relevancy"].aevaluate(
query=q, contexts=[c]
)
)

faithful_result = await tqdm_asyncio.gather(
*faithful_task, desc="faithfulness"
)
answer_relevancy_result = await tqdm_asyncio.gather(
*answer_relevancy_task, desc="answer_relevancy"
)
context_relevancy_result = await tqdm_asyncio.gather(
*context_relevancy_task, desc="context_relevancy"
)

return {
"faithfulness": self._process_generator_result(
"faithfulness", faithful_result
),
"answer_relevancy": self._process_generator_result(
"answer_relevancy", answer_relevancy_result
),
"context_relevancy": self._process_generator_result(
"context_relevancy", context_relevancy_result
),
}

def _process_retriever_result(self, name, eval_results):
"""Display results from evaluate."""

metric_dicts = []
Expand All @@ -138,14 +197,27 @@ def display_results(self, name, eval_results):

hit_rate = full_df["hit_rate"].mean()
mrr = full_df["mrr"].mean()
metrics = {"retrievers": [name], "hit_rate": [hit_rate], "mrr": [mrr]}
metrics = {"retrievers": name, "hit_rate": hit_rate, "mrr": mrr}

return metrics

def _process_generator_result(self, name, eval_results):
result = []
for r in eval_results:
result.append(json.loads(r.json()))
return {"generator": name, "result": result}


if __name__ == "__main__":
# OLLAMA SERVER
parser = argparse.ArgumentParser()
parser.add_argument(
"--type",
type=str,
default="retriever",
choices=["retriever", "generator"],
help="Set type to retriever or generator",
)
parser.add_argument(
"--llm",
type=str,
Expand Down Expand Up @@ -179,22 +251,30 @@ def display_results(self, name, eval_results):
args = parser.parse_args()
if args.host != "host.docker.internal":
port_number = 11434
if not is_port_open(port_number):
if not is_port_open(port_number) and args.llm not in ["gpt-3.5-turbo", "gpt-4", "gpt-4o", "gpt-4-turbo"]:
run_ollama_server()
evaluator = RAGPipelineEvaluator(
llm=args.llm,
teacher=args.teacher,
host=args.host,
dataset_path=args.dataset,
docstore_path=args.docstore,
)

async def main():
evaluator = RAGPipelineEvaluator(
llm=args.llm,
teacher=args.teacher,
host=args.host,
dataset_path=args.dataset,
docstore_path=args.docstore,
)

async def eval_retriever():
retriever_result = await evaluator.eval_retriever()
print(retriever_result)
# save results
with open("retriever_result.json", "w") as f:
json.dump(retriever_result, f)

asyncio.run(main())
async def eval_generator():
generator_result = await evaluator.eval_generator()
# save results
with open(f"generator_result_{args.llm}.json", "w") as f:
json.dump(generator_result, f)

if args.type == "retriever":
asyncio.run(eval_retriever())
else:
asyncio.run(eval_generator())
1 change: 0 additions & 1 deletion rag_chatbot/ollama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import threading
import os
import socket


Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit e5a9027

Please sign in to comment.