diff --git a/demos/function_calling/bolt_config.yaml b/demos/function_calling/bolt_config.yaml index 9b99364b..941eb457 100644 --- a/demos/function_calling/bolt_config.yaml +++ b/demos/function_calling/bolt_config.yaml @@ -17,7 +17,7 @@ prompt_targets: - type: function_resolver name: weather_forecast - description: This function resolver provides weather forecast information for a given city. + description: This function provides realtime weather forecast information for a given city. parameters: - name: city required: true diff --git a/demos/function_calling/docker-compose.yaml b/demos/function_calling/docker-compose.yaml index 1476b967..20e3002a 100644 --- a/demos/function_calling/docker-compose.yaml +++ b/demos/function_calling/docker-compose.yaml @@ -60,8 +60,7 @@ services: - OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal} # uncomment following line to use ollama endpoint that is hosted by docker # - OLLAMA_ENDPOINT=ollama - - OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M - # - OLLAMA_MODEL=Bolt-Function-Calling-1B:Q4_K_M + # - OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M api_server: build: diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index c7b6ae22..5607dc61 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -17,7 +17,7 @@ use proxy_wasm::types::*; use public_types::common_types::open_ai::{ ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message, - StreamOptions, ToolType, + ParameterType, StreamOptions, ToolType, }; use public_types::common_types::{ EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask, @@ -369,7 +369,9 @@ impl StreamContext { let mut properties: HashMap = HashMap::new(); for entity in entities.iter() { let param = FunctionParameter { - parameter_type: entity.parameter_type.clone(), + parameter_type: ParameterType::from( + entity.parameter_type.clone().unwrap_or("str".to_string()), + ), description: entity.description.clone(), required: entity.required, enum_values: entity.enum_values.clone(), diff --git a/function_resolver/app/arch_handler.py b/function_resolver/app/arch_handler.py index a92a62bc..77b0a65d 100644 --- a/function_resolver/app/arch_handler.py +++ b/function_resolver/app/arch_handler.py @@ -33,7 +33,7 @@ class ArchHandler: def _format_system(self, tools: List[Dict[str, Any]]): def convert_tools(tools): - return "\n".join([json.dumps(tool["function"]) for tool in tools]) + return "\n".join([json.dumps(tool) for tool in tools]) tool_text = convert_tools(tools) diff --git a/public_types/Cargo.lock b/public_types/Cargo.lock index 8b63d2ac..5a176a11 100644 --- a/public_types/Cargo.lock +++ b/public_types/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "diff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" + [[package]] name = "equivalent" version = "1.0.1" @@ -30,6 +36,22 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "pretty_assertions" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d" +dependencies = [ + "diff", + "yansi", +] + [[package]] name = "proc-macro2" version = "1.0.86" @@ -43,7 +65,9 @@ dependencies = [ name = "public_types" version = "0.1.0" dependencies = [ + "pretty_assertions", "serde", + "serde_json", "serde_yaml", ] @@ -82,6 +106,18 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.128" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + [[package]] name = "serde_yaml" version = "0.9.34+deprecated" @@ -117,3 +153,9 @@ name = "unsafe-libyaml" version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" + +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" diff --git a/public_types/Cargo.toml b/public_types/Cargo.toml index ec81e4a8..3c94335a 100644 --- a/public_types/Cargo.toml +++ b/public_types/Cargo.toml @@ -6,3 +6,7 @@ edition = "2021" [dependencies] serde = { version = "1.0", features = ["derive"] } serde_yaml = "0.9.34" + +[dev-dependencies] +pretty_assertions = "1.4.1" +serde_json = "1.0.64" diff --git a/public_types/src/common_types.rs b/public_types/src/common_types.rs index 385c7bef..df89a66f 100644 --- a/public_types/src/common_types.rs +++ b/public_types/src/common_types.rs @@ -36,7 +36,7 @@ pub struct SearchPointResult { pub mod open_ai { use std::collections::HashMap; - use serde::{Deserialize, Serialize}; + use serde::{ser::SerializeMap, Deserialize, Serialize}; use serde_yaml::Value; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -57,6 +57,7 @@ pub mod open_ai { #[serde(rename = "function")] Function, } + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionTool { #[serde(rename = "type")] @@ -71,16 +72,37 @@ pub mod open_ai { pub parameters: FunctionParameters, } - #[derive(Debug, Clone, Serialize, Deserialize)] + #[derive(Debug, Clone, Deserialize)] pub struct FunctionParameters { pub properties: HashMap, } - #[derive(Debug, Clone, Serialize, Deserialize)] + impl Serialize for FunctionParameters { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + // select all requried parameters + let required: Vec<&String> = self + .properties + .iter() + .filter(|(_, v)| v.required.unwrap_or(false)) + .map(|(k, _)| k) + .collect(); + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("properties", &self.properties)?; + if !required.is_empty() { + map.serialize_entry("required", &required)?; + } + map.end() + } + } + + #[derive(Debug, Clone, Deserialize)] pub struct FunctionParameter { #[serde(rename = "type")] - #[serde(skip_serializing_if = "Option::is_none")] - pub parameter_type: Option, + #[serde(default = "ParameterType::string")] + pub parameter_type: ParameterType, pub description: String, #[serde(skip_serializing_if = "Option::is_none")] pub required: Option, @@ -91,6 +113,60 @@ pub mod open_ai { pub default: Option, } + impl Serialize for FunctionParameter { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut map = serializer.serialize_map(Some(5))?; + map.serialize_entry("type", &self.parameter_type)?; + map.serialize_entry("description", &self.description)?; + if let Some(enum_values) = &self.enum_values { + map.serialize_entry("enum", enum_values)?; + } + if let Some(default) = &self.default { + map.serialize_entry("default", default)?; + } + map.end() + } + } + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + pub enum ParameterType { + #[serde(rename = "int")] + Int, + #[serde(rename = "float")] + Float, + #[serde(rename = "bool")] + Bool, + #[serde(rename = "str")] + String, + #[serde(rename = "list")] + List, + #[serde(rename = "dict")] + Dict, + } + + impl From for ParameterType { + fn from(s: String) -> Self { + match s.as_str() { + "int" => ParameterType::Int, + "float" => ParameterType::Float, + "bool" => ParameterType::Bool, + "string" => ParameterType::String, + "list" => ParameterType::List, + "dict" => ParameterType::Dict, + _ => ParameterType::String, + } + } + } + + impl ParameterType { + pub fn string() -> ParameterType { + ParameterType::String + } + } + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct StreamOptions { pub include_usage: bool, @@ -99,9 +175,11 @@ pub mod open_ai { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { pub role: String, + #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, } @@ -195,3 +273,140 @@ pub struct PromptGuardResponse { pub toxic_verdict: Option, pub jailbreak_verdict: Option, } + +#[cfg(test)] +mod test { + use crate::common_types::open_ai::Message; + use pretty_assertions::{assert_eq, assert_ne}; + use std::collections::HashMap; + + const TOOL_SERIALIZED: &str = r#"{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "What city do you want to know the weather for?" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "weather_forecast", + "description": "function to retrieve weather forecast", + "parameters": { + "properties": { + "city": { + "type": "str", + "description": "city for weather forecast", + "default": "test" + } + }, + "required": [ + "city" + ] + } + } + } + ], + "stream": true, + "stream_options": { + "include_usage": true + } +}"#; + + #[test] + fn test_tool_type_request() { + use super::open_ai::{ + ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType, + }; + + let mut properties = HashMap::new(); + properties.insert( + "city".to_string(), + FunctionParameter { + parameter_type: ParameterType::String, + description: "city for weather forecast".to_string(), + required: Some(true), + enum_values: None, + default: Some("test".to_string()), + }, + ); + + let function_definition = FunctionDefinition { + name: "weather_forecast".to_string(), + description: "function to retrieve weather forecast".to_string(), + parameters: super::open_ai::FunctionParameters { properties }, + }; + + let chat_completions_request = ChatCompletionsRequest { + model: "gpt-3.5-turbo".to_string(), + messages: vec![Message { + role: "user".to_string(), + content: Some("What city do you want to know the weather for?".to_string()), + model: None, + tool_calls: None, + }], + tools: Some(vec![super::open_ai::ChatCompletionTool { + tool_type: ToolType::Function, + function: function_definition, + }]), + stream: true, + stream_options: Some(super::open_ai::StreamOptions { + include_usage: true, + }), + }; + + let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap(); + println!("{}", serialized); + assert_eq!(TOOL_SERIALIZED, serialized); + } + + #[test] + fn test_parameter_types() { + use super::open_ai::{ + ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType, + }; + + const PARAMETER_SERIALZIED: &str = r#"{ + "city": { + "type": "str", + "description": "city for weather forecast", + "default": "test" + } +}"#; + + let properties = HashMap::from([( + "city".to_string(), + FunctionParameter { + parameter_type: ParameterType::String, + description: "city for weather forecast".to_string(), + required: Some(true), + enum_values: None, + default: Some("test".to_string()), + }, + )]); + + let serialized = serde_json::to_string_pretty(&properties).unwrap(); + assert_eq!(PARAMETER_SERIALZIED, serialized); + + // ensure that if type is missing it is set to string + const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#" + { + "city": { + "description": "city for weather forecast" + } + }"#; + + let missing_type_deserialized: HashMap = + serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap(); + println!("{:?}", missing_type_deserialized); + assert_eq!( + missing_type_deserialized + .get("city") + .unwrap() + .parameter_type, + ParameterType::String + ); + } +}