diff --git a/crates/prompt_gateway/src/hallucination.rs b/crates/common/src/api/hallucination.rs similarity index 94% rename from crates/prompt_gateway/src/hallucination.rs rename to crates/common/src/api/hallucination.rs index 130f8723..c0efc198 100644 --- a/crates/prompt_gateway/src/hallucination.rs +++ b/crates/common/src/api/hallucination.rs @@ -1,7 +1,23 @@ -use common::{ - common_types::open_ai::Message, +use std::collections::HashMap; + +use crate::{ + api::open_ai::Message, consts::{ARCH_MODEL_PREFIX, HALLUCINATION_TEMPLATE, USER_ROLE}, }; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HallucinationClassificationRequest { + pub prompt: String, + pub parameters: HashMap, + pub model: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HallucinationClassificationResponse { + pub params_scores: HashMap, + pub model: String, +} pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec { let mut arch_assistant = false; @@ -42,7 +58,7 @@ pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec { #[cfg(test)] mod test { - use common::common_types::open_ai::Message; + use crate::api::open_ai::Message; use pretty_assertions::assert_eq; use super::extract_messages_for_hallucination; diff --git a/crates/common/src/api/mod.rs b/crates/common/src/api/mod.rs new file mode 100644 index 00000000..d9de5c86 --- /dev/null +++ b/crates/common/src/api/mod.rs @@ -0,0 +1,4 @@ +pub mod hallucination; +pub mod open_ai; +pub mod prompt_guard; +pub mod zero_shot; diff --git a/crates/common/src/api/open_ai.rs b/crates/common/src/api/open_ai.rs new file mode 100644 index 00000000..20b550ae --- /dev/null +++ b/crates/common/src/api/open_ai.rs @@ -0,0 +1,653 @@ +use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE}; +use serde::{ser::SerializeMap, Deserialize, Serialize}; +use serde_yaml::Value; +use std::{ + collections::{HashMap, VecDeque}, + fmt::Display, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionsRequest { + #[serde(default)] + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(default)] + pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ToolType { + #[serde(rename = "function")] + Function, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionTool { + #[serde(rename = "type")] + pub tool_type: ToolType, + pub function: FunctionDefinition, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: String, + pub parameters: FunctionParameters, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct FunctionParameters { + pub properties: HashMap, +} + +impl Serialize for FunctionParameters { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + // select all requried parameters + let required: Vec<&String> = self + .properties + .iter() + .filter(|(_, v)| v.required.unwrap_or(false)) + .map(|(k, _)| k) + .collect(); + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("properties", &self.properties)?; + if !required.is_empty() { + map.serialize_entry("required", &required)?; + } + map.end() + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct FunctionParameter { + #[serde(rename = "type")] + #[serde(default = "ParameterType::string")] + pub parameter_type: ParameterType, + pub description: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "enum")] + pub enum_values: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, +} + +impl Serialize for FunctionParameter { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut map = serializer.serialize_map(Some(5))?; + map.serialize_entry("type", &self.parameter_type)?; + map.serialize_entry("description", &self.description)?; + if let Some(enum_values) = &self.enum_values { + map.serialize_entry("enum", enum_values)?; + } + if let Some(default) = &self.default { + map.serialize_entry("default", default)?; + } + map.end() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum ParameterType { + #[serde(rename = "int")] + Int, + #[serde(rename = "float")] + Float, + #[serde(rename = "bool")] + Bool, + #[serde(rename = "str")] + String, + #[serde(rename = "list")] + List, + #[serde(rename = "dict")] + Dict, +} + +impl From for ParameterType { + fn from(s: String) -> Self { + match s.as_str() { + "int" => ParameterType::Int, + "integer" => ParameterType::Int, + "float" => ParameterType::Float, + "bool" => ParameterType::Bool, + "boolean" => ParameterType::Bool, + "str" => ParameterType::String, + "string" => ParameterType::String, + "list" => ParameterType::List, + "array" => ParameterType::List, + "dict" => ParameterType::Dict, + "dictionary" => ParameterType::Dict, + _ => ParameterType::String, + } + } +} + +impl ParameterType { + pub fn string() -> ParameterType { + ParameterType::String + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamOptions { + pub include_usage: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Choice { + pub finish_reason: String, + pub index: usize, + pub message: Message, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub tool_type: ToolType, + pub function: FunctionCallDetail, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionCallDetail { + pub name: String, + pub arguments: HashMap, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ToolCallState { + pub key: String, + pub message: Option, + pub tool_call: FunctionCallDetail, + pub tool_response: String, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ArchState { + ToolCall(Vec), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionsResponse { + pub usage: Option, + pub choices: Vec, + pub model: String, + pub metadata: Option>, +} + +impl ChatCompletionsResponse { + pub fn new(message: String) -> Self { + ChatCompletionsResponse { + choices: vec![Choice { + message: Message { + role: ASSISTANT_ROLE.to_string(), + content: Some(message), + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: None, + tool_call_id: None, + }, + index: 0, + finish_reason: "done".to_string(), + }], + usage: None, + model: ARCH_FC_MODEL_NAME.to_string(), + metadata: None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Usage { + pub completion_tokens: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatCompletionStreamResponse { + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + pub choices: Vec, +} + +impl ChatCompletionStreamResponse { + pub fn new( + response: Option, + role: Option, + model: Option, + tool_calls: Option>, + ) -> Self { + ChatCompletionStreamResponse { + model, + choices: vec![ChunkChoice { + delta: Delta { + role, + content: response, + tool_calls, + model: None, + tool_call_id: None, + }, + finish_reason: None, + }], + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ChatCompletionChunkResponseError { + #[error("failed to deserialize")] + Deserialization(#[from] serde_json::Error), + #[error("empty content in data chunk")] + EmptyContent, + #[error("no chunks present")] + NoChunks, +} + +pub struct ChatCompletionStreamResponseServerEvents { + pub events: Vec, +} + +impl Display for ChatCompletionStreamResponseServerEvents { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let tokens_str = self + .events + .iter() + .map(|response_chunk| { + if response_chunk.choices.is_empty() { + return "".to_string(); + } + response_chunk.choices[0] + .delta + .content + .clone() + .unwrap_or("".to_string()) + }) + .collect::>() + .join(""); + + write!(f, "{}", tokens_str) + } +} + +impl TryFrom<&str> for ChatCompletionStreamResponseServerEvents { + type Error = ChatCompletionChunkResponseError; + + fn try_from(value: &str) -> Result { + let response_chunks: VecDeque = value + .lines() + .filter(|line| line.starts_with("data: ")) + .map(|line| line.get(6..).unwrap()) + .filter(|data_chunk| *data_chunk != "[DONE]") + .map(serde_json::from_str::) + .collect::, _>>()?; + + Ok(ChatCompletionStreamResponseServerEvents { + events: response_chunks.into(), + }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChunkChoice { + pub delta: Delta, + // TODO: could this be an enum? + pub finish_reason: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Delta { + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +pub fn to_server_events(chunks: Vec) -> String { + let mut response_str = String::new(); + for chunk in chunks.iter() { + response_str.push_str("data: "); + response_str.push_str(&serde_json::to_string(&chunk).unwrap()); + response_str.push_str("\n\n"); + } + response_str +} + +#[cfg(test)] +mod test { + use super::{ChatCompletionStreamResponseServerEvents, Message}; + use pretty_assertions::assert_eq; + use std::collections::HashMap; + + const TOOL_SERIALIZED: &str = r#"{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "What city do you want to know the weather for?" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "weather_forecast", + "description": "function to retrieve weather forecast", + "parameters": { + "properties": { + "city": { + "type": "str", + "description": "city for weather forecast", + "default": "test" + } + }, + "required": [ + "city" + ] + } + } + } + ], + "stream": true, + "stream_options": { + "include_usage": true + } +}"#; + + #[test] + fn test_tool_type_request() { + use super::{ + ChatCompletionTool, ChatCompletionsRequest, FunctionDefinition, FunctionParameter, + FunctionParameters, ParameterType, StreamOptions, ToolType, + }; + + let mut properties = HashMap::new(); + properties.insert( + "city".to_string(), + FunctionParameter { + parameter_type: ParameterType::String, + description: "city for weather forecast".to_string(), + required: Some(true), + enum_values: None, + default: Some("test".to_string()), + }, + ); + + let function_definition = FunctionDefinition { + name: "weather_forecast".to_string(), + description: "function to retrieve weather forecast".to_string(), + parameters: FunctionParameters { properties }, + }; + + let chat_completions_request = ChatCompletionsRequest { + model: "gpt-3.5-turbo".to_string(), + messages: vec![Message { + role: "user".to_string(), + content: Some("What city do you want to know the weather for?".to_string()), + model: None, + tool_calls: None, + tool_call_id: None, + }], + tools: Some(vec![ChatCompletionTool { + tool_type: ToolType::Function, + function: function_definition, + }]), + stream: true, + stream_options: Some(StreamOptions { + include_usage: true, + }), + metadata: None, + }; + + let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap(); + println!("{}", serialized); + assert_eq!(TOOL_SERIALIZED, serialized); + } + + #[test] + fn test_parameter_types() { + use super::{FunctionParameter, ParameterType}; + + const PARAMETER_SERIALZIED: &str = r#"{ + "city": { + "type": "str", + "description": "city for weather forecast", + "default": "test" + } +}"#; + + let properties = HashMap::from([( + "city".to_string(), + FunctionParameter { + parameter_type: ParameterType::String, + description: "city for weather forecast".to_string(), + required: Some(true), + enum_values: None, + default: Some("test".to_string()), + }, + )]); + + let serialized = serde_json::to_string_pretty(&properties).unwrap(); + assert_eq!(PARAMETER_SERIALZIED, serialized); + + // ensure that if type is missing it is set to string + const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#" + { + "city": { + "description": "city for weather forecast" + } + }"#; + + let missing_type_deserialized: HashMap = + serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap(); + println!("{:?}", missing_type_deserialized); + assert_eq!( + missing_type_deserialized + .get("city") + .unwrap() + .parameter_type, + ParameterType::String + ); + } + + #[test] + fn stream_chunk_parse() { + const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}]} + + +"#; + + let sever_events = + ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap(); + assert_eq!(sever_events.events.len(), 5); + assert_eq!( + sever_events.events[0].choices[0] + .delta + .content + .as_ref() + .unwrap(), + "" + ); + assert_eq!( + sever_events.events[1].choices[0] + .delta + .content + .as_ref() + .unwrap(), + "Hello" + ); + assert_eq!( + sever_events.events[2].choices[0] + .delta + .content + .as_ref() + .unwrap(), + "!" + ); + assert_eq!( + sever_events.events[3].choices[0] + .delta + .content + .as_ref() + .unwrap(), + " How" + ); + assert_eq!( + sever_events.events[4].choices[0] + .delta + .content + .as_ref() + .unwrap(), + " can" + ); + assert_eq!(sever_events.to_string(), "Hello! How can"); + } + + #[test] + fn stream_chunk_parse_done() { + const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]} + +data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + +data: [DONE] +"#; + + let sever_events: ChatCompletionStreamResponseServerEvents = + ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap(); + assert_eq!(sever_events.events.len(), 6); + assert_eq!( + sever_events.events[0].choices[0] + .delta + .content + .as_ref() + .unwrap(), + " I" + ); + assert_eq!( + sever_events.events[1].choices[0] + .delta + .content + .as_ref() + .unwrap(), + " assist" + ); + assert_eq!( + sever_events.events[2].choices[0] + .delta + .content + .as_ref() + .unwrap(), + " you" + ); + assert_eq!( + sever_events.events[3].choices[0] + .delta + .content + .as_ref() + .unwrap(), + " today" + ); + assert_eq!( + sever_events.events[4].choices[0] + .delta + .content + .as_ref() + .unwrap(), + "?" + ); + assert_eq!(sever_events.events[5].choices[0].delta.content, None); + + assert_eq!(sever_events.to_string(), " I assist you today?"); + } + + #[test] + fn stream_chunk_parse_mistral() { + const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]} + +data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}} + +data: [DONE] +"#; + + let sever_events: ChatCompletionStreamResponseServerEvents = + ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap(); + assert_eq!(sever_events.events.len(), 11); + + assert_eq!( + sever_events.to_string(), + "Hello! How can I assist you today?" + ); + } +} diff --git a/crates/common/src/api/prompt_guard.rs b/crates/common/src/api/prompt_guard.rs new file mode 100644 index 00000000..e474e3a9 --- /dev/null +++ b/crates/common/src/api/prompt_guard.rs @@ -0,0 +1,25 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum PromptGuardTask { + #[serde(rename = "jailbreak")] + Jailbreak, + #[serde(rename = "toxicity")] + Toxicity, + #[serde(rename = "both")] + Both, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PromptGuardRequest { + pub input: String, + pub task: PromptGuardTask, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PromptGuardResponse { + pub toxic_prob: Option, + pub jailbreak_prob: Option, + pub toxic_verdict: Option, + pub jailbreak_verdict: Option, +} diff --git a/crates/common/src/api/zero_shot.rs b/crates/common/src/api/zero_shot.rs new file mode 100644 index 00000000..fe08e797 --- /dev/null +++ b/crates/common/src/api/zero_shot.rs @@ -0,0 +1,18 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ZeroShotClassificationRequest { + pub input: String, + pub labels: Vec, + pub model: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ZeroShotClassificationResponse { + pub predicted_class: String, + pub predicted_class_score: f64, + pub scores: HashMap, + pub model: String, +} diff --git a/crates/common/src/common_types.rs b/crates/common/src/common_types.rs deleted file mode 100644 index 97b08c06..00000000 --- a/crates/common/src/common_types.rs +++ /dev/null @@ -1,743 +0,0 @@ -use crate::configuration::PromptTarget; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EmbeddingRequest { - pub prompt_target: PromptTarget, -} - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub enum EmbeddingType { - Name, - Description, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct VectorPoint { - pub id: String, - pub payload: HashMap, - pub vector: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct StoreVectorEmbeddingsRequest { - pub points: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SearchPointResult { - pub id: String, - pub version: i32, - pub score: f64, - pub payload: HashMap, -} - -pub mod open_ai { - use std::{ - collections::{HashMap, VecDeque}, - fmt::Display, - }; - - use serde::{ser::SerializeMap, Deserialize, Serialize}; - use serde_yaml::Value; - - use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE}; - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct ChatCompletionsRequest { - #[serde(default)] - pub model: String, - pub messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, - #[serde(default)] - pub stream: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub stream_options: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub metadata: Option>, - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub enum ToolType { - #[serde(rename = "function")] - Function, - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct ChatCompletionTool { - #[serde(rename = "type")] - pub tool_type: ToolType, - pub function: FunctionDefinition, - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct FunctionDefinition { - pub name: String, - pub description: String, - pub parameters: FunctionParameters, - } - - #[derive(Debug, Clone, Deserialize)] - pub struct FunctionParameters { - pub properties: HashMap, - } - - impl Serialize for FunctionParameters { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - // select all requried parameters - let required: Vec<&String> = self - .properties - .iter() - .filter(|(_, v)| v.required.unwrap_or(false)) - .map(|(k, _)| k) - .collect(); - let mut map = serializer.serialize_map(Some(2))?; - map.serialize_entry("properties", &self.properties)?; - if !required.is_empty() { - map.serialize_entry("required", &required)?; - } - map.end() - } - } - - #[derive(Debug, Clone, Deserialize)] - pub struct FunctionParameter { - #[serde(rename = "type")] - #[serde(default = "ParameterType::string")] - pub parameter_type: ParameterType, - pub description: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub required: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(rename = "enum")] - pub enum_values: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub default: Option, - } - - impl Serialize for FunctionParameter { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let mut map = serializer.serialize_map(Some(5))?; - map.serialize_entry("type", &self.parameter_type)?; - map.serialize_entry("description", &self.description)?; - if let Some(enum_values) = &self.enum_values { - map.serialize_entry("enum", enum_values)?; - } - if let Some(default) = &self.default { - map.serialize_entry("default", default)?; - } - map.end() - } - } - - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] - pub enum ParameterType { - #[serde(rename = "int")] - Int, - #[serde(rename = "float")] - Float, - #[serde(rename = "bool")] - Bool, - #[serde(rename = "str")] - String, - #[serde(rename = "list")] - List, - #[serde(rename = "dict")] - Dict, - } - - impl From for ParameterType { - fn from(s: String) -> Self { - match s.as_str() { - "int" => ParameterType::Int, - "integer" => ParameterType::Int, - "float" => ParameterType::Float, - "bool" => ParameterType::Bool, - "boolean" => ParameterType::Bool, - "str" => ParameterType::String, - "string" => ParameterType::String, - "list" => ParameterType::List, - "array" => ParameterType::List, - "dict" => ParameterType::Dict, - "dictionary" => ParameterType::Dict, - _ => ParameterType::String, - } - } - } - - impl ParameterType { - pub fn string() -> ParameterType { - ParameterType::String - } - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct StreamOptions { - pub include_usage: bool, - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct Message { - pub role: String, - - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, - - #[serde(skip_serializing_if = "Option::is_none")] - pub model: Option, - - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct Choice { - pub finish_reason: String, - pub index: usize, - pub message: Message, - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct ToolCall { - pub id: String, - #[serde(rename = "type")] - pub tool_type: ToolType, - pub function: FunctionCallDetail, - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct FunctionCallDetail { - pub name: String, - pub arguments: HashMap, - } - - #[derive(Debug, Deserialize, Serialize)] - pub struct ToolCallState { - pub key: String, - pub message: Option, - pub tool_call: FunctionCallDetail, - pub tool_response: String, - } - - #[derive(Debug, Deserialize, Serialize)] - #[serde(untagged)] - pub enum ArchState { - ToolCall(Vec), - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct ChatCompletionsResponse { - pub usage: Option, - pub choices: Vec, - pub model: String, - pub metadata: Option>, - } - - impl ChatCompletionsResponse { - pub fn new(message: String) -> Self { - ChatCompletionsResponse { - choices: vec![Choice { - message: Message { - role: ASSISTANT_ROLE.to_string(), - content: Some(message), - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_calls: None, - tool_call_id: None, - }, - index: 0, - finish_reason: "done".to_string(), - }], - usage: None, - model: ARCH_FC_MODEL_NAME.to_string(), - metadata: None, - } - } - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct Usage { - pub completion_tokens: usize, - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct ChatCompletionStreamResponse { - #[serde(skip_serializing_if = "Option::is_none")] - pub model: Option, - pub choices: Vec, - } - - impl ChatCompletionStreamResponse { - pub fn new( - response: Option, - role: Option, - model: Option, - tool_calls: Option>, - ) -> Self { - ChatCompletionStreamResponse { - model, - choices: vec![ChunkChoice { - delta: Delta { - role, - content: response, - tool_calls, - model: None, - tool_call_id: None, - }, - finish_reason: None, - }], - } - } - } - - #[derive(Debug, thiserror::Error)] - pub enum ChatCompletionChunkResponseError { - #[error("failed to deserialize")] - Deserialization(#[from] serde_json::Error), - #[error("empty content in data chunk")] - EmptyContent, - #[error("no chunks present")] - NoChunks, - } - - pub struct ChatCompletionStreamResponseServerEvents { - pub events: Vec, - } - - impl Display for ChatCompletionStreamResponseServerEvents { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let tokens_str = self - .events - .iter() - .map(|response_chunk| { - if response_chunk.choices.is_empty() { - return "".to_string(); - } - response_chunk.choices[0] - .delta - .content - .clone() - .unwrap_or("".to_string()) - }) - .collect::>() - .join(""); - - write!(f, "{}", tokens_str) - } - } - - impl TryFrom<&str> for ChatCompletionStreamResponseServerEvents { - type Error = ChatCompletionChunkResponseError; - - fn try_from(value: &str) -> Result { - let response_chunks: VecDeque = value - .lines() - .filter(|line| line.starts_with("data: ")) - .map(|line| line.get(6..).unwrap()) - .filter(|data_chunk| *data_chunk != "[DONE]") - .map(serde_json::from_str::) - .collect::, _>>()?; - - Ok(ChatCompletionStreamResponseServerEvents { - events: response_chunks.into(), - }) - } - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct ChunkChoice { - pub delta: Delta, - // TODO: could this be an enum? - pub finish_reason: Option, - } - - #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct Delta { - #[serde(skip_serializing_if = "Option::is_none")] - pub role: Option, - - #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, - - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - - #[serde(skip_serializing_if = "Option::is_none")] - pub model: Option, - - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, - } - - pub fn to_server_events(chunks: Vec) -> String { - let mut response_str = String::new(); - for chunk in chunks.iter() { - response_str.push_str("data: "); - response_str.push_str(&serde_json::to_string(&chunk).unwrap()); - response_str.push_str("\n\n"); - } - response_str - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ZeroShotClassificationRequest { - pub input: String, - pub labels: Vec, - pub model: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ZeroShotClassificationResponse { - pub predicted_class: String, - pub predicted_class_score: f64, - pub scores: HashMap, - pub model: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct HallucinationClassificationRequest { - pub prompt: String, - pub parameters: HashMap, - pub model: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct HallucinationClassificationResponse { - pub params_scores: HashMap, - pub model: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum PromptGuardTask { - #[serde(rename = "jailbreak")] - Jailbreak, - #[serde(rename = "toxicity")] - Toxicity, - #[serde(rename = "both")] - Both, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PromptGuardRequest { - pub input: String, - pub task: PromptGuardTask, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PromptGuardResponse { - pub toxic_prob: Option, - pub jailbreak_prob: Option, - pub toxic_verdict: Option, - pub jailbreak_verdict: Option, -} - -#[cfg(test)] -mod test { - use crate::common_types::open_ai::{ChatCompletionStreamResponseServerEvents, Message}; - use pretty_assertions::assert_eq; - use std::collections::HashMap; - - const TOOL_SERIALIZED: &str = r#"{ - "model": "gpt-3.5-turbo", - "messages": [ - { - "role": "user", - "content": "What city do you want to know the weather for?" - } - ], - "tools": [ - { - "type": "function", - "function": { - "name": "weather_forecast", - "description": "function to retrieve weather forecast", - "parameters": { - "properties": { - "city": { - "type": "str", - "description": "city for weather forecast", - "default": "test" - } - }, - "required": [ - "city" - ] - } - } - } - ], - "stream": true, - "stream_options": { - "include_usage": true - } -}"#; - - #[test] - fn test_tool_type_request() { - use super::open_ai::{ - ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType, - }; - - let mut properties = HashMap::new(); - properties.insert( - "city".to_string(), - FunctionParameter { - parameter_type: ParameterType::String, - description: "city for weather forecast".to_string(), - required: Some(true), - enum_values: None, - default: Some("test".to_string()), - }, - ); - - let function_definition = FunctionDefinition { - name: "weather_forecast".to_string(), - description: "function to retrieve weather forecast".to_string(), - parameters: super::open_ai::FunctionParameters { properties }, - }; - - let chat_completions_request = ChatCompletionsRequest { - model: "gpt-3.5-turbo".to_string(), - messages: vec![Message { - role: "user".to_string(), - content: Some("What city do you want to know the weather for?".to_string()), - model: None, - tool_calls: None, - tool_call_id: None, - }], - tools: Some(vec![super::open_ai::ChatCompletionTool { - tool_type: ToolType::Function, - function: function_definition, - }]), - stream: true, - stream_options: Some(super::open_ai::StreamOptions { - include_usage: true, - }), - metadata: None, - }; - - let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap(); - println!("{}", serialized); - assert_eq!(TOOL_SERIALIZED, serialized); - } - - #[test] - fn test_parameter_types() { - use super::open_ai::{FunctionParameter, ParameterType}; - - const PARAMETER_SERIALZIED: &str = r#"{ - "city": { - "type": "str", - "description": "city for weather forecast", - "default": "test" - } -}"#; - - let properties = HashMap::from([( - "city".to_string(), - FunctionParameter { - parameter_type: ParameterType::String, - description: "city for weather forecast".to_string(), - required: Some(true), - enum_values: None, - default: Some("test".to_string()), - }, - )]); - - let serialized = serde_json::to_string_pretty(&properties).unwrap(); - assert_eq!(PARAMETER_SERIALZIED, serialized); - - // ensure that if type is missing it is set to string - const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#" - { - "city": { - "description": "city for weather forecast" - } - }"#; - - let missing_type_deserialized: HashMap = - serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap(); - println!("{:?}", missing_type_deserialized); - assert_eq!( - missing_type_deserialized - .get("city") - .unwrap() - .parameter_type, - ParameterType::String - ); - } - - #[test] - fn stream_chunk_parse() { - const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]} - -data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]} - -data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]} - -data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}]} - -data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}]} - - -"#; - - let sever_events = - ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap(); - assert_eq!(sever_events.events.len(), 5); - assert_eq!( - sever_events.events[0].choices[0] - .delta - .content - .as_ref() - .unwrap(), - "" - ); - assert_eq!( - sever_events.events[1].choices[0] - .delta - .content - .as_ref() - .unwrap(), - "Hello" - ); - assert_eq!( - sever_events.events[2].choices[0] - .delta - .content - .as_ref() - .unwrap(), - "!" - ); - assert_eq!( - sever_events.events[3].choices[0] - .delta - .content - .as_ref() - .unwrap(), - " How" - ); - assert_eq!( - sever_events.events[4].choices[0] - .delta - .content - .as_ref() - .unwrap(), - " can" - ); - assert_eq!(sever_events.to_string(), "Hello! How can"); - } - - #[test] - fn stream_chunk_parse_done() { - const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}]} - -data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}]} - -data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}]} - -data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}]} - -data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]} - -data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} - -data: [DONE] -"#; - - let sever_events: ChatCompletionStreamResponseServerEvents = - ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap(); - assert_eq!(sever_events.events.len(), 6); - assert_eq!( - sever_events.events[0].choices[0] - .delta - .content - .as_ref() - .unwrap(), - " I" - ); - assert_eq!( - sever_events.events[1].choices[0] - .delta - .content - .as_ref() - .unwrap(), - " assist" - ); - assert_eq!( - sever_events.events[2].choices[0] - .delta - .content - .as_ref() - .unwrap(), - " you" - ); - assert_eq!( - sever_events.events[3].choices[0] - .delta - .content - .as_ref() - .unwrap(), - " today" - ); - assert_eq!( - sever_events.events[4].choices[0] - .delta - .content - .as_ref() - .unwrap(), - "?" - ); - assert_eq!(sever_events.events[5].choices[0].delta.content, None); - - assert_eq!(sever_events.to_string(), " I assist you today?"); - } - - #[test] - fn stream_chunk_parse_mistral() { - const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]} - -data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]} - -data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]} - -data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]} - -data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]} - -data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]} - -data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]} - -data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]} - -data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]} - -data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]} - -data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}} - -data: [DONE] -"#; - - let sever_events: ChatCompletionStreamResponseServerEvents = - ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap(); - assert_eq!(sever_events.events.len(), 11); - - assert_eq!( - sever_events.to_string(), - "Hello! How can I assist you today?" - ); - } -} diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 0c998660..0d9bea80 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -2,6 +2,22 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt::Display; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Configuration { + pub version: String, + pub listener: Listener, + pub endpoints: Option>, + pub llm_providers: Vec, + pub overrides: Option, + pub system_prompt: Option, + pub prompt_guards: Option, + pub prompt_targets: Option>, + pub error_target: Option, + pub ratelimits: Option>, + pub tracing: Option, + pub mode: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct Overrides { pub prompt_target_intent_matching_threshold: Option, @@ -22,22 +38,6 @@ pub enum GatewayMode { Prompt, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Configuration { - pub version: String, - pub listener: Listener, - pub endpoints: Option>, - pub llm_providers: Vec, - pub overrides: Option, - pub system_prompt: Option, - pub prompt_guards: Option, - pub prompt_targets: Option>, - pub error_target: Option, - pub ratelimits: Option>, - pub tracing: Option, - pub mode: Option, -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ErrorTargetDetail { pub endpoint: Option, diff --git a/crates/common/src/errors.rs b/crates/common/src/errors.rs index 4dc7d647..17f19ebb 100644 --- a/crates/common/src/errors.rs +++ b/crates/common/src/errors.rs @@ -1,6 +1,6 @@ use proxy_wasm::types::Status; -use crate::{common_types::open_ai::ChatCompletionChunkResponseError, ratelimit}; +use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit}; #[derive(thiserror::Error, Debug)] pub enum ClientError { diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index aa34f2fd..cd5238a3 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -1,4 +1,4 @@ -pub mod common_types; +pub mod api; pub mod configuration; pub mod consts; pub mod embeddings; diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 9bb81f11..5d8669a2 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,5 +1,5 @@ use crate::filter_context::WasmMetrics; -use common::common_types::open_ai::{ +use common::api::open_ai::{ ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse, Message, StreamOptions, }; diff --git a/crates/prompt_gateway/src/embeddings.rs b/crates/prompt_gateway/src/embeddings.rs new file mode 100644 index 00000000..f2883682 --- /dev/null +++ b/crates/prompt_gateway/src/embeddings.rs @@ -0,0 +1,5 @@ +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub enum EmbeddingType { + Name, + Description, +} diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs index eb4085cf..966765fe 100644 --- a/crates/prompt_gateway/src/filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -1,5 +1,5 @@ +use crate::embeddings::EmbeddingType; use crate::stream_context::StreamContext; -use common::common_types::EmbeddingType; use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing}; use common::consts::ARCH_UPSTREAM_HOST_HEADER; use common::consts::DEFAULT_EMBEDDING_MODEL; diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index b3ce7d8f..7508f852 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -1,10 +1,8 @@ use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext}; use common::{ - common_types::{ - open_ai::{ - to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, - }, - PromptGuardRequest, PromptGuardTask, + api::{ + open_ai::{self, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest}, + prompt_guard::{PromptGuardRequest, PromptGuardTask}, }, consts::{ ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER, @@ -324,7 +322,7 @@ impl HttpContext for StreamContext { ), ]; - let mut response_str = to_server_events(chunks); + let mut response_str = open_ai::to_server_events(chunks); // append the original response from the model to the stream response_str.push_str(&body_utf8); self.set_http_response_body(0, body_size, response_str.as_bytes()); diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index f873b9bf..8fe85f75 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -3,8 +3,8 @@ use proxy_wasm::traits::*; use proxy_wasm::types::*; mod context; +mod embeddings; mod filter_context; -mod hallucination; mod http_context; mod stream_context; diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 96d82cc7..3a4d2733 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -1,15 +1,17 @@ +use crate::embeddings::EmbeddingType; use crate::filter_context::{EmbeddingsStore, WasmMetrics}; -use crate::hallucination::extract_messages_for_hallucination; use acap::cos; -use common::common_types::open_ai::{ +use common::api::hallucination::{ + extract_messages_for_hallucination, HallucinationClassificationRequest, + HallucinationClassificationResponse, +}; +use common::api::open_ai::{ to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, ToolCall, ToolType, }; -use common::common_types::{ - EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, - PromptGuardResponse, ZeroShotClassificationRequest, ZeroShotClassificationResponse, -}; +use common::api::prompt_guard::PromptGuardResponse; +use common::api::zero_shot::{ZeroShotClassificationRequest, ZeroShotClassificationResponse}; use common::configuration::{Overrides, PromptGuards, PromptTarget, Tracing}; use common::consts::{ ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 57678a06..ac6009f8 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -1,11 +1,14 @@ -use common::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; -use common::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; -use common::common_types::{HallucinationClassificationResponse, PromptGuardResponse}; +use common::api::hallucination::HallucinationClassificationResponse; +use common::api::open_ai::{ + ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage, +}; +use common::api::prompt_guard::PromptGuardResponse; +use common::api::zero_shot::ZeroShotClassificationResponse; +use common::configuration::Configuration; use common::embeddings::{ create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding, }; -use common::{common_types::ZeroShotClassificationResponse, configuration::Configuration}; use http::StatusCode; use proxy_wasm_test_framework::tester::{self, Tester}; use proxy_wasm_test_framework::types::{