Skip to content

Commit

Permalink
feat(model): Support tongyi embedding (#1552)
Browse files Browse the repository at this point in the history
Co-authored-by: 无剑 <[email protected]>
Co-authored-by: csunny <[email protected]>
Co-authored-by: aries_ckt <[email protected]>
  • Loading branch information
4 people committed Jun 25, 2024
1 parent 47d205f commit fda1a56
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 2 deletions.
5 changes: 5 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ KNOWLEDGE_SEARCH_REWRITE=False
# proxy_openai_proxy_api_key={your-openai-sk}
# proxy_openai_proxy_backend=text-embedding-ada-002


## qwen embedding model, See dbgpt/model/parameter.py
# EMBEDDING_MODEL=proxy_tongyi
# proxy_tongyi_proxy_backend=text-embedding-v1

## Common HTTP embedding model
# EMBEDDING_MODEL=proxy_http_openapi
# proxy_http_openapi_proxy_server_url=http://localhost:8100/api/v1/embeddings
Expand Down
1 change: 1 addition & 0 deletions dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def get_device() -> str:
# Common HTTP embedding model
"proxy_http_openapi": "proxy_http_openapi",
"proxy_ollama": "proxy_ollama",
"proxy_tongyi": "proxy_tongyi",
# Rerank model, rerank mode is a special embedding model
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
Expand Down
8 changes: 8 additions & 0 deletions dbgpt/model/adapter/embeddings_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def load(self, model_name: str, param: BaseEmbeddingModelParameters) -> Embeddin
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIEmbeddings(**openapi_param)
elif model_name in ["proxy_tongyi"]:
from dbgpt.rag.embedding import TongYiEmbeddings

proxy_param = cast(ProxyEmbeddingParameters, param)
tongyi_param = {"api_key": proxy_param.proxy_api_key}
if proxy_param.proxy_backend:
tongyi_param["model_name"] = proxy_param.proxy_backend
return TongYiEmbeddings(**tongyi_param)
elif model_name in ["proxy_ollama"]:
from dbgpt.rag.embedding import OllamaEmbeddings

Expand Down
3 changes: 1 addition & 2 deletions dbgpt/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,8 +665,7 @@ def is_rerank_model(self) -> bool:


_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,"
"proxy_ollama,rerank_proxy_http_openapi",
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,rerank_proxy_http_openapi",
}

EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/rag/embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
JinaEmbeddings,
OllamaEmbeddings,
OpenAPIEmbeddings,
TongYiEmbeddings,
)
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401

Expand All @@ -29,6 +30,7 @@
"DefaultEmbeddingFactory",
"EmbeddingFactory",
"WrappedEmbeddingFactory",
"TongYiEmbeddings",
"CrossEncoderRerankEmbeddings",
"OpenAPIRerankEmbeddings",
]
80 changes: 80 additions & 0 deletions dbgpt/rag/embedding/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,3 +826,83 @@ async def aembed_query(self, text: str) -> List[float]:
return embedding["embedding"]
except ollama.ResponseError as e:
raise ValueError(f"**Ollama Response Error, Please CheckErrorInfo.**: {e}")


class TongYiEmbeddings(BaseModel, Embeddings):
"""The tongyi embeddings.
import dashscope
from http import HTTPStatus
from dashscope import TextEmbedding
dashscope.api_key = ''
def embed_with_list_of_str():
resp = TextEmbedding.call(
model=TextEmbedding.Models.text_embedding_v1,
# 最多支持10条,每条最长支持2048tokens
input=['风急天高猿啸哀', '渚清沙白鸟飞回', '无边落木萧萧下', '不尽长江滚滚来']
)
if resp.status_code == HTTPStatus.OK:
print(resp)
else:
print(resp)
if __name__ == '__main__':
embed_with_list_of_str()
"""

model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
model_name: str = Field(
default="text-embedding-v1", description="The name of the model to use."
)

def __init__(self, **kwargs):
"""Initialize the OpenAPIEmbeddings."""
try:
import dashscope # type: ignore
except ImportError as exc:
raise ValueError(
"Could not import python package: dashscope "
"Please install dashscope by command `pip install dashscope"
) from exc
dashscope.TextEmbedding.api_key = kwargs.get("api_key")
super().__init__(**kwargs)
self._api_key = kwargs.get("api_key")

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Get the embeddings for a list of texts.
Args:
texts (Documents): A list of texts to get embeddings for.
Returns:
Embedded texts as List[List[float]], where each inner List[float]
corresponds to a single input text.
"""
from dashscope import TextEmbedding

# 最多支持10条,每条最长支持2048tokens
resp = TextEmbedding.call(
model=self.model_name, input=texts, api_key=self._api_key
)
if "output" not in resp:
raise RuntimeError(resp["message"])

embeddings = resp["output"]["embeddings"]
sorted_embeddings = sorted(embeddings, key=lambda e: e["text_index"])

return [result["embedding"] for result in sorted_embeddings]

def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a OpenAPI embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]

0 comments on commit fda1a56

Please sign in to comment.