From 4e6fbcedd5853d62720a036273b2c2c4b96b7738 Mon Sep 17 00:00:00 2001 From: Dongri Jin Date: Fri, 26 Jan 2024 11:13:24 +0900 Subject: [PATCH] Add embedding models --- examples/embedding.rs | 6 ++---- src/v1/chat_completion.rs | 4 ++-- src/v1/common.rs | 9 +++++++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/embedding.rs b/examples/embedding.rs index 70403c9..da7f201 100644 --- a/examples/embedding.rs +++ b/examples/embedding.rs @@ -1,14 +1,12 @@ use openai_api_rs::v1::api::Client; +use openai_api_rs::v1::common::TEXT_EMBEDDING_3_SMALL; use openai_api_rs::v1::embedding::EmbeddingRequest; use std::env; fn main() -> Result<(), Box> { let client = Client::new(env::var("OPENAI_API_KEY").unwrap().to_string()); - let req = EmbeddingRequest::new( - "text-embedding-ada-002".to_string(), - "story time".to_string(), - ); + let req = EmbeddingRequest::new(TEXT_EMBEDDING_3_SMALL.to_string(), "story time".to_string()); let result = client.embedding(req)?; println!("{:?}", result.data); diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index fc4465c..e9bf223 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -158,7 +158,7 @@ pub struct ChatCompletionMessageForResponse { pub tool_calls: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct ChatCompletionChoice { pub index: i64, pub message: ChatCompletionMessageForResponse, @@ -166,7 +166,7 @@ pub struct ChatCompletionChoice { pub finish_details: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct ChatCompletionResponse { pub id: String, pub object: String, diff --git a/src/v1/common.rs b/src/v1/common.rs index fe4561a..7cceced 100644 --- a/src/v1/common.rs +++ b/src/v1/common.rs @@ -1,6 +1,6 @@ -use serde::Deserialize; +use serde::{Deserialize, Serialize}; -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct Usage { pub prompt_tokens: i32, pub completion_tokens: i32, @@ -45,3 +45,8 @@ pub const GPT4_32K_0314: &str = "gpt-4-32k-0314"; // https://platform.openai.com/docs/api-reference/images/object pub const DALL_E_2: &str = "dall-e-2"; pub const DALL_E_3: &str = "dall-e-3"; + +// https://platform.openai.com/docs/guides/embeddings/embedding-models +pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small"; +pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large"; +pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002";