Merge branch 'main' into adil/add_acm_demo

This commit is contained in:
Adil Hafeez 2025-01-08 16:55:07 -08:00
commit 68097fde07
166 changed files with 9507 additions and 11803 deletions

View file

@ -1,7 +1,23 @@
use common::{
common_types::open_ai::Message,
use std::collections::HashMap;
use crate::{
api::open_ai::Message,
consts::{ARCH_MODEL_PREFIX, HALLUCINATION_TEMPLATE, USER_ROLE},
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HallucinationClassificationRequest {
pub prompt: String,
pub parameters: HashMap<String, String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HallucinationClassificationResponse {
pub params_scores: HashMap<String, f64>,
pub model: String,
}
pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
let mut arch_assistant = false;
@ -42,7 +58,7 @@ pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec<String> {
#[cfg(test)]
mod test {
use common::common_types::open_ai::Message;
use crate::api::open_ai::Message;
use pretty_assertions::assert_eq;
use super::extract_messages_for_hallucination;

View file

@ -0,0 +1,4 @@
pub mod hallucination;
pub mod open_ai;
pub mod prompt_guard;
pub mod zero_shot;

View file

@ -0,0 +1,672 @@
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsRequest {
#[serde(default)]
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ChatCompletionTool>>,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ToolType {
#[serde(rename = "function")]
Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionTool {
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameters {
pub properties: HashMap<String, FunctionParameter>,
}
impl Serialize for FunctionParameters {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
// select all requried parameters
let required: Vec<&String> = self
.properties
.iter()
.filter(|(_, v)| v.required.unwrap_or(false))
.map(|(k, _)| k)
.collect();
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("properties", &self.properties)?;
if !required.is_empty() {
map.serialize_entry("required", &required)?;
}
map.end()
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameter {
#[serde(rename = "type")]
#[serde(default = "ParameterType::string")]
pub parameter_type: ParameterType,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "enum")]
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub format: Option<String>,
}
impl Serialize for FunctionParameter {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(Some(5))?;
map.serialize_entry("type", &self.parameter_type)?;
map.serialize_entry("description", &self.description)?;
if let Some(enum_values) = &self.enum_values {
map.serialize_entry("enum", enum_values)?;
}
if let Some(default) = &self.default {
map.serialize_entry("default", default)?;
}
if let Some(format) = &self.format {
map.serialize_entry("format", format)?;
}
map.end()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ParameterType {
#[serde(rename = "int")]
Int,
#[serde(rename = "float")]
Float,
#[serde(rename = "bool")]
Bool,
#[serde(rename = "str")]
String,
#[serde(rename = "list")]
List,
#[serde(rename = "dict")]
Dict,
}
impl From<String> for ParameterType {
fn from(s: String) -> Self {
match s.as_str() {
"int" => ParameterType::Int,
"integer" => ParameterType::Int,
"float" => ParameterType::Float,
"bool" => ParameterType::Bool,
"boolean" => ParameterType::Bool,
"str" => ParameterType::String,
"string" => ParameterType::String,
"list" => ParameterType::List,
"array" => ParameterType::List,
"dict" => ParameterType::Dict,
"dictionary" => ParameterType::Dict,
_ => ParameterType::String,
}
}
}
impl ParameterType {
pub fn string() -> ParameterType {
ParameterType::String
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
pub include_usage: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub finish_reason: Option<String>,
pub index: Option<usize>,
pub message: Message,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: FunctionCallDetail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDetail {
pub name: String,
pub arguments: HashMap<String, Value>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ToolCallState {
pub key: String,
pub message: Option<Message>,
pub tool_call: FunctionCallDetail,
pub tool_response: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ArchState {
ToolCall(Vec<ToolCallState>),
}
#[derive(Deserialize, Serialize)]
#[serde(untagged)]
pub enum ModelServerResponse {
ChatCompletionsResponse(ChatCompletionsResponse),
ModelServerErrorResponse(ModelServerErrorResponse),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelServerErrorResponse {
pub result: String,
pub intent_latency: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsResponse {
pub usage: Option<Usage>,
pub choices: Vec<Choice>,
pub model: String,
pub metadata: Option<HashMap<String, String>>,
}
impl ChatCompletionsResponse {
pub fn new(message: String) -> Self {
ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
role: ASSISTANT_ROLE.to_string(),
content: Some(message),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
},
index: Some(0),
finish_reason: Some("done".to_string()),
}],
usage: None,
model: ARCH_FC_MODEL_NAME.to_string(),
metadata: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub completion_tokens: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionStreamResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
pub choices: Vec<ChunkChoice>,
}
impl ChatCompletionStreamResponse {
pub fn new(
response: Option<String>,
role: Option<String>,
model: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
) -> Self {
ChatCompletionStreamResponse {
model,
choices: vec![ChunkChoice {
delta: Delta {
role,
content: response,
tool_calls,
model: None,
tool_call_id: None,
},
finish_reason: None,
}],
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ChatCompletionChunkResponseError {
#[error("failed to deserialize")]
Deserialization(#[from] serde_json::Error),
#[error("empty content in data chunk")]
EmptyContent,
#[error("no chunks present")]
NoChunks,
}
pub struct ChatCompletionStreamResponseServerEvents {
pub events: Vec<ChatCompletionStreamResponse>,
}
impl Display for ChatCompletionStreamResponseServerEvents {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let tokens_str = self
.events
.iter()
.map(|response_chunk| {
if response_chunk.choices.is_empty() {
return "".to_string();
}
response_chunk.choices[0]
.delta
.content
.clone()
.unwrap_or("".to_string())
})
.collect::<Vec<String>>()
.join("");
write!(f, "{}", tokens_str)
}
}
impl TryFrom<&str> for ChatCompletionStreamResponseServerEvents {
type Error = ChatCompletionChunkResponseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let response_chunks: VecDeque<ChatCompletionStreamResponse> = value
.lines()
.filter(|line| line.starts_with("data: "))
.map(|line| line.get(6..).unwrap())
.filter(|data_chunk| *data_chunk != "[DONE]")
.map(serde_json::from_str::<ChatCompletionStreamResponse>)
.collect::<Result<VecDeque<ChatCompletionStreamResponse>, _>>()?;
Ok(ChatCompletionStreamResponseServerEvents {
events: response_chunks.into(),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkChoice {
pub delta: Delta,
// TODO: could this be an enum?
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Delta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
let mut response_str = String::new();
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
}
#[cfg(test)]
mod test {
use super::{ChatCompletionStreamResponseServerEvents, Message};
use pretty_assertions::assert_eq;
use std::collections::HashMap;
const TOOL_SERIALIZED: &str = r#"{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "What city do you want to know the weather for?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "weather_forecast",
"description": "function to retrieve weather forecast",
"parameters": {
"properties": {
"city": {
"type": "str",
"description": "city for weather forecast",
"default": "test"
}
},
"required": [
"city"
]
}
}
}
],
"stream": true,
"stream_options": {
"include_usage": true
}
}"#;
#[test]
fn test_tool_type_request() {
use super::{
ChatCompletionTool, ChatCompletionsRequest, FunctionDefinition, FunctionParameter,
FunctionParameters, ParameterType, StreamOptions, ToolType,
};
let mut properties = HashMap::new();
properties.insert(
"city".to_string(),
FunctionParameter {
parameter_type: ParameterType::String,
description: "city for weather forecast".to_string(),
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
format: None,
},
);
let function_definition = FunctionDefinition {
name: "weather_forecast".to_string(),
description: "function to retrieve weather forecast".to_string(),
parameters: FunctionParameters { properties },
};
let chat_completions_request = ChatCompletionsRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: Some("What city do you want to know the weather for?".to_string()),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: Some(vec![ChatCompletionTool {
tool_type: ToolType::Function,
function: function_definition,
}]),
stream: true,
stream_options: Some(StreamOptions {
include_usage: true,
}),
metadata: None,
};
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();
println!("{}", serialized);
assert_eq!(TOOL_SERIALIZED, serialized);
}
#[test]
fn test_parameter_types() {
use super::{FunctionParameter, ParameterType};
const PARAMETER_SERIALZIED: &str = r#"{
"city": {
"type": "str",
"description": "city for weather forecast",
"default": "test"
}
}"#;
let properties = HashMap::from([(
"city".to_string(),
FunctionParameter {
parameter_type: ParameterType::String,
description: "city for weather forecast".to_string(),
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
format: None,
},
)]);
let serialized = serde_json::to_string_pretty(&properties).unwrap();
assert_eq!(PARAMETER_SERIALZIED, serialized);
// ensure that if type is missing it is set to string
const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#"
{
"city": {
"description": "city for weather forecast"
}
}"#;
let missing_type_deserialized: HashMap<String, FunctionParameter> =
serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap();
println!("{:?}", missing_type_deserialized);
assert_eq!(
missing_type_deserialized
.get("city")
.unwrap()
.parameter_type,
ParameterType::String
);
}
#[test]
fn stream_chunk_parse() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}]}
"#;
let sever_events =
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 5);
assert_eq!(
sever_events.events[0].choices[0]
.delta
.content
.as_ref()
.unwrap(),
""
);
assert_eq!(
sever_events.events[1].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"Hello"
);
assert_eq!(
sever_events.events[2].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"!"
);
assert_eq!(
sever_events.events[3].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" How"
);
assert_eq!(
sever_events.events[4].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" can"
);
assert_eq!(sever_events.to_string(), "Hello! How can");
}
#[test]
fn stream_chunk_parse_done() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
data: [DONE]
"#;
let sever_events: ChatCompletionStreamResponseServerEvents =
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 6);
assert_eq!(
sever_events.events[0].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" I"
);
assert_eq!(
sever_events.events[1].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" assist"
);
assert_eq!(
sever_events.events[2].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" you"
);
assert_eq!(
sever_events.events[3].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" today"
);
assert_eq!(
sever_events.events[4].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"?"
);
assert_eq!(sever_events.events[5].choices[0].delta.content, None);
assert_eq!(sever_events.to_string(), " I assist you today?");
}
#[test]
fn stream_chunk_parse_mistral() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}}
data: [DONE]
"#;
let sever_events: ChatCompletionStreamResponseServerEvents =
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 11);
assert_eq!(
sever_events.to_string(),
"Hello! How can I assist you today?"
);
}
}

View file

@ -0,0 +1,25 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PromptGuardTask {
#[serde(rename = "jailbreak")]
Jailbreak,
#[serde(rename = "toxicity")]
Toxicity,
#[serde(rename = "both")]
Both,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptGuardRequest {
pub input: String,
pub task: PromptGuardTask,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptGuardResponse {
pub toxic_prob: Option<f64>,
pub jailbreak_prob: Option<f64>,
pub toxic_verdict: Option<bool>,
pub jailbreak_verdict: Option<bool>,
}

View file

@ -0,0 +1,18 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationRequest {
pub input: String,
pub labels: Vec<String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationResponse {
pub predicted_class: String,
pub predicted_class_score: f64,
pub scores: HashMap<String, f64>,
pub model: String,
}

View file

@ -1,743 +0,0 @@
use crate::configuration::PromptTarget;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub prompt_target: PromptTarget,
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum EmbeddingType {
Name,
Description,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPoint {
pub id: String,
pub payload: HashMap<String, String>,
pub vector: Vec<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoreVectorEmbeddingsRequest {
pub points: Vec<VectorPoint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchPointResult {
pub id: String,
pub version: i32,
pub score: f64,
pub payload: HashMap<String, String>,
}
pub mod open_ai {
use std::{
collections::{HashMap, VecDeque},
fmt::Display,
};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use serde_yaml::Value;
use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsRequest {
#[serde(default)]
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ChatCompletionTool>>,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolType {
#[serde(rename = "function")]
Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionTool {
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameters {
pub properties: HashMap<String, FunctionParameter>,
}
impl Serialize for FunctionParameters {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
// select all requried parameters
let required: Vec<&String> = self
.properties
.iter()
.filter(|(_, v)| v.required.unwrap_or(false))
.map(|(k, _)| k)
.collect();
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("properties", &self.properties)?;
if !required.is_empty() {
map.serialize_entry("required", &required)?;
}
map.end()
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct FunctionParameter {
#[serde(rename = "type")]
#[serde(default = "ParameterType::string")]
pub parameter_type: ParameterType,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "enum")]
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
}
impl Serialize for FunctionParameter {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(Some(5))?;
map.serialize_entry("type", &self.parameter_type)?;
map.serialize_entry("description", &self.description)?;
if let Some(enum_values) = &self.enum_values {
map.serialize_entry("enum", enum_values)?;
}
if let Some(default) = &self.default {
map.serialize_entry("default", default)?;
}
map.end()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ParameterType {
#[serde(rename = "int")]
Int,
#[serde(rename = "float")]
Float,
#[serde(rename = "bool")]
Bool,
#[serde(rename = "str")]
String,
#[serde(rename = "list")]
List,
#[serde(rename = "dict")]
Dict,
}
impl From<String> for ParameterType {
fn from(s: String) -> Self {
match s.as_str() {
"int" => ParameterType::Int,
"integer" => ParameterType::Int,
"float" => ParameterType::Float,
"bool" => ParameterType::Bool,
"boolean" => ParameterType::Bool,
"str" => ParameterType::String,
"string" => ParameterType::String,
"list" => ParameterType::List,
"array" => ParameterType::List,
"dict" => ParameterType::Dict,
"dictionary" => ParameterType::Dict,
_ => ParameterType::String,
}
}
}
impl ParameterType {
pub fn string() -> ParameterType {
ParameterType::String
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
pub include_usage: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub finish_reason: String,
pub index: usize,
pub message: Message,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: FunctionCallDetail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDetail {
pub name: String,
pub arguments: HashMap<String, Value>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ToolCallState {
pub key: String,
pub message: Option<Message>,
pub tool_call: FunctionCallDetail,
pub tool_response: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ArchState {
ToolCall(Vec<ToolCallState>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsResponse {
pub usage: Option<Usage>,
pub choices: Vec<Choice>,
pub model: String,
pub metadata: Option<HashMap<String, String>>,
}
impl ChatCompletionsResponse {
pub fn new(message: String) -> Self {
ChatCompletionsResponse {
choices: vec![Choice {
message: Message {
role: ASSISTANT_ROLE.to_string(),
content: Some(message),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
},
index: 0,
finish_reason: "done".to_string(),
}],
usage: None,
model: ARCH_FC_MODEL_NAME.to_string(),
metadata: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub completion_tokens: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionStreamResponse {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
pub choices: Vec<ChunkChoice>,
}
impl ChatCompletionStreamResponse {
pub fn new(
response: Option<String>,
role: Option<String>,
model: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
) -> Self {
ChatCompletionStreamResponse {
model,
choices: vec![ChunkChoice {
delta: Delta {
role,
content: response,
tool_calls,
model: None,
tool_call_id: None,
},
finish_reason: None,
}],
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ChatCompletionChunkResponseError {
#[error("failed to deserialize")]
Deserialization(#[from] serde_json::Error),
#[error("empty content in data chunk")]
EmptyContent,
#[error("no chunks present")]
NoChunks,
}
pub struct ChatCompletionStreamResponseServerEvents {
pub events: Vec<ChatCompletionStreamResponse>,
}
impl Display for ChatCompletionStreamResponseServerEvents {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let tokens_str = self
.events
.iter()
.map(|response_chunk| {
if response_chunk.choices.is_empty() {
return "".to_string();
}
response_chunk.choices[0]
.delta
.content
.clone()
.unwrap_or("".to_string())
})
.collect::<Vec<String>>()
.join("");
write!(f, "{}", tokens_str)
}
}
impl TryFrom<&str> for ChatCompletionStreamResponseServerEvents {
type Error = ChatCompletionChunkResponseError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let response_chunks: VecDeque<ChatCompletionStreamResponse> = value
.lines()
.filter(|line| line.starts_with("data: "))
.map(|line| line.get(6..).unwrap())
.filter(|data_chunk| *data_chunk != "[DONE]")
.map(serde_json::from_str::<ChatCompletionStreamResponse>)
.collect::<Result<VecDeque<ChatCompletionStreamResponse>, _>>()?;
Ok(ChatCompletionStreamResponseServerEvents {
events: response_chunks.into(),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkChoice {
pub delta: Delta,
// TODO: could this be an enum?
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Delta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
let mut response_str = String::new();
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationRequest {
pub input: String,
pub labels: Vec<String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZeroShotClassificationResponse {
pub predicted_class: String,
pub predicted_class_score: f64,
pub scores: HashMap<String, f64>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HallucinationClassificationRequest {
pub prompt: String,
pub parameters: HashMap<String, String>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HallucinationClassificationResponse {
pub params_scores: HashMap<String, f64>,
pub model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PromptGuardTask {
#[serde(rename = "jailbreak")]
Jailbreak,
#[serde(rename = "toxicity")]
Toxicity,
#[serde(rename = "both")]
Both,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptGuardRequest {
pub input: String,
pub task: PromptGuardTask,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptGuardResponse {
pub toxic_prob: Option<f64>,
pub jailbreak_prob: Option<f64>,
pub toxic_verdict: Option<bool>,
pub jailbreak_verdict: Option<bool>,
}
#[cfg(test)]
mod test {
use crate::common_types::open_ai::{ChatCompletionStreamResponseServerEvents, Message};
use pretty_assertions::assert_eq;
use std::collections::HashMap;
const TOOL_SERIALIZED: &str = r#"{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "What city do you want to know the weather for?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "weather_forecast",
"description": "function to retrieve weather forecast",
"parameters": {
"properties": {
"city": {
"type": "str",
"description": "city for weather forecast",
"default": "test"
}
},
"required": [
"city"
]
}
}
}
],
"stream": true,
"stream_options": {
"include_usage": true
}
}"#;
#[test]
fn test_tool_type_request() {
use super::open_ai::{
ChatCompletionsRequest, FunctionDefinition, FunctionParameter, ParameterType, ToolType,
};
let mut properties = HashMap::new();
properties.insert(
"city".to_string(),
FunctionParameter {
parameter_type: ParameterType::String,
description: "city for weather forecast".to_string(),
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
},
);
let function_definition = FunctionDefinition {
name: "weather_forecast".to_string(),
description: "function to retrieve weather forecast".to_string(),
parameters: super::open_ai::FunctionParameters { properties },
};
let chat_completions_request = ChatCompletionsRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: Some("What city do you want to know the weather for?".to_string()),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: Some(vec![super::open_ai::ChatCompletionTool {
tool_type: ToolType::Function,
function: function_definition,
}]),
stream: true,
stream_options: Some(super::open_ai::StreamOptions {
include_usage: true,
}),
metadata: None,
};
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();
println!("{}", serialized);
assert_eq!(TOOL_SERIALIZED, serialized);
}
#[test]
fn test_parameter_types() {
use super::open_ai::{FunctionParameter, ParameterType};
const PARAMETER_SERIALZIED: &str = r#"{
"city": {
"type": "str",
"description": "city for weather forecast",
"default": "test"
}
}"#;
let properties = HashMap::from([(
"city".to_string(),
FunctionParameter {
parameter_type: ParameterType::String,
description: "city for weather forecast".to_string(),
required: Some(true),
enum_values: None,
default: Some("test".to_string()),
},
)]);
let serialized = serde_json::to_string_pretty(&properties).unwrap();
assert_eq!(PARAMETER_SERIALZIED, serialized);
// ensure that if type is missing it is set to string
const PARAMETER_SERIALZIED_MISSING_TYPE: &str = r#"
{
"city": {
"description": "city for weather forecast"
}
}"#;
let missing_type_deserialized: HashMap<String, FunctionParameter> =
serde_json::from_str(PARAMETER_SERIALZIED_MISSING_TYPE).unwrap();
println!("{:?}", missing_type_deserialized);
assert_eq!(
missing_type_deserialized
.get("city")
.unwrap()
.parameter_type,
ParameterType::String
);
}
#[test]
fn stream_chunk_parse() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}]}
"#;
let sever_events =
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 5);
assert_eq!(
sever_events.events[0].choices[0]
.delta
.content
.as_ref()
.unwrap(),
""
);
assert_eq!(
sever_events.events[1].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"Hello"
);
assert_eq!(
sever_events.events[2].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"!"
);
assert_eq!(
sever_events.events[3].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" How"
);
assert_eq!(
sever_events.events[4].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" can"
);
assert_eq!(sever_events.to_string(), "Hello! How can");
}
#[test]
fn stream_chunk_parse_done() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]}
data: [DONE]
"#;
let sever_events: ChatCompletionStreamResponseServerEvents =
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 6);
assert_eq!(
sever_events.events[0].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" I"
);
assert_eq!(
sever_events.events[1].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" assist"
);
assert_eq!(
sever_events.events[2].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" you"
);
assert_eq!(
sever_events.events[3].choices[0]
.delta
.content
.as_ref()
.unwrap(),
" today"
);
assert_eq!(
sever_events.events[4].choices[0]
.delta
.content
.as_ref()
.unwrap(),
"?"
);
assert_eq!(sever_events.events[5].choices[0].delta.content, None);
assert_eq!(sever_events.to_string(), " I assist you today?");
}
#[test]
fn stream_chunk_parse_mistral() {
const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" How"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" can"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" I"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" assist"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" you"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":" today"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":"?"},"finish_reason":null}]}
data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"total_tokens":13,"completion_tokens":9}}
data: [DONE]
"#;
let sever_events: ChatCompletionStreamResponseServerEvents =
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
assert_eq!(sever_events.events.len(), 11);
assert_eq!(
sever_events.to_string(),
"Hello! How can I assist you today?"
);
}
}

View file

@ -2,6 +2,26 @@ use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Display;
use crate::api::open_ai::{
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
pub listener: Listener,
pub endpoints: Option<HashMap<String, Endpoint>>,
pub llm_providers: Vec<LlmProvider>,
pub overrides: Option<Overrides>,
pub system_prompt: Option<String>,
pub prompt_guards: Option<PromptGuards>,
pub prompt_targets: Option<Vec<PromptTarget>>,
pub error_target: Option<ErrorTargetDetail>,
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Overrides {
pub prompt_target_intent_matching_threshold: Option<f64>,
@ -22,22 +42,6 @@ pub enum GatewayMode {
Prompt,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
pub listener: Listener,
pub endpoints: Option<HashMap<String, Endpoint>>,
pub llm_providers: Vec<LlmProvider>,
pub overrides: Option<Overrides>,
pub system_prompt: Option<String>,
pub prompt_guards: Option<PromptGuards>,
pub prompt_targets: Option<Vec<PromptTarget>>,
pub error_target: Option<ErrorTargetDetail>,
pub ratelimits: Option<Vec<Ratelimit>>,
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorTargetDetail {
pub endpoint: Option<EndpointDetails>,
@ -192,6 +196,7 @@ pub struct Parameter {
pub enum_values: Option<Vec<String>>,
pub default: Option<String>,
pub in_path: Option<bool>,
pub format: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]
@ -233,11 +238,47 @@ pub struct PromptTarget {
pub auto_llm_dispatch_on_response: Option<bool>,
}
// convert PromptTarget to ChatCompletionTool
impl From<&PromptTarget> for ChatCompletionTool {
fn from(val: &PromptTarget) -> Self {
let properties: HashMap<String, FunctionParameter> = match val.parameters {
Some(ref entities) => {
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
for entity in entities.iter() {
let param = FunctionParameter {
parameter_type: ParameterType::from(
entity.parameter_type.clone().unwrap_or("str".to_string()),
),
description: entity.description.clone(),
required: entity.required,
enum_values: entity.enum_values.clone(),
default: entity.default.clone(),
format: entity.format.clone(),
};
properties.insert(entity.name.clone(), param);
}
properties
}
None => HashMap::new(),
};
ChatCompletionTool {
tool_type: crate::api::open_ai::ToolType::Function,
function: FunctionDefinition {
name: val.name.clone(),
description: val.description.clone(),
parameters: FunctionParameters { properties },
},
}
}
}
#[cfg(test)]
mod test {
use pretty_assertions::assert_eq;
use std::fs;
use crate::configuration::GuardType;
use crate::{api::open_ai::ToolType, configuration::GuardType};
#[test]
fn test_deserialize_configuration() {
@ -309,4 +350,76 @@ mod test {
let mode = config.mode.as_ref().unwrap_or(&super::GatewayMode::Prompt);
assert_eq!(*mode, super::GatewayMode::Prompt);
}
#[test]
fn test_tool_conversion() {
let ref_config = fs::read_to_string(
"../../docs/source/resources/includes/arch_config_full_reference.yaml",
)
.expect("reference config file not found");
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
let prompt_targets = &config.prompt_targets;
let prompt_target = prompt_targets
.as_ref()
.unwrap()
.iter()
.find(|p| p.name == "reboot_network_device")
.unwrap();
let chat_completion_tool: super::ChatCompletionTool = prompt_target.into();
assert_eq!(chat_completion_tool.tool_type, ToolType::Function);
assert_eq!(chat_completion_tool.function.name, "reboot_network_device");
assert_eq!(
chat_completion_tool.function.description,
"Reboot a specific network device"
);
assert_eq!(chat_completion_tool.function.parameters.properties.len(), 2);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.contains_key("device_id"),
true
);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.get("device_id")
.unwrap()
.parameter_type,
crate::api::open_ai::ParameterType::String
);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.get("device_id")
.unwrap()
.description,
"Identifier of the network device to reboot.".to_string()
);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.get("device_id")
.unwrap()
.required,
Some(true)
);
assert_eq!(
chat_completion_tool
.function
.parameters
.properties
.get("confirmation")
.unwrap()
.parameter_type,
crate::api::open_ai::ParameterType::Bool
);
}
}

View file

@ -1,7 +1,3 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "katanemo/bge-large-en-v1.5";
pub const DEFAULT_INTENT_MODEL: &str = "katanemo/bart-large-mnli";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";
@ -9,11 +5,6 @@ pub const TOOL_ROLE: &str = "tool";
pub const ASSISTANT_ROLE: &str = "assistant";
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const MODEL_SERVER_NAME: &str = "model_server";
pub const ZEROSHOT_INTERNAL_HOST: &str = "zeroshot";
pub const ARCH_FC_INTERNAL_HOST: &str = "arch_fc";
pub const HALLUCINATION_INTERNAL_HOST: &str = "hallucination";
pub const EMBEDDINGS_INTERNAL_HOST: &str = "embeddings";
pub const GUARD_INTERNAL_HOST: &str = "guard";
pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const MESSAGES_KEY: &str = "messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
@ -25,7 +16,6 @@ pub const REQUEST_ID_HEADER: &str = "x-request-id";
pub const TRACE_PARENT_HEADER: &str = "traceparent";
pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal";
pub const ARCH_UPSTREAM_HOST_HEADER: &str = "x-arch-upstream";
pub const ARCH_LLM_UPSTREAM_LISTENER: &str = "arch_llm_listener";
pub const ARCH_MODEL_PREFIX: &str = "Arch";
pub const HALLUCINATION_TEMPLATE: &str =
"It seems I'm missing some information. Could you provide the following details ";

View file

@ -1,59 +0,0 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use crate::embeddings;
use serde::{Deserialize, Serialize};
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
pub struct CreateEmbeddingRequest {
#[serde(rename = "input")]
pub input: Box<embeddings::CreateEmbeddingRequestInput>,
/// ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.
#[serde(rename = "model")]
pub model: String,
/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
#[serde(rename = "encoding_format", skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EncodingFormat>,
/// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.
#[serde(rename = "dimensions", skip_serializing_if = "Option::is_none")]
pub dimensions: Option<i32>,
/// A unique identifier representing your end-user, which can help to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).
#[serde(rename = "user", skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}
impl CreateEmbeddingRequest {
pub fn new(
input: embeddings::CreateEmbeddingRequestInput,
model: String,
) -> CreateEmbeddingRequest {
CreateEmbeddingRequest {
input: Box::new(input),
model,
encoding_format: None,
dimensions: None,
user: None,
}
}
}
/// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub enum EncodingFormat {
#[serde(rename = "float")]
Float,
#[serde(rename = "base64")]
Base64,
}
impl Default for EncodingFormat {
fn default() -> EncodingFormat {
Self::Float
}
}

View file

@ -1,28 +0,0 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use serde::{Deserialize, Serialize};
/// CreateEmbeddingRequestInput : Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens.
/// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. for counting tokens.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CreateEmbeddingRequestInput {
/// The string that will be turned into an embedding.
String(String),
/// The array of integers that will be turned into an embedding.
Array(Vec<i32>),
}
impl Default for CreateEmbeddingRequestInput {
fn default() -> Self {
Self::String(Default::default())
}
}

View file

@ -1,55 +0,0 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use crate::embeddings;
use serde::{Deserialize, Serialize};
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
pub struct CreateEmbeddingResponse {
/// The list of embeddings generated by the model.
#[serde(rename = "data")]
pub data: Vec<embeddings::Embedding>,
/// The name of the model used to generate the embedding.
#[serde(rename = "model")]
pub model: String,
/// The object type, which is always \"list\".
#[serde(rename = "object")]
pub object: Object,
#[serde(rename = "usage")]
pub usage: Box<embeddings::CreateEmbeddingResponseUsage>,
}
impl CreateEmbeddingResponse {
pub fn new(
data: Vec<embeddings::Embedding>,
model: String,
object: Object,
usage: embeddings::CreateEmbeddingResponseUsage,
) -> CreateEmbeddingResponse {
CreateEmbeddingResponse {
data,
model,
object,
usage: Box::new(usage),
}
}
}
/// The object type, which is always \"list\".
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub enum Object {
#[serde(rename = "list")]
List,
}
impl Default for Object {
fn default() -> Object {
Self::List
}
}

View file

@ -1,32 +0,0 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use serde::{Deserialize, Serialize};
/// CreateEmbeddingResponseUsage : The usage information for the request.
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
pub struct CreateEmbeddingResponseUsage {
/// The number of tokens used by the prompt.
#[serde(rename = "prompt_tokens")]
pub prompt_tokens: i32,
/// The total number of tokens used by the request.
#[serde(rename = "total_tokens")]
pub total_tokens: i32,
}
impl CreateEmbeddingResponseUsage {
/// The usage information for the request.
pub fn new(prompt_tokens: i32, total_tokens: i32) -> CreateEmbeddingResponseUsage {
CreateEmbeddingResponseUsage {
prompt_tokens,
total_tokens,
}
}
}

View file

@ -1,48 +0,0 @@
/*
* OMF Embeddings
*
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 1.0.0
*
* Generated by: https://openapi-generator.tech
*/
use serde::{Deserialize, Serialize};
/// Embedding : Represents an embedding vector returned by embedding endpoint.
#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)]
pub struct Embedding {
/// The index of the embedding in the list of embeddings.
#[serde(rename = "index")]
pub index: i32,
/// The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings).
#[serde(rename = "embedding")]
pub embedding: Vec<f64>,
/// The object type, which is always \"embedding\"
#[serde(rename = "object")]
pub object: Object,
}
impl Embedding {
/// Represents an embedding vector returned by embedding endpoint.
pub fn new(index: i32, embedding: Vec<f64>, object: Object) -> Embedding {
Embedding {
index,
embedding,
object,
}
}
}
/// The object type, which is always \"embedding\"
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
pub enum Object {
#[serde(rename = "embedding")]
Embedding,
}
impl Default for Object {
fn default() -> Object {
Self::Embedding
}
}

View file

@ -1,10 +0,0 @@
pub mod create_embedding_request;
pub use self::create_embedding_request::CreateEmbeddingRequest;
pub mod create_embedding_request_input;
pub use self::create_embedding_request_input::CreateEmbeddingRequestInput;
pub mod create_embedding_response;
pub use self::create_embedding_response::CreateEmbeddingResponse;
pub mod create_embedding_response_usage;
pub use self::create_embedding_response_usage::CreateEmbeddingResponseUsage;
pub mod embedding;
pub use self::embedding::Embedding;

View file

@ -1,6 +1,6 @@
use proxy_wasm::types::Status;
use crate::{common_types::open_ai::ChatCompletionChunkResponseError, ratelimit};
use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit};
#[derive(thiserror::Error, Debug)]
pub enum ClientError {

View file

@ -1,14 +1,13 @@
pub mod common_types;
pub mod api;
pub mod configuration;
pub mod consts;
pub mod embeddings;
pub mod errors;
pub mod http;
pub mod llm_providers;
pub mod path;
pub mod pii;
pub mod ratelimit;
pub mod routing;
pub mod stats;
pub mod tokenizer;
pub mod tracing;
pub mod path;

View file

@ -1,6 +1,9 @@
use std::collections::HashMap;
pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> Result<String, String> {
pub fn replace_params_in_path(
path: &str,
params: &HashMap<String, String>,
) -> Result<String, String> {
let mut result = String::new();
let mut in_param = false;
let mut current_param = String::new();
@ -17,12 +20,10 @@ pub fn replace_params_in_path(path: &str, params: &HashMap<String, String>) -> R
return Err(format!("Missing value for parameter `{}`", param_name));
}
current_param.clear();
} else if in_param {
current_param.push(c);
} else {
if in_param {
current_param.push(c);
} else {
result.push(c);
}
result.push(c);
}
}

View file

@ -2,7 +2,6 @@ use log::error;
use proxy_wasm::hostcalls;
use proxy_wasm::types::*;
#[allow(unused)]
pub trait Metric {
fn id(&self) -> u32;
fn value(&self) -> Result<u64, String> {
@ -14,7 +13,6 @@ pub trait Metric {
}
}
#[allow(unused)]
pub trait IncrementingMetric: Metric {
fn increment(&self, offset: i64) {
match hostcalls::increment_metric(self.id(), offset) {
@ -24,7 +22,6 @@ pub trait IncrementingMetric: Metric {
}
}
#[allow(unused)]
pub trait RecordingMetric: Metric {
fn record(&self, value: u64) {
match hostcalls::record_metric(self.id(), value) {
@ -39,7 +36,6 @@ pub struct Counter {
id: u32,
}
#[allow(unused)]
impl Counter {
pub fn new(name: String) -> Counter {
let returned_id = hostcalls::define_metric(MetricType::Counter, &name)
@ -85,7 +81,6 @@ pub struct Histogram {
id: u32,
}
#[allow(unused)]
impl Histogram {
pub fn new(name: String) -> Histogram {
let returned_id = hostcalls::define_metric(MetricType::Histogram, &name)

View file

@ -1,3 +1,4 @@
use crate::metrics::Metrics;
use crate::stream_context::StreamContext;
use common::configuration::Configuration;
use common::consts::OTEL_COLLECTOR_HTTP;
@ -6,9 +7,7 @@ use common::http::CallArgs;
use common::http::Client;
use common::llm_providers::LlmProviders;
use common::ratelimit;
use common::stats::Counter;
use common::stats::Gauge;
use common::stats::Histogram;
use common::tracing::TraceData;
use log::debug;
use log::warn;
@ -22,39 +21,12 @@ use std::time::Duration;
use std::sync::{Arc, Mutex};
#[derive(Copy, Clone, Debug)]
pub struct WasmMetrics {
pub active_http_calls: Gauge,
pub ratelimited_rq: Counter,
pub time_to_first_token: Histogram,
pub time_per_output_token: Histogram,
pub tokens_per_second: Histogram,
pub request_latency: Histogram,
pub output_sequence_length: Histogram,
pub input_sequence_length: Histogram,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: Gauge::new(String::from("active_http_calls")),
ratelimited_rq: Counter::new(String::from("ratelimited_rq")),
time_to_first_token: Histogram::new(String::from("time_to_first_token")),
time_per_output_token: Histogram::new(String::from("time_per_output_token")),
tokens_per_second: Histogram::new(String::from("tokens_per_second")),
request_latency: Histogram::new(String::from("request_latency")),
output_sequence_length: Histogram::new(String::from("output_sequence_length")),
input_sequence_length: Histogram::new(String::from("input_sequence_length")),
}
}
}
#[derive(Debug)]
pub struct CallContext {}
#[derive(Debug)]
pub struct FilterContext {
metrics: Rc<WasmMetrics>,
metrics: Rc<Metrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, CallContext>>,
llm_providers: Option<Rc<LlmProviders>>,
@ -65,7 +37,7 @@ impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()),
metrics: Rc::new(Metrics::new()),
llm_providers: None,
traces_queue: Arc::new(Mutex::new(VecDeque::new())),
}

View file

@ -3,6 +3,7 @@ use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod filter_context;
mod metrics;
mod stream_context;
proxy_wasm::main! {{

View file

@ -0,0 +1,28 @@
use common::stats::{Counter, Gauge, Histogram};
#[derive(Copy, Clone, Debug)]
pub struct Metrics {
pub active_http_calls: Gauge,
pub ratelimited_rq: Counter,
pub time_to_first_token: Histogram,
pub time_per_output_token: Histogram,
pub tokens_per_second: Histogram,
pub request_latency: Histogram,
pub output_sequence_length: Histogram,
pub input_sequence_length: Histogram,
}
impl Metrics {
pub fn new() -> Metrics {
Metrics {
active_http_calls: Gauge::new(String::from("active_http_calls")),
ratelimited_rq: Counter::new(String::from("ratelimited_rq")),
time_to_first_token: Histogram::new(String::from("time_to_first_token")),
time_per_output_token: Histogram::new(String::from("time_per_output_token")),
tokens_per_second: Histogram::new(String::from("tokens_per_second")),
request_latency: Histogram::new(String::from("request_latency")),
output_sequence_length: Histogram::new(String::from("output_sequence_length")),
input_sequence_length: Histogram::new(String::from("input_sequence_length")),
}
}
}

View file

@ -1,5 +1,5 @@
use crate::filter_context::WasmMetrics;
use common::common_types::open_ai::{
use crate::metrics::Metrics;
use common::api::open_ai::{
ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
Message, StreamOptions,
};
@ -12,25 +12,23 @@ use common::errors::ServerError;
use common::llm_providers::LlmProviders;
use common::pii::obfuscate_auth_header;
use common::ratelimit::Header;
use common::stats::{IncrementingMetric, RecordingMetric};
use common::tracing::{Event, Span, TraceData, Traceparent};
use common::{ratelimit, routing, tokenizer};
use http::StatusCode;
use log::{debug, trace, warn};
use proxy_wasm::hostcalls::get_current_time;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::collections::VecDeque;
use std::num::NonZero;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use common::stats::{IncrementingMetric, RecordingMetric};
use proxy_wasm::hostcalls::get_current_time;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub struct StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
metrics: Rc<Metrics>,
ratelimit_selector: Option<Header>,
streaming_response: bool,
response_tokens: usize,
@ -50,7 +48,7 @@ pub struct StreamContext {
impl StreamContext {
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
metrics: Rc<Metrics>,
llm_providers: Rc<LlmProviders>,
traces_queue: Arc<Mutex<VecDeque<TraceData>>>,
) -> Self {

View file

@ -1,5 +1,9 @@
use std::str::FromStr;
use common::errors::ServerError;
use common::stats::IncrementingMetric;
use http::StatusCode;
use log::{debug, warn};
use proxy_wasm::traits::Context;
use crate::stream_context::{ResponseHandlerType, StreamContext};
@ -19,76 +23,34 @@ impl Context for StreamContext {
.expect("invalid token_id");
self.metrics.active_http_calls.increment(-1);
/*
state transition
let body = self
.get_http_call_response_body(0, body_size)
.unwrap_or(vec![]);
graph LR
on_http_request_body --> prompt received
prompt received --> get embeddings & arch guard
arch guard --> get embeddings
get embeddings --> zeroshot intent
on_http_request_body prompt received get embeddings zeroshot intent
arch guard
continue from zeroshot intent
graph LR
zeroshot intent --> arch_fc
zeroshot intent --> default prompt target
arch_fc --> developer api call & hallucination check
hallucination check --> parameter gathering & developer api call
developer api call --> resume request to llm
zeroshot intent arch_fc developer api call resume request to llm
default prompt target hallucination check parameter gathering
using https://mermaid-ascii.art/
*/
if let Some(body) = self.get_http_call_response_body(0, body_size) {
#[cfg_attr(any(), rustfmt::skip)]
match callout_context.response_handler_type {
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
ResponseHandlerType::Embeddings => self.embeddings_handler(body, callout_context),
ResponseHandlerType::ZeroShotIntent => self.zero_shot_intent_detection_resp_handler(body, callout_context),
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
ResponseHandlerType::Hallucination => self.hallucination_classification_resp_handler(body, callout_context),
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context),
}
} else {
self.send_server_error(
ServerError::LogicError(String::from("No response body in inline HTTP request")),
None,
let http_status = self
.get_http_call_response_header(":status")
.unwrap_or(StatusCode::OK.as_str().to_string());
debug!("http call response code: {}", http_status);
if http_status != StatusCode::OK.as_str() {
let server_error = ServerError::Upstream {
host: callout_context.upstream_cluster.unwrap(),
path: callout_context.upstream_cluster_path.unwrap(),
status: http_status.clone(),
body: String::from_utf8(body).unwrap(),
};
warn!("filter received non 2xx code: {:?}", server_error);
return self.send_server_error(
server_error,
Some(StatusCode::from_str(http_status.as_str()).unwrap()),
);
}
debug!("http call response handler type: {:?}", callout_context.response_handler_type);
#[cfg_attr(any(), rustfmt::skip)]
match callout_context.response_handler_type {
ResponseHandlerType::ArchFC => self.arch_fc_response_handler(body, callout_context),
ResponseHandlerType::FunctionCall => self.api_call_response_handler(body, callout_context),
ResponseHandlerType::DefaultTarget =>self.default_target_handler(body, callout_context),
}
}
}

View file

@ -1,60 +1,27 @@
use crate::metrics::Metrics;
use crate::stream_context::StreamContext;
use common::common_types::EmbeddingType;
use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget, Tracing};
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use common::consts::DEFAULT_EMBEDDING_MODEL;
use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use common::http::CallArgs;
use common::http::Client;
use common::stats::Gauge;
use common::stats::IncrementingMetric;
use http::StatusCode;
use log::{debug, info, trace, warn};
use log::debug;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::cell::RefCell;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::rc::Rc;
use std::time::Duration;
#[derive(Copy, Clone, Debug)]
pub struct WasmMetrics {
pub active_http_calls: Gauge,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: Gauge::new(String::from("active_http_calls")),
}
}
}
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
pub type EmbeddingsStore = HashMap<String, EmbeddingTypeMap>;
#[derive(Debug)]
pub struct FilterCallContext {
pub prompt_target_name: String,
pub embedding_type: EmbeddingType,
}
pub struct FilterCallContext {}
#[derive(Debug)]
pub struct FilterContext {
metrics: Rc<WasmMetrics>,
metrics: Rc<Metrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>,
overrides: Rc<Option<Overrides>>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
prompt_guards: Rc<PromptGuards>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
temp_embeddings_store: EmbeddingsStore,
active_embedding_calls_count: u32,
tracing: Rc<Option<Tracing>>,
}
@ -62,136 +29,14 @@ impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()),
metrics: Rc::new(Metrics::new()),
system_prompt: Rc::new(None),
prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()),
embeddings_store: Some(Rc::new(HashMap::new())),
temp_embeddings_store: HashMap::new(),
active_embedding_calls_count: 0,
tracing: Rc::new(None),
}
}
fn process_prompt_targets(&mut self) {
let prompt_target_description: Vec<(String, String)> = self
.prompt_targets
.iter()
.map(|(k, v)| (k.clone(), v.description.clone()))
.collect();
prompt_target_description
.iter()
.for_each(|(name, description)| {
self.schedule_embeddings_call(name, description, EmbeddingType::Description);
});
}
fn schedule_embeddings_call(
&mut self,
prompt_target_name: &str,
input: &str,
embedding_type: EmbeddingType,
) {
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(String::from(input))),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data = serde_json::to_string(&embeddings_input).unwrap();
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/embeddings",
vec![
(ARCH_UPSTREAM_HOST_HEADER, EMBEDDINGS_INTERNAL_HOST),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", EMBEDDINGS_INTERNAL_HOST),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(60),
);
let call_context = crate::filter_context::FilterCallContext {
prompt_target_name: String::from(prompt_target_name),
embedding_type,
};
self.active_embedding_calls_count += 1;
if let Err(error) = self.http_call(call_args, call_context) {
panic!("{error}")
}
}
fn embedding_response_handler(
&mut self,
embedding_type: EmbeddingType,
prompt_target_name: String,
body: Vec<u8>,
) {
let prompt_target = self
.prompt_targets
.get(&prompt_target_name)
.unwrap_or_else(|| {
panic!(
"Received embeddings response for unknown prompt target name={}",
prompt_target_name
)
});
if !body.is_empty() {
let mut embedding_response: CreateEmbeddingResponse =
match serde_json::from_slice(&body) {
Ok(response) => response,
Err(e) => {
panic!(
"Error deserializing embedding response. body: {:?}: {:?}",
String::from_utf8(body).unwrap(),
e
);
}
};
let embeddings = embedding_response.data.remove(0).embedding;
debug!(
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
prompt_target.name,
prompt_target.description,
embedding_type
);
let entry = self.temp_embeddings_store.entry(prompt_target_name);
match entry {
Entry::Occupied(_) => {
entry.and_modify(|e| {
if let Entry::Vacant(e) = e.entry(embedding_type) {
e.insert(embeddings);
} else {
panic!(
"Duplicate {:?} for prompt target with name=\"{}\"",
&embedding_type, prompt_target.name
)
}
});
}
Entry::Vacant(_) => {
entry.or_insert(HashMap::from([(embedding_type, embeddings)]));
}
}
if self.prompt_targets.len() == self.temp_embeddings_store.len() {
self.embeddings_store =
Some(Rc::new(std::mem::take(&mut self.temp_embeddings_store)))
}
}
}
}
impl Client for FilterContext {
@ -206,46 +51,7 @@ impl Client for FilterContext {
}
}
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
token_id: u32,
_num_headers: usize,
body_size: usize,
_num_trailers: usize,
) {
trace!(
"filter_context: on_http_call_response called with token_id: {:?}",
token_id
);
let callout_data = self
.callouts
.borrow_mut()
.remove(&token_id)
.expect("invalid token_id");
self.active_embedding_calls_count -= 1;
self.metrics.active_http_calls.increment(-1);
let body_bytes = self.get_http_call_response_body(0, body_size).unwrap();
if let Some(status_code) = self.get_http_call_response_header(":status") {
if status_code == StatusCode::OK.as_str() {
self.embedding_response_handler(
callout_data.embedding_type,
callout_data.prompt_target_name,
body_bytes,
);
} else {
warn!(
"Received non-200 status code: {} for callout with token_id: {}: body_str: {}",
status_code,
token_id,
String::from_utf8(body_bytes).unwrap()
);
}
}
}
}
impl Context for FilterContext {}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
@ -283,15 +89,12 @@ impl RootContext for FilterContext {
context_id
);
let embedding_store = self.embeddings_store.as_ref().map(Rc::clone);
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.system_prompt),
Rc::clone(&self.prompt_targets),
Rc::clone(&self.prompt_guards),
Rc::clone(&self.overrides),
embedding_store,
Rc::clone(&self.tracing),
)))
}
@ -301,25 +104,6 @@ impl RootContext for FilterContext {
}
fn on_vm_start(&mut self, _: usize) -> bool {
self.set_tick_period(Duration::from_secs(1));
true
}
fn on_tick(&mut self) {
if self.embeddings_store.is_some()
&& self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len()
{
info!("embeddings store initialized");
self.set_tick_period(Duration::from_secs(0));
} else {
if self.active_embedding_calls_count == 0 {
info!("retrieving embeddings from embedding server");
self.process_prompt_targets();
} else {
info!("waiting for embeddings store to be initialized");
}
self.set_tick_period(Duration::from_secs(5));
}
}
}

View file

@ -1,15 +1,12 @@
use crate::stream_context::{ResponseHandlerType, StreamCallContext, StreamContext};
use common::{
common_types::{
open_ai::{
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
},
PromptGuardRequest, PromptGuardTask,
api::open_ai::{
self, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
},
consts::{
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, GUARD_INTERNAL_HOST,
HEALTHZ_PATH, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
MODEL_SERVER_NAME, REQUEST_ID_HEADER, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
},
errors::ServerError,
http::{CallArgs, Client},
@ -37,11 +34,7 @@ impl HttpContext for StreamContext {
let request_path = self.get_http_request_header(":path").unwrap_or_default();
if request_path == HEALTHZ_PATH {
if self.is_embedding_store_initialized() {
self.send_http_response(200, vec![], None);
} else {
self.send_http_response(503, vec![], None);
}
self.send_http_response(200, vec![], None);
return Action::Continue;
}
@ -140,43 +133,25 @@ impl HttpContext for StreamContext {
self.user_prompt = Some(last_user_prompt.clone());
let user_message_str = self.user_prompt.as_ref().unwrap().content.clone();
// convert prompt targets to ChatCompletionTool
let tool_calls: Vec<ChatCompletionTool> = self
.prompt_targets
.iter()
.map(|(_, pt)| pt.into())
.collect();
let prompt_guard_jailbreak_task = self
.prompt_guards
.input_guards
.contains_key(&common::configuration::GuardType::Jailbreak);
let arch_fc_chat_completion_request = ChatCompletionsRequest {
messages: deserialized_body.messages.clone(),
metadata: deserialized_body.metadata.clone(),
stream: deserialized_body.stream,
model: "--".to_string(),
stream_options: deserialized_body.stream_options.clone(),
tools: Some(tool_calls),
};
self.chat_completions_request = Some(deserialized_body);
if !prompt_guard_jailbreak_task {
debug!("Missing input guard. Making inline call to retrieve embeddings");
let callout_context = StreamCallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: user_message_str.clone(),
prompt_target_name: None,
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
};
self.get_embeddings(callout_context);
return Action::Pause;
}
let get_prompt_guards_request = PromptGuardRequest {
input: self
.user_prompt
.as_ref()
.unwrap()
.content
.as_ref()
.unwrap()
.clone(),
task: PromptGuardTask::Jailbreak,
};
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
let json_data = match serde_json::to_string(&arch_fc_chat_completion_request) {
Ok(json_data) => json_data,
Err(error) => {
self.send_server_error(ServerError::Serialization(error), None);
@ -184,14 +159,14 @@ impl HttpContext for StreamContext {
}
};
debug!("archgw => archfc: {}", json_data);
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, GUARD_INTERNAL_HOST),
(ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME),
(":method", "POST"),
(":path", "/guard"),
(":authority", GUARD_INTERNAL_HOST),
(":path", "/function_calling"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
(":authority", MODEL_SERVER_NAME),
];
if self.request_id.is_some() {
@ -204,23 +179,25 @@ impl HttpContext for StreamContext {
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
"/guard",
"/function_calling",
headers,
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
);
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
response_handler_type: ResponseHandlerType::ArchFC,
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
prompt_target_name: None,
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
upstream_cluster: Some(ARCH_INTERNAL_CLUSTER_NAME.to_string()),
upstream_cluster_path: Some("/function_calling".to_string()),
};
if let Err(e) = self.http_call(call_args, call_context) {
debug!("http_call failed: {:?}", e);
self.send_server_error(ServerError::HttpDispatch(e), None);
}
@ -324,7 +301,7 @@ impl HttpContext for StreamContext {
),
];
let mut response_str = to_server_events(chunks);
let mut response_str = open_ai::to_server_events(chunks);
// append the original response from the model to the stream
response_str.push_str(&body_utf8);
self.set_http_response_body(0, body_size, response_str.as_bytes());
@ -339,9 +316,11 @@ impl HttpContext for StreamContext {
let mut data = match serde_json::from_str(&body_utf8) {
Ok(data) => data,
Err(e) => {
warn!("could not deserialize response: {}", e);
self.send_server_error(ServerError::Deserialization(e), None);
return Action::Pause;
warn!(
"could not deserialize response, sending data as it is: {}",
e
);
return Action::Continue;
}
};
// use serde::Value to manipulate the json object and ensure that we don't lose any data

View file

@ -4,8 +4,8 @@ use proxy_wasm::types::*;
mod context;
mod filter_context;
mod hallucination;
mod http_context;
mod metrics;
mod stream_context;
proxy_wasm::main! {{

View file

@ -0,0 +1,14 @@
use common::stats::Gauge;
#[derive(Copy, Clone, Debug)]
pub struct Metrics {
pub active_http_calls: Gauge,
}
impl Metrics {
pub fn new() -> Metrics {
Metrics {
active_http_calls: Gauge::new(String::from("active_http_calls")),
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,11 +1,7 @@
use common::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use common::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use common::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
use common::embeddings::{
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
Embedding,
use common::api::open_ai::{
ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage,
};
use common::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use common::configuration::Configuration;
use http::StatusCode;
use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
@ -80,13 +76,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "guard"),
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/guard"),
(":authority", "guard"),
(":path", "/function_calling"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
(":authority", "model_server"),
]),
None,
None,
@ -94,139 +88,11 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
)
.returning(Some(1))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
let prompt_guard_response = PromptGuardResponse {
toxic_prob: None,
toxic_verdict: None,
jailbreak_prob: None,
jailbreak_verdict: None,
};
let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
1,
0,
prompt_guard_response_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&prompt_guard_response_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "embeddings"),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddings"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let embedding_response = CreateEmbeddingResponse {
data: vec![Embedding {
index: 0,
embedding: vec![],
object: embedding::Object::default(),
}],
model: String::from("test"),
object: create_embedding_response::Object::default(),
usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)),
};
let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
2,
0,
embeddings_response_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embeddings_response_buffer))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "zeroshot"),
(":method", "POST"),
(":path", "/zeroshot"),
(":authority", "zeroshot"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(3))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let zero_shot_response = ZeroShotClassificationResponse {
predicted_class: "weather_forecast".to_string(),
predicted_class_score: 0.1,
scores: HashMap::new(),
model: "test-model".to_string(),
};
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
3,
0,
zeroshot_intent_detection_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&zeroshot_intent_detection_buffer))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
(":method", "POST"),
("x-arch-upstream", "arch_fc"),
(":path", "/v1/chat/completions"),
(":authority", "arch_fc"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "120000"),
]),
None,
None,
None,
)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
}
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
@ -245,69 +111,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
module
.call_proxy_on_tick(filter_context)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "embeddings"),
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddings"),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(101))
.expect_metric_increment("active_http_calls", 1)
.expect_set_tick_period_millis(Some(5000))
.execute_and_expect(ReturnType::None)
.unwrap();
let embedding_response = CreateEmbeddingResponse {
data: vec![Embedding {
embedding: vec![],
index: 0,
object: embedding::Object::default(),
}],
model: String::from("test"),
object: create_embedding_response::Object::default(),
usage: Box::new(CreateEmbeddingResponseUsage {
prompt_tokens: 0,
total_tokens: 0,
}),
};
let embedding_response_str = serde_json::to_string(&embedding_response).unwrap();
module
.call_proxy_on_http_call_response(
filter_context,
101,
0,
embedding_response_str.len() as i32,
0,
)
.expect_log(
Some(LogLevel::Trace),
Some(
format!(
"filter_context: on_http_call_response called with token_id: {:?}",
101
)
.as_str(),
),
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embedding_response_str))
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
filter_context
}
@ -432,6 +235,7 @@ fn prompt_gateway_successful_request_to_open_ai_chat_completions() {
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4))
@ -535,8 +339,8 @@ fn prompt_gateway_request_to_llm_gateway() {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: "test".to_string(),
index: 0,
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "system".to_string(),
content: None,
@ -561,7 +365,7 @@ fn prompt_gateway_request_to_llm_gateway() {
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
.call_proxy_on_http_call_response(http_context, 1, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
@ -569,47 +373,7 @@ fn prompt_gateway_request_to_llm_gateway() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "hallucination"),
(":method", "POST"),
(":path", "/hallucination"),
(":authority", "hallucination"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(5))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
// hallucination should return that parameters were not halliucinated
// prompt: str
// parameters: dict
// model: str
let hallucatination_body = HallucinationClassificationResponse {
params_scores: HashMap::from([("city".to_string(), 0.99)]),
model: "nli-model".to_string(),
};
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
module
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
@ -625,14 +389,14 @@ fn prompt_gateway_request_to_llm_gateway() {
None,
None,
)
.returning(Some(6))
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
.call_proxy_on_http_call_response(http_context, 2, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
@ -640,6 +404,10 @@ fn prompt_gateway_request_to_llm_gateway() {
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200"))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -649,8 +417,8 @@ fn prompt_gateway_request_to_llm_gateway() {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: "test".to_string(),
index: 0,
finish_reason: Some("test".to_string()),
index: Some(0),
message: Message {
role: "assistant".to_string(),
content: Some("hello from fake llm gateway".to_string()),