Add supported parameter type, validation and tests (#88)

* Add supported parameter type and validation

* make the tools format more compliant with openai

* more updates

* fix more

* fix unit test
This commit is contained in:
Adil Hafeez 2024-09-27 13:33:05 -07:00 committed by GitHub
parent 59229b8fc9
commit 75cf5e5304
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 273 additions and 11 deletions

View file

@ -17,7 +17,7 @@ prompt_targets:
- type: function_resolver
name: weather_forecast
description: This function resolver provides weather forecast information for a given city.
description: This function provides realtime weather forecast information for a given city.
parameters:
- name: city
required: true

View file

@ -60,8 +60,7 @@ services:
- OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal}
# uncomment following line to use ollama endpoint that is hosted by docker
# - OLLAMA_ENDPOINT=ollama
- OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M
# - OLLAMA_MODEL=Bolt-Function-Calling-1B:Q4_K_M
# - OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M
api_server:
build:

View file

@ -17,7 +17,7 @@ use proxy_wasm::types::*;
use public_types::common_types::open_ai::{
ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
StreamOptions, ToolType,
ParameterType, StreamOptions, ToolType,
};
use public_types::common_types::{
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
@ -369,7 +369,9 @@ impl StreamContext {
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
for entity in entities.iter() {
let param = FunctionParameter {
parameter_type: entity.parameter_type.clone(),
parameter_type: ParameterType::from(
entity.parameter_type.clone().unwrap_or("str".to_string()),
),
description: entity.description.clone(),
required: entity.required,
enum_values: entity.enum_values.clone(),

View file

@ -33,7 +33,7 @@ class ArchHandler:
def _format_system(self, tools: List[Dict[str, Any]]):
def convert_tools(tools):
return "\n".join([json.dumps(tool["function"]) for tool in tools])
return "\n".join([json.dumps(tool) for tool in tools])
tool_text = convert_tools(tools)

View file

@ -2,6 +2,12 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "diff"
version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8"
[[package]]
name = "equivalent"
version = "1.0.1"
@ -30,6 +36,22 @@ version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]]
name = "memchr"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "pretty_assertions"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ae130e2f271fbc2ac3a40fb1d07180839cdbbe443c7a27e1e3c13c5cac0116d"
dependencies = [
"diff",
"yansi",
]
[[package]]
name = "proc-macro2"
version = "1.0.86"
@ -43,7 +65,9 @@ dependencies = [
name = "public_types"
version = "0.1.0"
dependencies = [
"pretty_assertions",
"serde",
"serde_json",
"serde_yaml",
]
@ -82,6 +106,18 @@ dependencies = [
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.128"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
]
[[package]]
name = "serde_yaml"
version = "0.9.34+deprecated"
@ -117,3 +153,9 @@ name = "unsafe-libyaml"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861"
[[package]]
name = "yansi"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"

View file

@ -6,3 +6,7 @@ edition = "2021"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_yaml = "0.9.34"
[dev-dependencies]
pretty_assertions = "1.4.1"
serde_json = "1.0.64"

View file

@ -36,7 +36,7 @@ pub struct SearchPointResult {
pub mod open_ai {
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -57,6 +57,7 @@ pub mod open_ai {
#[serde(rename = "function")]
Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionTool {
#[serde(rename = "type")]
@ -71,16 +72,37 @@ pub mod open_ai {
pub parameters: FunctionParameters,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameters {
pub properties: HashMap<String, FunctionParameter>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
impl Serialize for FunctionParameters {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
// select all requried parameters
let required: Vec<&String> = self
.properties
.iter()
.filter(|(_, v)| v.required.unwrap_or(false))
.map(|(k, _)| k)
.collect();
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("properties", &self.properties)?;
if !required.is_empty() {
map.serialize_entry("required", &required)?;
}
map.end()
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameter {
#[serde(rename = "type")]
#[serde(skip_serializing_if = "Option::is_none")]
pub parameter_type: Option<String>,
#[serde(default = "ParameterType::string")]
pub parameter_type: ParameterType,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
@ -91,6 +113,60 @@ pub mod open_ai {
pub default: Option<String>,
}
impl Serialize for FunctionParameter {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(Some(5))?;
map.serialize_entry("type", &self.parameter_type)?;
map.serialize_entry("description", &self.description)?;
if let Some(enum_values) = &self.enum_values {
map.serialize_entry("enum", enum_values)?;
}
if let Some(default) = &self.default {
map.serialize_entry("default", default)?;
}
map.end()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ParameterType {
#[serde(rename = "int")]
Int,
#[serde(rename = "float")]
Float,
#[serde(rename = "bool")]
Bool,
#[serde(rename = "str")]
String,
#[serde(rename = "list")]
List,
#[serde(rename = "dict")]
Dict,
}
impl From<String> for ParameterType {
fn from(s: String) -> Self {
match s.as_str() {
"int" => ParameterType::Int,
"float" => ParameterType::Float,
"bool" => ParameterType::Bool,
"string" => ParameterType::String,
"list" => ParameterType::List,
"dict" => ParameterType::Dict,
_ => ParameterType::String,
}
}
}
impl ParameterType {
pub fn string() -> ParameterType {
ParameterType::String
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
pub include_usage: bool,
@ -99,9 +175,11 @@ pub mod open_ai {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
@ -195,3 +273,140 @@ pub struct PromptGuardResponse {
pub toxic_verdict: Option<bool>,
pub jailbreak_verdict: Option<bool>,
}
#[cfg(test)]
mod test {
use crate::common_types::open_ai::Message;
use pretty_assertions::{assert_eq, assert_ne};
use std::collections::HashMap;
const TOOL_SERIALIZED: &str = r#"{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "What city do you want to know the weather for?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "weather_forecast",
"description": "function to retrieve weather forecast",
"parameters": {
"properties": {
"city": {
"type": "str",
"description": "city for weather forecast",
"default": "test"
}
},
"required": [
"city"
]
}
}
}
],
"stream": true,
"stream_options": {
"include_usage": true
}
}"#;
#[test]
fn test_tool_type_request() {
use super::open_ai::{
ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType,
};
let mut properties = HashMap::new();
properties.insert(
"city".to_string(),
FunctionParameter {
parameter_type: ParameterType::String,
description: "city for weather forecast".to_string(),
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
},
);
let function_definition = FunctionDefinition {
name: "weather_forecast".to_string(),
description: "function to retrieve weather forecast".to_string(),
parameters: super::open_ai::FunctionParameters { properties },
};
let chat_completions_request = ChatCompletionsRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: Some("What city do you want to know the weather for?".to_string()),
model: None,
tool_calls: None,
}],
tools: Some(vec![super::open_ai::ChatCompletionTool {
tool_type: ToolType::Function,
function: function_definition,
}]),
stream: true,
stream_options: Some(super::open_ai::StreamOptions {
include_usage: true,
}),
};
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();
println!("{}", serialized);
assert_eq!(TOOL_SERIALIZED, serialized);
}
#[test]
fn test_parameter_types() {
use super::open_ai::{
ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType,
};
const PARAMETER_SERIALZIED: &str = r#"{
"city": {
"type": "str",
"description": "city for weather forecast",
"default": "test"
}
}"#;
let properties = HashMap::from([(
"city".to_string(),
FunctionParameter {
parameter_type: ParameterType::String,
description: "city for weather forecast".to_string(),
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
},
)]);
let serialized = serde_json::to_string_pretty(&properties).unwrap();
assert_eq!(PARAMETER_SERIALZIED, serialized);
// ensure that if type is missing it is set to string
const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#"
{
"city": {
"description": "city for weather forecast"
}
}"#;
let missing_type_deserialized: HashMap<String, FunctionParameter> =
serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap();
println!("{:?}", missing_type_deserialized);
assert_eq!(
missing_type_deserialized
.get("city")
.unwrap()
.parameter_type,
ParameterType::String
);
}
}