mirror of
https://github.com/katanemo/plano.git
synced 2026-06-20 15:28:07 +02:00
Merge branch 'main' into adil/add_acm_demo
This commit is contained in:
commit
68097fde07
166 changed files with 9507 additions and 11803 deletions
|
|
@ -1,7 +1,23 @@
|
|||
use common::{
|
||||
common_types::open_ai::Message,
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::{
|
||||
api::open_ai::Message,
|
||||
consts::{ARCH_MODEL_PREFIX, HALLUCINATION_TEMPLATE, USER_ROLE},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationRequest {
|
||||
pub prompt: String,
|
||||
pub parameters: HashMap<String, String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationResponse {
|
||||
pub params_scores: HashMap<String, f64>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
|
||||
let mut arch_assistant = false;
|
||||
|
|
@ -42,7 +58,7 @@ pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use common::common_types::open_ai::Message;
|
||||
use crate::api::open_ai::Message;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::extract_messages_for_hallucination;
|
||||
4
crates/common/src/api/mod.rs
Normal file
4
crates/common/src/api/mod.rs
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
pub mod hallucination;
|
||||
pub mod open_ai;
|
||||
pub mod prompt_guard;
|
||||
pub mod zero_shot;
|
||||
672
crates/common/src/api/open_ai.rs
Normal file
672
crates/common/src/api/open_ai.rs
Normal file
|
|
@ -0,0 +1,672 @@
|
|||
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fmt::Display,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsRequest {
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<ChatCompletionTool>>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum ToolType {
|
||||
#[serde(rename = "function")]
|
||||
Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionTool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FunctionParameters {
|
||||
pub properties: HashMap<String, FunctionParameter>,
|
||||
}
|
||||
|
||||
impl Serialize for FunctionParameters {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
// select all requried parameters
|
||||
let required: Vec<&String> = self
|
||||
.properties
|
||||
.iter()
|
||||
.filter(|(_, v)| v.required.unwrap_or(false))
|
||||
.map(|(k, _)| k)
|
||||
.collect();
|
||||
let mut map = serializer.serialize_map(Some(2))?;
|
||||
map.serialize_entry("properties", &self.properties)?;
|
||||
if !required.is_empty() {
|
||||
map.serialize_entry("required", &required)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FunctionParameter {
|
||||
#[serde(rename = "type")]
|
||||
#[serde(default = "ParameterType::string")]
|
||||
pub parameter_type: ParameterType,
|
||||
pub description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "enum")]
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub format: Option<String>,
|
||||
}
|
||||
|
||||
impl Serialize for FunctionParameter {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut map = serializer.serialize_map(Some(5))?;
|
||||
map.serialize_entry("type", &self.parameter_type)?;
|
||||
map.serialize_entry("description", &self.description)?;
|
||||
if let Some(enum_values) = &self.enum_values {
|
||||
map.serialize_entry("enum", enum_values)?;
|
||||
}
|
||||
if let Some(default) = &self.default {
|
||||
map.serialize_entry("default", default)?;
|
||||
}
|
||||
if let Some(format) = &self.format {
|
||||
map.serialize_entry("format", format)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum ParameterType {
|
||||
#[serde(rename = "int")]
|
||||
Int,
|
||||
#[serde(rename = "float")]
|
||||
Float,
|
||||
#[serde(rename = "bool")]
|
||||
Bool,
|
||||
#[serde(rename = "str")]
|
||||
String,
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
#[serde(rename = "dict")]
|
||||
Dict,
|
||||
}
|
||||
|
||||
impl From<String> for ParameterType {
|
||||
fn from(s: String) -> Self {
|
||||
match s.as_str() {
|
||||
"int" => ParameterType::Int,
|
||||
"integer" => ParameterType::Int,
|
||||
"float" => ParameterType::Float,
|
||||
"bool" => ParameterType::Bool,
|
||||
"boolean" => ParameterType::Bool,
|
||||
"str" => ParameterType::String,
|
||||
"string" => ParameterType::String,
|
||||
"list" => ParameterType::List,
|
||||
"array" => ParameterType::List,
|
||||
"dict" => ParameterType::Dict,
|
||||
"dictionary" => ParameterType::Dict,
|
||||
_ => ParameterType::String,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ParameterType {
|
||||
pub fn string() -> ParameterType {
|
||||
ParameterType::String
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StreamOptions {
|
||||
pub include_usage: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub finish_reason: Option<String>,
|
||||
pub index: Option<usize>,
|
||||
pub message: Message,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: FunctionCallDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCallDetail {
|
||||
pub name: String,
|
||||
pub arguments: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ToolCallState {
|
||||
pub key: String,
|
||||
pub message: Option<Message>,
|
||||
pub tool_call: FunctionCallDetail,
|
||||
pub tool_response: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ArchState {
|
||||
ToolCall(Vec<ToolCallState>),
|
||||
}
|
||||
#[derive(Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ModelServerResponse {
|
||||
ChatCompletionsResponse(ChatCompletionsResponse),
|
||||
ModelServerErrorResponse(ModelServerErrorResponse),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelServerErrorResponse {
|
||||
pub result: String,
|
||||
pub intent_latency: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
pub usage: Option<Usage>,
|
||||
pub choices: Vec<Choice>,
|
||||
pub model: String,
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl ChatCompletionsResponse {
|
||||
pub fn new(message: String) -> Self {
|
||||
ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: Some(message),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
index: Some(0),
|
||||
finish_reason: Some("done".to_string()),
|
||||
}],
|
||||
usage: None,
|
||||
model: ARCH_FC_MODEL_NAME.to_string(),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub completion_tokens: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionStreamResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
impl ChatCompletionStreamResponse {
|
||||
pub fn new(
|
||||
response: Option<String>,
|
||||
role: Option<String>,
|
||||
model: Option<String>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
) -> Self {
|
||||
ChatCompletionStreamResponse {
|
||||
model,
|
||||
choices: vec![ChunkChoice {
|
||||
delta: Delta {
|
||||
role,
|
||||
content: response,
|
||||
tool_calls,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChatCompletionChunkResponseError {
|
||||
#[error("failed to deserialize")]
|
||||
Deserialization(#[from] serde_json::Error),
|
||||
#[error("empty content in data chunk")]
|
||||
EmptyContent,
|
||||
#[error("no chunks present")]
|
||||
NoChunks,
|
||||
}
|
||||
|
||||
pub struct ChatCompletionStreamResponseServerEvents {
|
||||
pub events: Vec<ChatCompletionStreamResponse>,
|
||||
}
|
||||
|
||||
impl Display for ChatCompletionStreamResponseServerEvents {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let tokens_str = self
|
||||
.events
|
||||
.iter()
|
||||
.map(|response_chunk| {
|
||||
if response_chunk.choices.is_empty() {
|
||||
return "".to_string();
|
||||
}
|
||||
response_chunk.choices[0]
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or("".to_string())
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("");
|
||||
|
||||
write!(f, "{}", tokens_str)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for ChatCompletionStreamResponseServerEvents {
|
||||
type Error = ChatCompletionChunkResponseError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
let response_chunks: VecDeque<ChatCompletionStreamResponse> = value
|
||||
.lines()
|
||||
.filter(|line| line.starts_with("data: "))
|
||||
.map(|line| line.get(6..).unwrap())
|
||||
.filter(|data_chunk| *data_chunk != "[DONE]")
|
||||
.map(serde_json::from_str::<ChatCompletionStreamResponse>)
|
||||
.collect::<Result<VecDeque<ChatCompletionStreamResponse>, _>>()?;
|
||||
|
||||
Ok(ChatCompletionStreamResponseServerEvents {
|
||||
events: response_chunks.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChunkChoice {
|
||||
pub delta: Delta,
|
||||
// TODO: could this be an enum?
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Delta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
|
||||
let mut response_str = String::new();
|
||||
for chunk in chunks.iter() {
|
||||
response_str.push_str("data: ");
|
||||
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
|
||||
response_str.push_str("\n\n");
|
||||
}
|
||||
response_str
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::{ChatCompletionStreamResponseServerEvents, Message};
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
|
||||
const TOOL_SERIALIZED: &str = r#"{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What city do you want to know the weather for?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"description": "function to retrieve weather forecast",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "str",
|
||||
"description": "city for weather forecast",
|
||||
"default": "test"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"city"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"stream_options": {
|
||||
"include_usage": true
|
||||
}
|
||||
}"#;
|
||||
|
||||
#[test]
|
||||
fn test_tool_type_request() {
|
||||
use super::{
|
||||
ChatCompletionTool, ChatCompletionsRequest, FunctionDefinition, FunctionParameter,
|
||||
FunctionParameters, ParameterType, StreamOptions, ToolType,
|
||||
};
|
||||
|
||||
let mut properties = HashMap::new();
|
||||
properties.insert(
|
||||
"city".to_string(),
|
||||
FunctionParameter {
|
||||
parameter_type: ParameterType::String,
|
||||
description: "city for weather forecast".to_string(),
|
||||
required: Some(true),
|
||||
enum_values: None,
|
||||
default: Some("test".to_string()),
|
||||
format: None,
|
||||
},
|
||||
);
|
||||
|
||||
let function_definition = FunctionDefinition {
|
||||
name: "weather_forecast".to_string(),
|
||||
description: "function to retrieve weather forecast".to_string(),
|
||||
parameters: FunctionParameters { properties },
|
||||
};
|
||||
|
||||
let chat_completions_request = ChatCompletionsRequest {
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What city do you want to know the weather for?".to_string()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
tools: Some(vec![ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
function: function_definition,
|
||||
}]),
|
||||
stream: true,
|
||||
stream_options: Some(StreamOptions {
|
||||
include_usage: true,
|
||||
}),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();
|
||||
println!("{}", serialized);
|
||||
assert_eq!(TOOL_SERIALIZED, serialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parameter_types() {
|
||||
use super::{FunctionParameter, ParameterType};
|
||||
|
||||
const PARAMETER_SERIALZIED: &str = r#"{
|
||||
"city": {
|
||||
"type": "str",
|
||||
"description": "city for weather forecast",
|
||||
"default": "test"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let properties = HashMap::from([(
|
||||
"city".to_string(),
|
||||
FunctionParameter {
|
||||
parameter_type: ParameterType::String,
|
||||
description: "city for weather forecast".to_string(),
|
||||
required: Some(true),
|
||||
enum_values: None,
|
||||
default: Some("test".to_string()),
|
||||
format: None,
|
||||
},
|
||||
)]);
|
||||
|
||||
let serialized = serde_json::to_string_pretty(&properties).unwrap();
|
||||
assert_eq!(PARAMETER_SERIALZIED, serialized);
|
||||
|
||||
// ensure that if type is missing it is set to string
|
||||
const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#"
|
||||
{
|
||||
"city": {
|
||||
"description": "city for weather forecast"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let missing_type_deserialized: HashMap<String, FunctionParameter> =
|
||||
serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap();
|
||||
println!("{:?}", missing_type_deserialized);
|
||||
assert_eq!(
|
||||
missing_type_deserialized
|
||||
.get("city")
|
||||
.unwrap()
|
||||
.parameter_type,
|
||||
ParameterType::String
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
"#;
|
||||
|
||||
let sever_events =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 5);
|
||||
assert_eq!(
|
||||
sever_events.events[0].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
""
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[1].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"Hello"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[2].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"!"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[3].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" How"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[4].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" can"
|
||||
);
|
||||
assert_eq!(sever_events.to_string(), "Hello! How can");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse_done() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let sever_events: ChatCompletionStreamResponseServerEvents =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 6);
|
||||
assert_eq!(
|
||||
sever_events.events[0].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" I"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[1].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" assist"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[2].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" you"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[3].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" today"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[4].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"?"
|
||||
);
|
||||
assert_eq!(sever_events.events[5].choices[0].delta.content, None);
|
||||
|
||||
assert_eq!(sever_events.to_string(), " I assist you today?");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse_mistral() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let sever_events: ChatCompletionStreamResponseServerEvents =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 11);
|
||||
|
||||
assert_eq!(
|
||||
sever_events.to_string(),
|
||||
"Hello! How can I assist you today?"
|
||||
);
|
||||
}
|
||||
}
|
||||
25
crates/common/src/api/prompt_guard.rs
Normal file
25
crates/common/src/api/prompt_guard.rs
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum PromptGuardTask {
|
||||
#[serde(rename = "jailbreak")]
|
||||
Jailbreak,
|
||||
#[serde(rename = "toxicity")]
|
||||
Toxicity,
|
||||
#[serde(rename = "both")]
|
||||
Both,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptGuardRequest {
|
||||
pub input: String,
|
||||
pub task: PromptGuardTask,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptGuardResponse {
|
||||
pub toxic_prob: Option<f64>,
|
||||
pub jailbreak_prob: Option<f64>,
|
||||
pub toxic_verdict: Option<bool>,
|
||||
pub jailbreak_verdict: Option<bool>,
|
||||
}
|
||||
18
crates/common/src/api/zero_shot.rs
Normal file
18
crates/common/src/api/zero_shot.rs
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ZeroShotClassificationRequest {
|
||||
pub input: String,
|
||||
pub labels: Vec<String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ZeroShotClassificationResponse {
|
||||
pub predicted_class: String,
|
||||
pub predicted_class_score: f64,
|
||||
pub scores: HashMap<String, f64>,
|
||||
pub model: String,
|
||||
}
|
||||
|
|
@ -1,743 +0,0 @@
|
|||
use crate::configuration::PromptTarget;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
pub prompt_target: PromptTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
pub enum EmbeddingType {
|
||||
Name,
|
||||
Description,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorPoint {
|
||||
pub id: String,
|
||||
pub payload: HashMap<String, String>,
|
||||
pub vector: Vec<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StoreVectorEmbeddingsRequest {
|
||||
pub points: Vec<VectorPoint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchPointResult {
|
||||
pub id: String,
|
||||
pub version: i32,
|
||||
pub score: f64,
|
||||
pub payload: HashMap<String, String>,
|
||||
}
|
||||
|
||||
pub mod open_ai {
|
||||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fmt::Display,
|
||||
};
|
||||
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
|
||||
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsRequest {
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<ChatCompletionTool>>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ToolType {
|
||||
#[serde(rename = "function")]
|
||||
Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionTool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FunctionParameters {
|
||||
pub properties: HashMap<String, FunctionParameter>,
|
||||
}
|
||||
|
||||
impl Serialize for FunctionParameters {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
// select all requried parameters
|
||||
let required: Vec<&String> = self
|
||||
.properties
|
||||
.iter()
|
||||
.filter(|(_, v)| v.required.unwrap_or(false))
|
||||
.map(|(k, _)| k)
|
||||
.collect();
|
||||
let mut map = serializer.serialize_map(Some(2))?;
|
||||
map.serialize_entry("properties", &self.properties)?;
|
||||
if !required.is_empty() {
|
||||
map.serialize_entry("required", &required)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct FunctionParameter {
|
||||
#[serde(rename = "type")]
|
||||
#[serde(default = "ParameterType::string")]
|
||||
pub parameter_type: ParameterType,
|
||||
pub description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "enum")]
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<String>,
|
||||
}
|
||||
|
||||
impl Serialize for FunctionParameter {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut map = serializer.serialize_map(Some(5))?;
|
||||
map.serialize_entry("type", &self.parameter_type)?;
|
||||
map.serialize_entry("description", &self.description)?;
|
||||
if let Some(enum_values) = &self.enum_values {
|
||||
map.serialize_entry("enum", enum_values)?;
|
||||
}
|
||||
if let Some(default) = &self.default {
|
||||
map.serialize_entry("default", default)?;
|
||||
}
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum ParameterType {
|
||||
#[serde(rename = "int")]
|
||||
Int,
|
||||
#[serde(rename = "float")]
|
||||
Float,
|
||||
#[serde(rename = "bool")]
|
||||
Bool,
|
||||
#[serde(rename = "str")]
|
||||
String,
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
#[serde(rename = "dict")]
|
||||
Dict,
|
||||
}
|
||||
|
||||
impl From<String> for ParameterType {
|
||||
fn from(s: String) -> Self {
|
||||
match s.as_str() {
|
||||
"int" => ParameterType::Int,
|
||||
"integer" => ParameterType::Int,
|
||||
"float" => ParameterType::Float,
|
||||
"bool" => ParameterType::Bool,
|
||||
"boolean" => ParameterType::Bool,
|
||||
"str" => ParameterType::String,
|
||||
"string" => ParameterType::String,
|
||||
"list" => ParameterType::List,
|
||||
"array" => ParameterType::List,
|
||||
"dict" => ParameterType::Dict,
|
||||
"dictionary" => ParameterType::Dict,
|
||||
_ => ParameterType::String,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ParameterType {
|
||||
pub fn string() -> ParameterType {
|
||||
ParameterType::String
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StreamOptions {
|
||||
pub include_usage: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub finish_reason: String,
|
||||
pub index: usize,
|
||||
pub message: Message,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: FunctionCallDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCallDetail {
|
||||
pub name: String,
|
||||
pub arguments: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ToolCallState {
|
||||
pub key: String,
|
||||
pub message: Option<Message>,
|
||||
pub tool_call: FunctionCallDetail,
|
||||
pub tool_response: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ArchState {
|
||||
ToolCall(Vec<ToolCallState>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
pub usage: Option<Usage>,
|
||||
pub choices: Vec<Choice>,
|
||||
pub model: String,
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl ChatCompletionsResponse {
|
||||
pub fn new(message: String) -> Self {
|
||||
ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message: Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: Some(message),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
index: 0,
|
||||
finish_reason: "done".to_string(),
|
||||
}],
|
||||
usage: None,
|
||||
model: ARCH_FC_MODEL_NAME.to_string(),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub completion_tokens: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionStreamResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
impl ChatCompletionStreamResponse {
|
||||
pub fn new(
|
||||
response: Option<String>,
|
||||
role: Option<String>,
|
||||
model: Option<String>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
) -> Self {
|
||||
ChatCompletionStreamResponse {
|
||||
model,
|
||||
choices: vec![ChunkChoice {
|
||||
delta: Delta {
|
||||
role,
|
||||
content: response,
|
||||
tool_calls,
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ChatCompletionChunkResponseError {
|
||||
#[error("failed to deserialize")]
|
||||
Deserialization(#[from] serde_json::Error),
|
||||
#[error("empty content in data chunk")]
|
||||
EmptyContent,
|
||||
#[error("no chunks present")]
|
||||
NoChunks,
|
||||
}
|
||||
|
||||
pub struct ChatCompletionStreamResponseServerEvents {
|
||||
pub events: Vec<ChatCompletionStreamResponse>,
|
||||
}
|
||||
|
||||
impl Display for ChatCompletionStreamResponseServerEvents {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let tokens_str = self
|
||||
.events
|
||||
.iter()
|
||||
.map(|response_chunk| {
|
||||
if response_chunk.choices.is_empty() {
|
||||
return "".to_string();
|
||||
}
|
||||
response_chunk.choices[0]
|
||||
.delta
|
||||
.content
|
||||
.clone()
|
||||
.unwrap_or("".to_string())
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("");
|
||||
|
||||
write!(f, "{}", tokens_str)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for ChatCompletionStreamResponseServerEvents {
|
||||
type Error = ChatCompletionChunkResponseError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
let response_chunks: VecDeque<ChatCompletionStreamResponse> = value
|
||||
.lines()
|
||||
.filter(|line| line.starts_with("data: "))
|
||||
.map(|line| line.get(6..).unwrap())
|
||||
.filter(|data_chunk| *data_chunk != "[DONE]")
|
||||
.map(serde_json::from_str::<ChatCompletionStreamResponse>)
|
||||
.collect::<Result<VecDeque<ChatCompletionStreamResponse>, _>>()?;
|
||||
|
||||
Ok(ChatCompletionStreamResponseServerEvents {
|
||||
events: response_chunks.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChunkChoice {
|
||||
pub delta: Delta,
|
||||
// TODO: could this be an enum?
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Delta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
|
||||
let mut response_str = String::new();
|
||||
for chunk in chunks.iter() {
|
||||
response_str.push_str("data: ");
|
||||
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
|
||||
response_str.push_str("\n\n");
|
||||
}
|
||||
response_str
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ZeroShotClassificationRequest {
|
||||
pub input: String,
|
||||
pub labels: Vec<String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ZeroShotClassificationResponse {
|
||||
pub predicted_class: String,
|
||||
pub predicted_class_score: f64,
|
||||
pub scores: HashMap<String, f64>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationRequest {
|
||||
pub prompt: String,
|
||||
pub parameters: HashMap<String, String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationResponse {
|
||||
pub params_scores: HashMap<String, f64>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum PromptGuardTask {
|
||||
#[serde(rename = "jailbreak")]
|
||||
Jailbreak,
|
||||
#[serde(rename = "toxicity")]
|
||||
Toxicity,
|
||||
#[serde(rename = "both")]
|
||||
Both,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptGuardRequest {
|
||||
pub input: String,
|
||||
pub task: PromptGuardTask,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptGuardResponse {
|
||||
pub toxic_prob: Option<f64>,
|
||||
pub jailbreak_prob: Option<f64>,
|
||||
pub toxic_verdict: Option<bool>,
|
||||
pub jailbreak_verdict: Option<bool>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::common_types::open_ai::{ChatCompletionStreamResponseServerEvents, Message};
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::HashMap;
|
||||
|
||||
const TOOL_SERIALIZED: &str = r#"{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What city do you want to know the weather for?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"description": "function to retrieve weather forecast",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "str",
|
||||
"description": "city for weather forecast",
|
||||
"default": "test"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"city"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"stream_options": {
|
||||
"include_usage": true
|
||||
}
|
||||
}"#;
|
||||
|
||||
#[test]
|
||||
fn test_tool_type_request() {
|
||||
use super::open_ai::{
|
||||
ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType,
|
||||
};
|
||||
|
||||
let mut properties = HashMap::new();
|
||||
properties.insert(
|
||||
"city".to_string(),
|
||||
FunctionParameter {
|
||||
parameter_type: ParameterType::String,
|
||||
description: "city for weather forecast".to_string(),
|
||||
required: Some(true),
|
||||
enum_values: None,
|
||||
default: Some("test".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
let function_definition = FunctionDefinition {
|
||||
name: "weather_forecast".to_string(),
|
||||
description: "function to retrieve weather forecast".to_string(),
|
||||
parameters: super::open_ai::FunctionParameters { properties },
|
||||
};
|
||||
|
||||
let chat_completions_request = ChatCompletionsRequest {
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: Some("What city do you want to know the weather for?".to_string()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
tools: Some(vec![super::open_ai::ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
function: function_definition,
|
||||
}]),
|
||||
stream: true,
|
||||
stream_options: Some(super::open_ai::StreamOptions {
|
||||
include_usage: true,
|
||||
}),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();
|
||||
println!("{}", serialized);
|
||||
assert_eq!(TOOL_SERIALIZED, serialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parameter_types() {
|
||||
use super::open_ai::{FunctionParameter, ParameterType};
|
||||
|
||||
const PARAMETER_SERIALZIED: &str = r#"{
|
||||
"city": {
|
||||
"type": "str",
|
||||
"description": "city for weather forecast",
|
||||
"default": "test"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let properties = HashMap::from([(
|
||||
"city".to_string(),
|
||||
FunctionParameter {
|
||||
parameter_type: ParameterType::String,
|
||||
description: "city for weather forecast".to_string(),
|
||||
required: Some(true),
|
||||
enum_values: None,
|
||||
default: Some("test".to_string()),
|
||||
},
|
||||
)]);
|
||||
|
||||
let serialized = serde_json::to_string_pretty(&properties).unwrap();
|
||||
assert_eq!(PARAMETER_SERIALZIED, serialized);
|
||||
|
||||
// ensure that if type is missing it is set to string
|
||||
const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#"
|
||||
{
|
||||
"city": {
|
||||
"description": "city for weather forecast"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let missing_type_deserialized: HashMap<String, FunctionParameter> =
|
||||
serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap();
|
||||
println!("{:?}", missing_type_deserialized);
|
||||
assert_eq!(
|
||||
missing_type_deserialized
|
||||
.get("city")
|
||||
.unwrap()
|
||||
.parameter_type,
|
||||
ParameterType::String
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
||||
"#;
|
||||
|
||||
let sever_events =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 5);
|
||||
assert_eq!(
|
||||
sever_events.events[0].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
""
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[1].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"Hello"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[2].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"!"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[3].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" How"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[4].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" can"
|
||||
);
|
||||
assert_eq!(sever_events.to_string(), "Hello! How can");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse_done() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let sever_events: ChatCompletionStreamResponseServerEvents =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 6);
|
||||
assert_eq!(
|
||||
sever_events.events[0].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" I"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[1].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" assist"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[2].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" you"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[3].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
" today"
|
||||
);
|
||||
assert_eq!(
|
||||
sever_events.events[4].choices[0]
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap(),
|
||||
"?"
|
||||
);
|
||||
assert_eq!(sever_events.events[5].choices[0].delta.content, None);
|
||||
|
||||
assert_eq!(sever_events.to_string(), " I assist you today?");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_chunk_parse_mistral() {
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]}
|
||||
|
||||
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}}
|
||||
|
||||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let sever_events: ChatCompletionStreamResponseServerEvents =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 11);
|
||||
|
||||
assert_eq!(
|
||||
sever_events.to_string(),
|
||||
"Hello! How can I assist you today?"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -2,6 +2,26 @@ use serde::{Deserialize, Serialize};
|
|||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
|
||||
use crate::api::open_ai::{
|
||||
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub version: String,
|
||||
pub listener: Listener,
|
||||
pub endpoints: Option<HashMap<String, Endpoint>>,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub overrides: Option<Overrides>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub prompt_guards: Option<PromptGuards>,
|
||||
pub prompt_targets: Option<Vec<PromptTarget>>,
|
||||
pub error_target: Option<ErrorTargetDetail>,
|
||||
pub ratelimits: Option<Vec<Ratelimit>>,
|
||||
pub tracing: Option<Tracing>,
|
||||
pub mode: Option<GatewayMode>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct Overrides {
|
||||
pub prompt_target_intent_matching_threshold: Option<f64>,
|
||||
|
|
@ -22,22 +42,6 @@ pub enum GatewayMode {
|
|||
Prompt,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Configuration {
|
||||
pub version: String,
|
||||
pub listener: Listener,
|
||||
pub endpoints: Option<HashMap<String, Endpoint>>,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub overrides: Option<Overrides>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub prompt_guards: Option<PromptGuards>,
|
||||
pub prompt_targets: Option<Vec<PromptTarget>>,
|
||||
pub error_target: Option<ErrorTargetDetail>,
|
||||
pub ratelimits: Option<Vec<Ratelimit>>,
|
||||
pub tracing: Option<Tracing>,
|
||||
pub mode: Option<GatewayMode>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorTargetDetail {
|
||||
pub endpoint: Option<EndpointDetails>,
|
||||
|
|
@ -192,6 +196,7 @@ pub struct Parameter {
|
|||
pub enum_values: Option<Vec<String>>,
|
||||
pub default: Option<String>,
|
||||
pub in_path: Option<bool>,
|
||||
pub format: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
|
||||
|
|
@ -233,11 +238,47 @@ pub struct PromptTarget {
|
|||
pub auto_llm_dispatch_on_response: Option<bool>,
|
||||
}
|
||||
|
||||
// convert PromptTarget to ChatCompletionTool
|
||||
impl From<&PromptTarget> for ChatCompletionTool {
|
||||
fn from(val: &PromptTarget) -> Self {
|
||||
let properties: HashMap<String, FunctionParameter> = match val.parameters {
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = FunctionParameter {
|
||||
parameter_type: ParameterType::from(
|
||||
entity.parameter_type.clone().unwrap_or("str".to_string()),
|
||||
),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
enum_values: entity.enum_values.clone(),
|
||||
default: entity.default.clone(),
|
||||
format: entity.format.clone(),
|
||||
};
|
||||
properties.insert(entity.name.clone(), param);
|
||||
}
|
||||
properties
|
||||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
|
||||
ChatCompletionTool {
|
||||
tool_type: crate::api::open_ai::ToolType::Function,
|
||||
function: FunctionDefinition {
|
||||
name: val.name.clone(),
|
||||
description: val.description.clone(),
|
||||
parameters: FunctionParameters { properties },
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
|
||||
use crate::configuration::GuardType;
|
||||
use crate::{api::open_ai::ToolType, configuration::GuardType};
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_configuration() {
|
||||
|
|
@ -309,4 +350,76 @@ mod test {
|
|||
let mode = config.mode.as_ref().unwrap_or(&super::GatewayMode::Prompt);
|
||||
assert_eq!(*mode, super::GatewayMode::Prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let ref_config = fs::read_to_string(
|
||||
"../../docs/source/resources/includes/arch_config_full_reference.yaml",
|
||||
)
|
||||
.expect("reference config file not found");
|
||||
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
|
||||
let prompt_targets = &config.prompt_targets;
|
||||
let prompt_target = prompt_targets
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.find(|p| p.name == "reboot_network_device")
|
||||
.unwrap();
|
||||
let chat_completion_tool: super::ChatCompletionTool = prompt_target.into();
|
||||
assert_eq!(chat_completion_tool.tool_type, ToolType::Function);
|
||||
assert_eq!(chat_completion_tool.function.name, "reboot_network_device");
|
||||
assert_eq!(
|
||||
chat_completion_tool.function.description,
|
||||
"Reboot a specific network device"
|
||||
);
|
||||
assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2);
|
||||
assert_eq!(
|
||||
chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.contains_key("device_id"),
|
||||
true
|
||||
);
|
||||
assert_eq!(
|
||||
chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.get("device_id")
|
||||
.unwrap()
|
||||
.parameter_type,
|
||||
crate::api::open_ai::ParameterType::String
|
||||
);
|
||||
assert_eq!(
|
||||
chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.get("device_id")
|
||||
.unwrap()
|
||||
.description,
|
||||
"Identifier of the network device to reboot.".to_string()
|
||||
);
|
||||
assert_eq!(
|
||||
chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.get("device_id")
|
||||
.unwrap()
|
||||
.required,
|
||||
Some(true)
|
||||
);
|
||||
assert_eq!(
|
||||
chat_completion_tool
|
||||
.function
|
||||
.parameters
|
||||
.properties
|
||||
.get("confirmation")
|
||||
.unwrap()
|
||||
.parameter_type,
|
||||
crate::api::open_ai::ParameterType::Bool
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,3 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "katanemo/bge-large-en-v1.5";
|
||||
pub const DEFAULT_INTENT_MODEL: &str = "katanemo/bart-large-mnli";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
|
||||
pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25;
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
|
|
@ -9,11 +5,6 @@ pub const TOOL_ROLE: &str = "tool";
|
|||
pub const ASSISTANT_ROLE: &str = "assistant";
|
||||
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
||||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||
pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot";
|
||||
pub const ARCH_FC_INTERNAL_HOST: &str = "arch_fc";
|
||||
pub const HALLUCINATION_INTERNAL_HOST: &str = "hallucination";
|
||||
pub const EMBEDDINGS_INTERNAL_HOST: &str = "embeddings";
|
||||
pub const GUARD_INTERNAL_HOST: &str = "guard";
|
||||
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
||||
pub const MESSAGES_KEY: &str = "messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
|
|
@ -25,7 +16,6 @@ pub const REQUEST_ID_HEADER: &str = "x-request-id";
|
|||
pub const TRACE_PARENT_HEADER: &str = "traceparent";
|
||||
pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
|
||||
pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
|
||||
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";
|
||||
pub const ARCH_MODEL_PREFIX: &str = "Arch";
|
||||
pub const HALLUCINATION_TEMPLATE: &str =
|
||||
"It seems I'm missing some information. Could you provide the following details ";
|
||||
|
|
|
|||
|
|
@ -1,59 +0,0 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use crate::embeddings;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingRequest {
|
||||
#[serde(rename = "input")]
|
||||
pub input: Box<embeddings::CreateEmbeddingRequestInput>,
|
||||
/// ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.
|
||||
#[serde(rename = "model")]
|
||||
pub model: String,
|
||||
/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
|
||||
#[serde(rename = "encoding_format", skip_serializing_if = "Option::is_none")]
|
||||
pub encoding_format: Option<EncodingFormat>,
|
||||
/// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.
|
||||
#[serde(rename = "dimensions", skip_serializing_if = "Option::is_none")]
|
||||
pub dimensions: Option<i32>,
|
||||
/// A unique identifier representing your end-user, which can help to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).
|
||||
#[serde(rename = "user", skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
impl CreateEmbeddingRequest {
|
||||
pub fn new(
|
||||
input: embeddings::CreateEmbeddingRequestInput,
|
||||
model: String,
|
||||
) -> CreateEmbeddingRequest {
|
||||
CreateEmbeddingRequest {
|
||||
input: Box::new(input),
|
||||
model,
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
pub enum EncodingFormat {
|
||||
#[serde(rename = "float")]
|
||||
Float,
|
||||
#[serde(rename = "base64")]
|
||||
Base64,
|
||||
}
|
||||
|
||||
impl Default for EncodingFormat {
|
||||
fn default() -> EncodingFormat {
|
||||
Self::Float
|
||||
}
|
||||
}
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// CreateEmbeddingRequestInput : Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens.
|
||||
/// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CreateEmbeddingRequestInput {
|
||||
/// The string that will be turned into an embedding.
|
||||
String(String),
|
||||
/// The array of integers that will be turned into an embedding.
|
||||
Array(Vec<i32>),
|
||||
}
|
||||
|
||||
impl Default for CreateEmbeddingRequestInput {
|
||||
fn default() -> Self {
|
||||
Self::String(Default::default())
|
||||
}
|
||||
}
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use crate::embeddings;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingResponse {
|
||||
/// The list of embeddings generated by the model.
|
||||
#[serde(rename = "data")]
|
||||
pub data: Vec<embeddings::Embedding>,
|
||||
/// The name of the model used to generate the embedding.
|
||||
#[serde(rename = "model")]
|
||||
pub model: String,
|
||||
/// The object type, which is always \"list\".
|
||||
#[serde(rename = "object")]
|
||||
pub object: Object,
|
||||
#[serde(rename = "usage")]
|
||||
pub usage: Box<embeddings::CreateEmbeddingResponseUsage>,
|
||||
}
|
||||
|
||||
impl CreateEmbeddingResponse {
|
||||
pub fn new(
|
||||
data: Vec<embeddings::Embedding>,
|
||||
model: String,
|
||||
object: Object,
|
||||
usage: embeddings::CreateEmbeddingResponseUsage,
|
||||
) -> CreateEmbeddingResponse {
|
||||
CreateEmbeddingResponse {
|
||||
data,
|
||||
model,
|
||||
object,
|
||||
usage: Box::new(usage),
|
||||
}
|
||||
}
|
||||
}
|
||||
/// The object type, which is always \"list\".
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
pub enum Object {
|
||||
#[serde(rename = "list")]
|
||||
List,
|
||||
}
|
||||
|
||||
impl Default for Object {
|
||||
fn default() -> Object {
|
||||
Self::List
|
||||
}
|
||||
}
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// CreateEmbeddingResponseUsage : The usage information for the request.
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CreateEmbeddingResponseUsage {
|
||||
/// The number of tokens used by the prompt.
|
||||
#[serde(rename = "prompt_tokens")]
|
||||
pub prompt_tokens: i32,
|
||||
/// The total number of tokens used by the request.
|
||||
#[serde(rename = "total_tokens")]
|
||||
pub total_tokens: i32,
|
||||
}
|
||||
|
||||
impl CreateEmbeddingResponseUsage {
|
||||
/// The usage information for the request.
|
||||
pub fn new(prompt_tokens: i32, total_tokens: i32) -> CreateEmbeddingResponseUsage {
|
||||
CreateEmbeddingResponseUsage {
|
||||
prompt_tokens,
|
||||
total_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
/*
|
||||
* OMF Embeddings
|
||||
*
|
||||
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
|
||||
*
|
||||
* The version of the OpenAPI document: 1.0.0
|
||||
*
|
||||
* Generated by: https://openapi-generator.tech
|
||||
*/
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Embedding : Represents an embedding vector returned by embedding endpoint.
|
||||
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Embedding {
|
||||
/// The index of the embedding in the list of embeddings.
|
||||
#[serde(rename = "index")]
|
||||
pub index: i32,
|
||||
/// The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings).
|
||||
#[serde(rename = "embedding")]
|
||||
pub embedding: Vec<f64>,
|
||||
/// The object type, which is always \"embedding\"
|
||||
#[serde(rename = "object")]
|
||||
pub object: Object,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
/// Represents an embedding vector returned by embedding endpoint.
|
||||
pub fn new(index: i32, embedding: Vec<f64>, object: Object) -> Embedding {
|
||||
Embedding {
|
||||
index,
|
||||
embedding,
|
||||
object,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// The object type, which is always \"embedding\"
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
pub enum Object {
|
||||
#[serde(rename = "embedding")]
|
||||
Embedding,
|
||||
}
|
||||
|
||||
impl Default for Object {
|
||||
fn default() -> Object {
|
||||
Self::Embedding
|
||||
}
|
||||
}
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
pub mod create_embedding_request;
|
||||
pub use self::create_embedding_request::CreateEmbeddingRequest;
|
||||
pub mod create_embedding_request_input;
|
||||
pub use self::create_embedding_request_input::CreateEmbeddingRequestInput;
|
||||
pub mod create_embedding_response;
|
||||
pub use self::create_embedding_response::CreateEmbeddingResponse;
|
||||
pub mod create_embedding_response_usage;
|
||||
pub use self::create_embedding_response_usage::CreateEmbeddingResponseUsage;
|
||||
pub mod embedding;
|
||||
pub use self::embedding::Embedding;
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
use proxy_wasm::types::Status;
|
||||
|
||||
use crate::{common_types::open_ai::ChatCompletionChunkResponseError, ratelimit};
|
||||
use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit};
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ClientError {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
pub mod common_types;
|
||||
pub mod api;
|
||||
pub mod configuration;
|
||||
pub mod consts;
|
||||
pub mod embeddings;
|
||||
pub mod errors;
|
||||
pub mod http;
|
||||
pub mod llm_providers;
|
||||
pub mod path;
|
||||
pub mod pii;
|
||||
pub mod ratelimit;
|
||||
pub mod routing;
|
||||
pub mod stats;
|
||||
pub mod tokenizer;
|
||||
pub mod tracing;
|
||||
pub mod path;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> Result<String, String> {
|
||||
pub fn replace_params_in_path(
|
||||
path: &str,
|
||||
params: &HashMap<String, String>,
|
||||
) -> Result<String, String> {
|
||||
let mut result = String::new();
|
||||
let mut in_param = false;
|
||||
let mut current_param = String::new();
|
||||
|
|
@ -17,12 +20,10 @@ pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> R
|
|||
return Err(format!("Missing value for parameter `{}`", param_name));
|
||||
}
|
||||
current_param.clear();
|
||||
} else if in_param {
|
||||
current_param.push(c);
|
||||
} else {
|
||||
if in_param {
|
||||
current_param.push(c);
|
||||
} else {
|
||||
result.push(c);
|
||||
}
|
||||
result.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ use log::error;
|
|||
use proxy_wasm::hostcalls;
|
||||
use proxy_wasm::types::*;
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait Metric {
|
||||
fn id(&self) -> u32;
|
||||
fn value(&self) -> Result<u64, String> {
|
||||
|
|
@ -14,7 +13,6 @@ pub trait Metric {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait IncrementingMetric: Metric {
|
||||
fn increment(&self, offset: i64) {
|
||||
match hostcalls::increment_metric(self.id(), offset) {
|
||||
|
|
@ -24,7 +22,6 @@ pub trait IncrementingMetric: Metric {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub trait RecordingMetric: Metric {
|
||||
fn record(&self, value: u64) {
|
||||
match hostcalls::record_metric(self.id(), value) {
|
||||
|
|
@ -39,7 +36,6 @@ pub struct Counter {
|
|||
id: u32,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl Counter {
|
||||
pub fn new(name: String) -> Counter {
|
||||
let returned_id = hostcalls::define_metric(MetricType::Counter, &name)
|
||||
|
|
@ -85,7 +81,6 @@ pub struct Histogram {
|
|||
id: u32,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl Histogram {
|
||||
pub fn new(name: String) -> Histogram {
|
||||
let returned_id = hostcalls::define_metric(MetricType::Histogram, &name)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use crate::metrics::Metrics;
|
||||
use crate::stream_context::StreamContext;
|
||||
use common::configuration::Configuration;
|
||||
use common::consts::OTEL_COLLECTOR_HTTP;
|
||||
|
|
@ -6,9 +7,7 @@ use common::http::CallArgs;
|
|||
use common::http::Client;
|
||||
use common::llm_providers::LlmProviders;
|
||||
use common::ratelimit;
|
||||
use common::stats::Counter;
|
||||
use common::stats::Gauge;
|
||||
use common::stats::Histogram;
|
||||
use common::tracing::TraceData;
|
||||
use log::debug;
|
||||
use log::warn;
|
||||
|
|
@ -22,39 +21,12 @@ use std::time::Duration;
|
|||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct WasmMetrics {
|
||||
pub active_http_calls: Gauge,
|
||||
pub ratelimited_rq: Counter,
|
||||
pub time_to_first_token: Histogram,
|
||||
pub time_per_output_token: Histogram,
|
||||
pub tokens_per_second: Histogram,
|
||||
pub request_latency: Histogram,
|
||||
pub output_sequence_length: Histogram,
|
||||
pub input_sequence_length: Histogram,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
ratelimited_rq: Counter::new(String::from("ratelimited_rq")),
|
||||
time_to_first_token: Histogram::new(String::from("time_to_first_token")),
|
||||
time_per_output_token: Histogram::new(String::from("time_per_output_token")),
|
||||
tokens_per_second: Histogram::new(String::from("tokens_per_second")),
|
||||
request_latency: Histogram::new(String::from("request_latency")),
|
||||
output_sequence_length: Histogram::new(String::from("output_sequence_length")),
|
||||
input_sequence_length: Histogram::new(String::from("input_sequence_length")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CallContext {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterContext {
|
||||
metrics: Rc<WasmMetrics>,
|
||||
metrics: Rc<Metrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: RefCell<HashMap<u32, CallContext>>,
|
||||
llm_providers: Option<Rc<LlmProviders>>,
|
||||
|
|
@ -65,7 +37,7 @@ impl FilterContext {
|
|||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
metrics: Rc::new(Metrics::new()),
|
||||
llm_providers: None,
|
||||
traces_queue: Arc::new(Mutex::new(VecDeque::new())),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ use proxy_wasm::traits::*;
|
|||
use proxy_wasm::types::*;
|
||||
|
||||
mod filter_context;
|
||||
mod metrics;
|
||||
mod stream_context;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
|
|
|
|||
28
crates/llm_gateway/src/metrics.rs
Normal file
28
crates/llm_gateway/src/metrics.rs
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
use common::stats::{Counter, Gauge, Histogram};
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Metrics {
|
||||
pub active_http_calls: Gauge,
|
||||
pub ratelimited_rq: Counter,
|
||||
pub time_to_first_token: Histogram,
|
||||
pub time_per_output_token: Histogram,
|
||||
pub tokens_per_second: Histogram,
|
||||
pub request_latency: Histogram,
|
||||
pub output_sequence_length: Histogram,
|
||||
pub input_sequence_length: Histogram,
|
||||
}
|
||||
|
||||
impl Metrics {
|
||||
pub fn new() -> Metrics {
|
||||
Metrics {
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
ratelimited_rq: Counter::new(String::from("ratelimited_rq")),
|
||||
time_to_first_token: Histogram::new(String::from("time_to_first_token")),
|
||||
time_per_output_token: Histogram::new(String::from("time_per_output_token")),
|
||||
tokens_per_second: Histogram::new(String::from("tokens_per_second")),
|
||||
request_latency: Histogram::new(String::from("request_latency")),
|
||||
output_sequence_length: Histogram::new(String::from("output_sequence_length")),
|
||||
input_sequence_length: Histogram::new(String::from("input_sequence_length")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
use crate::filter_context::WasmMetrics;
|
||||
use common::common_types::open_ai::{
|
||||
use crate::metrics::Metrics;
|
||||
use common::api::open_ai::{
|
||||
ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
|
||||
Message, StreamOptions,
|
||||
};
|
||||
|
|
@ -12,25 +12,23 @@ use common::errors::ServerError;
|
|||
use common::llm_providers::LlmProviders;
|
||||
use common::pii::obfuscate_auth_header;
|
||||
use common::ratelimit::Header;
|
||||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
use common::tracing::{Event, Span, TraceData, Traceparent};
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use http::StatusCode;
|
||||
use log::{debug, trace, warn};
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use common::stats::{IncrementingMetric, RecordingMetric};
|
||||
|
||||
use proxy_wasm::hostcalls::get_current_time;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
pub struct StreamContext {
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
metrics: Rc<Metrics>,
|
||||
ratelimit_selector: Option<Header>,
|
||||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
|
|
@ -50,7 +48,7 @@ pub struct StreamContext {
|
|||
impl StreamContext {
|
||||
pub fn new(
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
metrics: Rc<Metrics>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
|
||||
) -> Self {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
use std::str::FromStr;
|
||||
|
||||
use common::errors::ServerError;
|
||||
use common::stats::IncrementingMetric;
|
||||
use http::StatusCode;
|
||||
use log::{debug, warn};
|
||||
use proxy_wasm::traits::Context;
|
||||
|
||||
use crate::stream_context::{ResponseHandlerType, StreamContext};
|
||||
|
|
@ -19,76 +23,34 @@ impl Context for StreamContext {
|
|||
.expect("invalid token_id");
|
||||
self.metrics.active_http_calls.increment(-1);
|
||||
|
||||
/*
|
||||
state transition
|
||||
let body = self
|
||||
.get_http_call_response_body(0, body_size)
|
||||
.unwrap_or(vec![]);
|
||||
|
||||
graph LR
|
||||
|
||||
on_http_request_body --> prompt received
|
||||
prompt received --> get embeddings & arch guard
|
||||
arch guard --> get embeddings
|
||||
get embeddings --> zeroshot intent
|
||||
|
||||
┌──────────────────────┐ ┌─────────────────┐ ┌────────────────┐ ┌─────────────────┐
|
||||
│ │ │ │ │ │ │ │
|
||||
│ on_http_request_body ├──►│ prompt received ├──►│ get embeddings ├──►│ zeroshot intent │
|
||||
│ │ │ │ │ │ │ │
|
||||
└──────────────────────┘ └────────┬────────┘ └────────────────┘ └─────────────────┘
|
||||
│ ▲
|
||||
│ │
|
||||
│ │
|
||||
│ ┌────────┴───────┐
|
||||
│ │ │
|
||||
└───────────►│ arch guard │
|
||||
│ │
|
||||
└────────────────┘
|
||||
|
||||
|
||||
continue from zeroshot intent
|
||||
|
||||
graph LR
|
||||
|
||||
zeroshot intent --> arch_fc
|
||||
zeroshot intent --> default prompt target
|
||||
arch_fc --> developer api call & hallucination check
|
||||
hallucination check --> parameter gathering & developer api call
|
||||
developer api call --> resume request to llm
|
||||
|
||||
|
||||
┌─────────────────┐ ┌───────────────────────┐ ┌─────────────────────┐ ┌───────────────────────┐
|
||||
│ │ │ │ │ │ │ │
|
||||
│ zeroshot intent ├──►│ arch_fc ├──►│ developer api call ├──►│ resume request to llm │
|
||||
│ │ │ │ │ │ │ │
|
||||
└────────┬────────┘ └───────────┬───────────┘ └─────────────────────┘ └───────────────────────┘
|
||||
│ │ ▲
|
||||
│ └─────────────┐ │
|
||||
│ │ │
|
||||
│ ┌───────────────────────┐ │ ┌──────────┴──────────┐ ┌───────────────────────┐
|
||||
│ │ │ │ │ │ │ │
|
||||
└───────────►│ default prompt target │ └▲│ hallucination check ├──►│ parameter gathering │
|
||||
│ │ │ │ │ │
|
||||
└───────────────────────┘ └─────────────────────┘ └───────────────────────┘
|
||||
|
||||
|
||||
using https://mermaid-ascii.art/
|
||||
*/
|
||||
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
#[cfg_attr(any(), rustfmt::skip)]
|
||||
match callout_context.response_handler_type {
|
||||
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
|
||||
ResponseHandlerType::Embeddings => self.embeddings_handler(body, callout_context),
|
||||
ResponseHandlerType::ZeroShotIntent => self.zero_shot_intent_detection_resp_handler(body, callout_context),
|
||||
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
|
||||
ResponseHandlerType::Hallucination => self.hallucination_classification_resp_handler(body, callout_context),
|
||||
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
|
||||
ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context),
|
||||
}
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(String::from("No response body in inline HTTP request")),
|
||||
None,
|
||||
let http_status = self
|
||||
.get_http_call_response_header(":status")
|
||||
.unwrap_or(StatusCode::OK.as_str().to_string());
|
||||
debug!("http call response code: {}", http_status);
|
||||
if http_status != StatusCode::OK.as_str() {
|
||||
let server_error = ServerError::Upstream {
|
||||
host: callout_context.upstream_cluster.unwrap(),
|
||||
path: callout_context.upstream_cluster_path.unwrap(),
|
||||
status: http_status.clone(),
|
||||
body: String::from_utf8(body).unwrap(),
|
||||
};
|
||||
warn!("filter received non 2xx code: {:?}", server_error);
|
||||
return self.send_server_error(
|
||||
server_error,
|
||||
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
|
||||
);
|
||||
}
|
||||
|
||||
debug!("http call response handler type: {:?}", callout_context.response_handler_type);
|
||||
#[cfg_attr(any(), rustfmt::skip)]
|
||||
match callout_context.response_handler_type {
|
||||
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
|
||||
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
|
||||
ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,60 +1,27 @@
|
|||
use crate::metrics::Metrics;
|
||||
use crate::stream_context::StreamContext;
|
||||
use common::common_types::EmbeddingType;
|
||||
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing};
|
||||
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
|
||||
use common::consts::DEFAULT_EMBEDDING_MODEL;
|
||||
use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST};
|
||||
use common::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use common::http::CallArgs;
|
||||
use common::http::Client;
|
||||
use common::stats::Gauge;
|
||||
use common::stats::IncrementingMetric;
|
||||
use http::StatusCode;
|
||||
use log::{debug, info, trace, warn};
|
||||
use log::debug;
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct WasmMetrics {
|
||||
pub active_http_calls: Gauge,
|
||||
}
|
||||
|
||||
impl WasmMetrics {
|
||||
fn new() -> WasmMetrics {
|
||||
WasmMetrics {
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
|
||||
pub type EmbeddingsStore = HashMap<String, EmbeddingTypeMap>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterCallContext {
|
||||
pub prompt_target_name: String,
|
||||
pub embedding_type: EmbeddingType,
|
||||
}
|
||||
pub struct FilterCallContext {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterContext {
|
||||
metrics: Rc<WasmMetrics>,
|
||||
metrics: Rc<Metrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: RefCell<HashMap<u32, FilterCallContext>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
temp_embeddings_store: EmbeddingsStore,
|
||||
active_embedding_calls_count: u32,
|
||||
tracing: Rc<Option<Tracing>>,
|
||||
}
|
||||
|
||||
|
|
@ -62,136 +29,14 @@ impl FilterContext {
|
|||
pub fn new() -> FilterContext {
|
||||
FilterContext {
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
metrics: Rc::new(Metrics::new()),
|
||||
system_prompt: Rc::new(None),
|
||||
prompt_targets: Rc::new(HashMap::new()),
|
||||
overrides: Rc::new(None),
|
||||
prompt_guards: Rc::new(PromptGuards::default()),
|
||||
embeddings_store: Some(Rc::new(HashMap::new())),
|
||||
temp_embeddings_store: HashMap::new(),
|
||||
active_embedding_calls_count: 0,
|
||||
tracing: Rc::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
let prompt_target_description: Vec<(String, String)> = self
|
||||
.prompt_targets
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.description.clone()))
|
||||
.collect();
|
||||
|
||||
prompt_target_description
|
||||
.iter()
|
||||
.for_each(|(name, description)| {
|
||||
self.schedule_embeddings_call(name, description, EmbeddingType::Description);
|
||||
});
|
||||
}
|
||||
|
||||
fn schedule_embeddings_call(
|
||||
&mut self,
|
||||
prompt_target_name: &str,
|
||||
input: &str,
|
||||
embedding_type: EmbeddingType,
|
||||
) {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(String::from(input))),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
let json_data = serde_json::to_string(&embeddings_input).unwrap();
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
"/embeddings",
|
||||
vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", EMBEDDINGS_INTERNAL_HOST),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(60),
|
||||
);
|
||||
|
||||
let call_context = crate::filter_context::FilterCallContext {
|
||||
prompt_target_name: String::from(prompt_target_name),
|
||||
embedding_type,
|
||||
};
|
||||
|
||||
self.active_embedding_calls_count += 1;
|
||||
if let Err(error) = self.http_call(call_args, call_context) {
|
||||
panic!("{error}")
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding_response_handler(
|
||||
&mut self,
|
||||
embedding_type: EmbeddingType,
|
||||
prompt_target_name: String,
|
||||
body: Vec<u8>,
|
||||
) {
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(&prompt_target_name)
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"Received embeddings response for unknown prompt target name={}",
|
||||
prompt_target_name
|
||||
)
|
||||
});
|
||||
|
||||
if !body.is_empty() {
|
||||
let mut embedding_response: CreateEmbeddingResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error deserializing embedding response. body: {:?}: {:?}",
|
||||
String::from_utf8(body).unwrap(),
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let embeddings = embedding_response.data.remove(0).embedding;
|
||||
debug!(
|
||||
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
|
||||
prompt_target.name,
|
||||
prompt_target.description,
|
||||
embedding_type
|
||||
);
|
||||
|
||||
let entry = self.temp_embeddings_store.entry(prompt_target_name);
|
||||
match entry {
|
||||
Entry::Occupied(_) => {
|
||||
entry.and_modify(|e| {
|
||||
if let Entry::Vacant(e) = e.entry(embedding_type) {
|
||||
e.insert(embeddings);
|
||||
} else {
|
||||
panic!(
|
||||
"Duplicate {:?} for prompt target with name=\"{}\"",
|
||||
&embedding_type, prompt_target.name
|
||||
)
|
||||
}
|
||||
});
|
||||
}
|
||||
Entry::Vacant(_) => {
|
||||
entry.or_insert(HashMap::from([(embedding_type, embeddings)]));
|
||||
}
|
||||
}
|
||||
|
||||
if self.prompt_targets.len() == self.temp_embeddings_store.len() {
|
||||
self.embeddings_store =
|
||||
Some(Rc::new(std::mem::take(&mut self.temp_embeddings_store)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for FilterContext {
|
||||
|
|
@ -206,46 +51,7 @@ impl Client for FilterContext {
|
|||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
token_id: u32,
|
||||
_num_headers: usize,
|
||||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
trace!(
|
||||
"filter_context: on_http_call_response called with token_id: {:?}",
|
||||
token_id
|
||||
);
|
||||
let callout_data = self
|
||||
.callouts
|
||||
.borrow_mut()
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id");
|
||||
|
||||
self.active_embedding_calls_count -= 1;
|
||||
self.metrics.active_http_calls.increment(-1);
|
||||
let body_bytes = self.get_http_call_response_body(0, body_size).unwrap();
|
||||
|
||||
if let Some(status_code) = self.get_http_call_response_header(":status") {
|
||||
if status_code == StatusCode::OK.as_str() {
|
||||
self.embedding_response_handler(
|
||||
callout_data.embedding_type,
|
||||
callout_data.prompt_target_name,
|
||||
body_bytes,
|
||||
);
|
||||
} else {
|
||||
warn!(
|
||||
"Received non-200 status code: {} for callout with token_id: {}: body_str: {}",
|
||||
status_code,
|
||||
token_id,
|
||||
String::from_utf8(body_bytes).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Context for FilterContext {}
|
||||
|
||||
// RootContext allows the Rust code to reach into the Envoy Config
|
||||
impl RootContext for FilterContext {
|
||||
|
|
@ -283,15 +89,12 @@ impl RootContext for FilterContext {
|
|||
context_id
|
||||
);
|
||||
|
||||
let embedding_store = self.embeddings_store.as_ref().map(Rc::clone);
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
Rc::clone(&self.system_prompt),
|
||||
Rc::clone(&self.prompt_targets),
|
||||
Rc::clone(&self.prompt_guards),
|
||||
Rc::clone(&self.overrides),
|
||||
embedding_store,
|
||||
Rc::clone(&self.tracing),
|
||||
)))
|
||||
}
|
||||
|
|
@ -301,25 +104,6 @@ impl RootContext for FilterContext {
|
|||
}
|
||||
|
||||
fn on_vm_start(&mut self, _: usize) -> bool {
|
||||
self.set_tick_period(Duration::from_secs(1));
|
||||
true
|
||||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
if self.embeddings_store.is_some()
|
||||
&& self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len()
|
||||
{
|
||||
info!("embeddings store initialized");
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
} else {
|
||||
if self.active_embedding_calls_count == 0 {
|
||||
info!("retrieving embeddings from embedding server");
|
||||
self.process_prompt_targets();
|
||||
} else {
|
||||
info!("waiting for embeddings store to be initialized");
|
||||
}
|
||||
|
||||
self.set_tick_period(Duration::from_secs(5));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +1,12 @@
|
|||
use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext};
|
||||
use common::{
|
||||
common_types::{
|
||||
open_ai::{
|
||||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
|
||||
},
|
||||
PromptGuardRequest, PromptGuardTask,
|
||||
api::open_ai::{
|
||||
self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
},
|
||||
consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, GUARD_INTERNAL_HOST,
|
||||
HEALTHZ_PATH, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
|
||||
MODEL_SERVER_NAME, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
|
||||
},
|
||||
errors::ServerError,
|
||||
http::{CallArgs, Client},
|
||||
|
|
@ -37,11 +34,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let request_path = self.get_http_request_header(":path").unwrap_or_default();
|
||||
if request_path == HEALTHZ_PATH {
|
||||
if self.is_embedding_store_initialized() {
|
||||
self.send_http_response(200, vec![], None);
|
||||
} else {
|
||||
self.send_http_response(503, vec![], None);
|
||||
}
|
||||
self.send_http_response(200, vec![], None);
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
|
|
@ -140,43 +133,25 @@ impl HttpContext for StreamContext {
|
|||
|
||||
self.user_prompt = Some(last_user_prompt.clone());
|
||||
|
||||
let user_message_str = self.user_prompt.as_ref().unwrap().content.clone();
|
||||
// convert prompt targets to ChatCompletionTool
|
||||
let tool_calls: Vec<ChatCompletionTool> = self
|
||||
.prompt_targets
|
||||
.iter()
|
||||
.map(|(_, pt)| pt.into())
|
||||
.collect();
|
||||
|
||||
let prompt_guard_jailbreak_task = self
|
||||
.prompt_guards
|
||||
.input_guards
|
||||
.contains_key(&common::configuration::GuardType::Jailbreak);
|
||||
let arch_fc_chat_completion_request = ChatCompletionsRequest {
|
||||
messages: deserialized_body.messages.clone(),
|
||||
metadata: deserialized_body.metadata.clone(),
|
||||
stream: deserialized_body.stream,
|
||||
model: "--".to_string(),
|
||||
stream_options: deserialized_body.stream_options.clone(),
|
||||
tools: Some(tool_calls),
|
||||
};
|
||||
|
||||
self.chat_completions_request = Some(deserialized_body);
|
||||
|
||||
if !prompt_guard_jailbreak_task {
|
||||
debug!("Missing input guard. Making inline call to retrieve embeddings");
|
||||
let callout_context = StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: user_message_str.clone(),
|
||||
prompt_target_name: None,
|
||||
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
|
||||
similarity_scores: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
};
|
||||
self.get_embeddings(callout_context);
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
let get_prompt_guards_request = PromptGuardRequest {
|
||||
input: self
|
||||
.user_prompt
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone(),
|
||||
task: PromptGuardTask::Jailbreak,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
|
||||
let json_data = match serde_json::to_string(&arch_fc_chat_completion_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
self.send_server_error(ServerError::Serialization(error), None);
|
||||
|
|
@ -184,14 +159,14 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
debug!("archgw => archfc: {}", json_data);
|
||||
|
||||
let mut headers = vec![
|
||||
(ARCH_UPSTREAM_HOST_HEADER, GUARD_INTERNAL_HOST),
|
||||
(ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME),
|
||||
(":method", "POST"),
|
||||
(":path", "/guard"),
|
||||
(":authority", GUARD_INTERNAL_HOST),
|
||||
(":path", "/function_calling"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
(":authority", MODEL_SERVER_NAME),
|
||||
];
|
||||
|
||||
if self.request_id.is_some() {
|
||||
|
|
@ -204,23 +179,25 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
"/guard",
|
||||
"/function_calling",
|
||||
headers,
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
|
||||
let call_context = StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
response_handler_type: ResponseHandlerType::ArchFC,
|
||||
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
|
||||
prompt_target_name: None,
|
||||
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
|
||||
similarity_scores: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
upstream_cluster: Some(ARCH_INTERNAL_CLUSTER_NAME.to_string()),
|
||||
upstream_cluster_path: Some("/function_calling".to_string()),
|
||||
};
|
||||
|
||||
if let Err(e) = self.http_call(call_args, call_context) {
|
||||
debug!("http_call failed: {:?}", e);
|
||||
self.send_server_error(ServerError::HttpDispatch(e), None);
|
||||
}
|
||||
|
||||
|
|
@ -324,7 +301,7 @@ impl HttpContext for StreamContext {
|
|||
),
|
||||
];
|
||||
|
||||
let mut response_str = to_server_events(chunks);
|
||||
let mut response_str = open_ai::to_server_events(chunks);
|
||||
// append the original response from the model to the stream
|
||||
response_str.push_str(&body_utf8);
|
||||
self.set_http_response_body(0, body_size, response_str.as_bytes());
|
||||
|
|
@ -339,9 +316,11 @@ impl HttpContext for StreamContext {
|
|||
let mut data = match serde_json::from_str(&body_utf8) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
warn!("could not deserialize response: {}", e);
|
||||
self.send_server_error(ServerError::Deserialization(e), None);
|
||||
return Action::Pause;
|
||||
warn!(
|
||||
"could not deserialize response, sending data as it is: {}",
|
||||
e
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
// use serde::Value to manipulate the json object and ensure that we don't lose any data
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ use proxy_wasm::types::*;
|
|||
|
||||
mod context;
|
||||
mod filter_context;
|
||||
mod hallucination;
|
||||
mod http_context;
|
||||
mod metrics;
|
||||
mod stream_context;
|
||||
|
||||
proxy_wasm::main! {{
|
||||
|
|
|
|||
14
crates/prompt_gateway/src/metrics.rs
Normal file
14
crates/prompt_gateway/src/metrics.rs
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
use common::stats::Gauge;
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Metrics {
|
||||
pub active_http_calls: Gauge,
|
||||
}
|
||||
|
||||
impl Metrics {
|
||||
pub fn new() -> Metrics {
|
||||
Metrics {
|
||||
active_http_calls: Gauge::new(String::from("active_http_calls")),
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,11 +1,7 @@
|
|||
use common::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
|
||||
use common::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
|
||||
use common::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
|
||||
use common::embeddings::{
|
||||
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
|
||||
Embedding,
|
||||
use common::api::open_ai::{
|
||||
ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage,
|
||||
};
|
||||
use common::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
|
||||
use common::configuration::Configuration;
|
||||
use http::StatusCode;
|
||||
use proxy_wasm_test_framework::tester::{self, Tester};
|
||||
use proxy_wasm_test_framework::types::{
|
||||
|
|
@ -80,13 +76,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "guard"),
|
||||
("x-arch-upstream", "model_server"),
|
||||
(":method", "POST"),
|
||||
(":path", "/guard"),
|
||||
(":authority", "guard"),
|
||||
(":path", "/function_calling"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
(":authority", "model_server"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
|
|
@ -94,139 +88,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
)
|
||||
.returning(Some(1))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
|
||||
let prompt_guard_response = PromptGuardResponse {
|
||||
toxic_prob: None,
|
||||
toxic_verdict: None,
|
||||
jailbreak_prob: None,
|
||||
jailbreak_verdict: None,
|
||||
};
|
||||
let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
1,
|
||||
0,
|
||||
prompt_guard_response_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&prompt_guard_response_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "embeddings"),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddings"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(2))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let embedding_response = CreateEmbeddingResponse {
|
||||
data: vec![Embedding {
|
||||
index: 0,
|
||||
embedding: vec![],
|
||||
object: embedding::Object::default(),
|
||||
}],
|
||||
model: String::from("test"),
|
||||
object: create_embedding_response::Object::default(),
|
||||
usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)),
|
||||
};
|
||||
let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
2,
|
||||
0,
|
||||
embeddings_response_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&embeddings_response_buffer))
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "zeroshot"),
|
||||
(":method", "POST"),
|
||||
(":path", "/zeroshot"),
|
||||
(":authority", "zeroshot"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(3))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let zero_shot_response = ZeroShotClassificationResponse {
|
||||
predicted_class: "weather_forecast".to_string(),
|
||||
predicted_class_score: 0.1,
|
||||
scores: HashMap::new(),
|
||||
model: "test-model".to_string(),
|
||||
};
|
||||
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
3,
|
||||
0,
|
||||
zeroshot_intent_detection_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&zeroshot_intent_detection_buffer))
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
("x-arch-upstream", "arch_fc"),
|
||||
(":path", "/v1/chat/completions"),
|
||||
(":authority", "arch_fc"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "120000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
||||
|
|
@ -245,69 +111,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
|||
.execute_and_expect(ReturnType::Bool(true))
|
||||
.unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_tick(filter_context)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "embeddings"),
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddings"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(101))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_set_tick_period_millis(Some(5000))
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let embedding_response = CreateEmbeddingResponse {
|
||||
data: vec![Embedding {
|
||||
embedding: vec![],
|
||||
index: 0,
|
||||
object: embedding::Object::default(),
|
||||
}],
|
||||
model: String::from("test"),
|
||||
object: create_embedding_response::Object::default(),
|
||||
usage: Box::new(CreateEmbeddingResponseUsage {
|
||||
prompt_tokens: 0,
|
||||
total_tokens: 0,
|
||||
}),
|
||||
};
|
||||
let embedding_response_str = serde_json::to_string(&embedding_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
filter_context,
|
||||
101,
|
||||
0,
|
||||
embedding_response_str.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_log(
|
||||
Some(LogLevel::Trace),
|
||||
Some(
|
||||
format!(
|
||||
"filter_context: on_http_call_response called with token_id: {:?}",
|
||||
101
|
||||
)
|
||||
.as_str(),
|
||||
),
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&embedding_response_str))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
filter_context
|
||||
}
|
||||
|
||||
|
|
@ -432,6 +235,7 @@ fn prompt_gateway_successful_request_to_open_ai_chat_completions() {
|
|||
.returning(Some(chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(Some("arch_internal"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
|
|
@ -535,8 +339,8 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
completion_tokens: 0,
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: "test".to_string(),
|
||||
index: 0,
|
||||
finish_reason: Some("test".to_string()),
|
||||
index: Some(0),
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: None,
|
||||
|
|
@ -561,7 +365,7 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&arch_fc_resp_str))
|
||||
|
|
@ -569,47 +373,7 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
Some(vec![
|
||||
("x-arch-upstream", "hallucination"),
|
||||
(":method", "POST"),
|
||||
(":path", "/hallucination"),
|
||||
(":authority", "hallucination"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// hallucination should return that parameters were not halliucinated
|
||||
// prompt: str
|
||||
// parameters: dict
|
||||
// model: str
|
||||
|
||||
let hallucatination_body = HallucinationClassificationResponse {
|
||||
params_scores: HashMap::from([("city".to_string(), 0.99)]),
|
||||
model: "nli-model".to_string(),
|
||||
};
|
||||
|
||||
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_http_call(
|
||||
Some("arch_internal"),
|
||||
|
|
@ -625,14 +389,14 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(6))
|
||||
.returning(Some(2))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
|
|
@ -640,6 +404,10 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
|
||||
.returning(Some("200"))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
|
@ -649,8 +417,8 @@ fn prompt_gateway_request_to_llm_gateway() {
|
|||
completion_tokens: 0,
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: "test".to_string(),
|
||||
index: 0,
|
||||
finish_reason: Some("test".to_string()),
|
||||
index: Some(0),
|
||||
message: Message {
|
||||
role: "assistant".to_string(),
|
||||
content: Some("hello from fake llm gateway".to_string()),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue