Skip to content

Commit

Permalink
Add chatmodels for langchain (#193)
Browse files Browse the repository at this point in the history
Signed-off-by: shiyu22 <[email protected]>
  • Loading branch information
shiyu22 committed Apr 13, 2023
1 parent 274fb85 commit 23c7e3b
Show file tree
Hide file tree
Showing 10 changed files with 262 additions and 120 deletions.
27 changes: 24 additions & 3 deletions examples/adapter/langchain_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@

from langchain import Cohere
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage

from gptcache.adapter.langchain_llms import LangChainLLMs
from gptcache.adapter.langchain_models import LangChainLLMs
from gptcache import cache
from gptcache.processor.pre import get_prompt

from gptcache.adapter.langchain_models import LangChainChat

OpenAI.api_key = os.getenv("OPENAI_API_KEY")
Cohere.cohere_api_key = os.getenv("COHERE_API_KEY")


def run():
def run_llm():
cache.init(
pre_embedding_func=get_prompt,
)
Expand All @@ -30,5 +34,22 @@ def run():
print(answer)


def get_msg(data, **_):
return data.get("messages")[-1].content


def run_chat_model():
cache.init(
pre_embedding_func=get_msg,
)

# chat=ChatOpenAI(temperature=0)
chat = LangChainChat(chat=ChatOpenAI(temperature=0))

answer = chat([HumanMessage(content="Translate this sentence from English to Chinese. I love programming.")])
print(answer)


if __name__ == '__main__':
run()
run_llm()
run_chat_model()
2 changes: 1 addition & 1 deletion examples/langchain_examples/langchain_llms_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langchain import Cohere
from langchain.llms import OpenAI

from gptcache.adapter.langchain_llms import LangChainLLMs
from gptcache.adapter.langchain_models import LangChainLLMs
from gptcache import cache, Cache
from gptcache.processor.pre import get_prompt

Expand Down
2 changes: 1 addition & 1 deletion examples/langchain_examples/langchain_prompt_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain.llms import OpenAI
from langchain import PromptTemplate, LLMChain

from gptcache.adapter.langchain_llms import LangChainLLMs
from gptcache.adapter.langchain_models import LangChainLLMs
from gptcache import Cache
from gptcache.processor.pre import get_prompt

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain.llms import OpenAI
from langchain import PromptTemplate

from gptcache.adapter.langchain_llms import LangChainLLMs
from gptcache.adapter.langchain_models import LangChainLLMs
from gptcache.manager import get_data_manager, CacheBase, VectorBase
from gptcache import Cache
from gptcache.embedding import Onnx
Expand Down
1 change: 1 addition & 0 deletions gptcache/adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def adapt(llm_handler, cache_data_convert, update_cache_callback, *args, **kwarg
)
if rank_threshold <= rank:
cache_answers.append((rank, cache_answer))
chat_cache.data_manager.update_access_time(cache_data)
cache_answers = sorted(cache_answers, key=lambda x: x[0], reverse=True)
if len(cache_answers) != 0:
return_message = chat_cache.post_process_messages_func(
Expand Down
62 changes: 0 additions & 62 deletions gptcache/adapter/langchain_llms.py

This file was deleted.

124 changes: 124 additions & 0 deletions gptcache/adapter/langchain_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import Optional, List, Any

from gptcache.adapter.adapter import adapt
from gptcache.utils import import_pydantic, import_langchain

import_pydantic()
import_langchain()

# pylint: disable=C0413
from pydantic import BaseModel
from langchain.llms.base import LLM
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, LLMResult, AIMessage, ChatGeneration, ChatResult


class LangChainLLMs(LLM, BaseModel):
"""LangChain LLM Wrapper.
:param llm: LLM from langchain.llms.
:type llm: Any
Example:
.. code-block:: python
from gptcache import cache
from gptcache.processor.pre import get_prompt
# init gptcache
cache.init(pre_embedding_func=get_prompt)
cache.set_openai_key()
from langchain.llms import OpenAI
from gptcache.adapter.langchain_models import LangChainLLMs
# run llm with gptcache
llm = LangChainLLMs(llm=OpenAI(temperature=0))
llm("Hello world")
"""

llm: Any

@property
def _llm_type(self) -> str:
return "gptcache_llm"

def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
return adapt(
self.llm,
cache_data_convert,
update_cache_callback,
prompt=prompt,
stop=stop,
**kwargs
)

def __call__(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
return self._call(prompt=prompt, stop=stop, **kwargs)


# pylint: disable=protected-access
class LangChainChat(BaseChatModel, BaseModel):
"""LangChain LLM Wrapper.
:param chat: LLM from langchain.chat_models.
:type chat: Any
Example:
.. code-block:: python
from gptcache import cache
from gptcache.processor.pre import get_prompt
# init gptcache
cache.init(pre_embedding_func=get_prompt)
cache.set_openai_key()
from langchain.chat_models import ChatOpenAI
from gptcache.adapter.langchain_models import LangChainChat
# run chat model with gptcache
chat = LangChainLLMs(chat=ChatOpenAI(temperature=0))
chat([HumanMessage(content="Translate this sentence from English to French. I love programming.")])
"""

chat: Any

def _generate(self, messages: Any, stop: Optional[List[str]] = None, **kwargs):
return adapt(
self.chat._generate,
cache_msg_data_convert,
update_cache_msg_callback,
messages=messages,
stop=stop,
**kwargs
)

async def _agenerate(self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None, **kwargs) -> LLMResult:
return adapt(
self.chat._agenerate,
cache_msg_data_convert,
update_cache_msg_callback,
messages=messages,
stop=stop,
**kwargs
)

def __call__(self, messages: Any, stop: Optional[List[str]] = None, **kwargs):
res = self._generate(messages=messages, stop=stop, **kwargs)
return res.generations[0].message


def cache_data_convert(cache_data):
return cache_data


def update_cache_callback(llm_data, update_cache_func):
update_cache_func(llm_data)
return llm_data


def cache_msg_data_convert(cache_data):
llm_res = ChatResult(generations=[ChatGeneration(text="",
generation_info=None,
message=AIMessage(content=cache_data, additional_kwargs={}))],
llm_output=None)
return llm_res


def update_cache_msg_callback(llm_data, update_cache_func):
update_cache_func(llm_data.generations[0].text)
return llm_data
6 changes: 6 additions & 0 deletions gptcache/manager/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def import_data(
def get_scalar_data(self, res_data, **kwargs):
pass

def update_access_time(self, res_data, **kwargs):
pass

@abstractmethod
def search(self, embedding_data, **kwargs):
pass
Expand Down Expand Up @@ -181,6 +184,9 @@ def import_data(
def get_scalar_data(self, res_data, **kwargs):
return self.s.get_data_by_id(res_data[1])

def update_access_time(self, res_data, **kwargs):
return self.s.update_access_time(res_data[1])

def search(self, embedding_data, **kwargs):
embedding_data = normalize(embedding_data)
return self.v.search(embedding_data)
Expand Down
52 changes: 0 additions & 52 deletions tests/unit_tests/adapter/test_langchain_llms.py

This file was deleted.

Loading

0 comments on commit 23c7e3b

Please sign in to comment.