diff --git a/examples/function_call.rs b/examples/function_call.rs index 1961778..7bff433 100644 --- a/examples/function_call.rs +++ b/examples/function_call.rs @@ -46,7 +46,7 @@ fn main() -> Result<(), Box> { required: Some(vec![String::from("coin")]), }, }]), - function_call: Some(FunctionCallType::auto), //Some(FunctionCallType::Function { name: "test".to_string() }) + function_call: Some(FunctionCallType::Auto), // Some(FunctionCallType::Function { name: "test".to_string() }), temperature: None, top_p: None, n: None, @@ -59,6 +59,10 @@ fn main() -> Result<(), Box> { user: None, }; + // debug reuqest json + // let serialized = serde_json::to_string(&req).unwrap(); + // println!("{}", serialized); + let result = client.chat_completion(req)?; match result.choices[0].finish_reason { diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index 8d590a7..03d7d98 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -1,4 +1,5 @@ -use serde::{Deserialize, Serialize}; +use serde::ser::SerializeMap; +use serde::{Deserialize, Serialize, Serializer}; use std::collections::HashMap; use crate::v1::common; @@ -14,11 +15,10 @@ pub const GPT4_32K_0314: &str = "gpt-4-32k-0314"; pub const GPT4_0613: &str = "gpt-4-0613"; #[derive(Debug, Serialize)] -#[allow(non_camel_case_types)] pub enum FunctionCallType { - none, - auto, - function { name: String }, + None, + Auto, + Function { name: String }, } #[derive(Debug, Serialize)] @@ -28,6 +28,7 @@ pub struct ChatCompletionRequest { #[serde(skip_serializing_if = "Option::is_none")] pub functions: Option>, #[serde(skip_serializing_if = "Option::is_none")] + #[serde(serialize_with = "serialize_function_call")] pub function_call: Option, #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, @@ -160,3 +161,22 @@ pub struct FunctionCall { #[serde(skip_serializing_if = "Option::is_none")] pub arguments: Option, } + +fn serialize_function_call( + value: &Option, + serializer: S, +) -> Result +where + S: Serializer, +{ + match value { + Some(FunctionCallType::None) => serializer.serialize_str("none"), + Some(FunctionCallType::Auto) => serializer.serialize_str("auto"), + Some(FunctionCallType::Function { name }) => { + let mut map = serializer.serialize_map(Some(1))?; + map.serialize_entry("name", name)?; + map.end() + } + None => serializer.serialize_none(), + } +}