mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
add e2e tests
This commit is contained in:
parent
a5cbd2a978
commit
1064301f45
19 changed files with 1147 additions and 78 deletions
|
|
@ -185,12 +185,16 @@ pub mod open_ai {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
|
@ -244,8 +248,9 @@ pub mod open_ai {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionChunkResponse {
|
||||
pub model: String,
|
||||
pub struct ChatCompletionStreamResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
|
|
@ -259,11 +264,11 @@ pub mod open_ai {
|
|||
NoChunks,
|
||||
}
|
||||
|
||||
pub struct ChatCompletionChunkResponseServerEvents {
|
||||
pub events: Vec<ChatCompletionChunkResponse>,
|
||||
pub struct ChatCompletionStreamResponseServerEvents {
|
||||
pub events: Vec<ChatCompletionStreamResponse>,
|
||||
}
|
||||
|
||||
impl Display for ChatCompletionChunkResponseServerEvents {
|
||||
impl Display for ChatCompletionStreamResponseServerEvents {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let tokens_str = self
|
||||
.events
|
||||
|
|
@ -285,19 +290,19 @@ pub mod open_ai {
|
|||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for ChatCompletionChunkResponseServerEvents {
|
||||
impl TryFrom<&str> for ChatCompletionStreamResponseServerEvents {
|
||||
type Error = ChatCompletionChunkResponseError;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
let response_chunks: VecDeque<ChatCompletionChunkResponse> = value
|
||||
let response_chunks: VecDeque<ChatCompletionStreamResponse> = value
|
||||
.lines()
|
||||
.filter(|line| line.starts_with("data: "))
|
||||
.map(|line| line.get(6..).unwrap())
|
||||
.filter(|data_chunk| *data_chunk != "[DONE]")
|
||||
.map(serde_json::from_str::<ChatCompletionChunkResponse>)
|
||||
.collect::<Result<VecDeque<ChatCompletionChunkResponse>, _>>()?;
|
||||
.map(serde_json::from_str::<ChatCompletionStreamResponse>)
|
||||
.collect::<Result<VecDeque<ChatCompletionStreamResponse>, _>>()?;
|
||||
|
||||
Ok(ChatCompletionChunkResponseServerEvents {
|
||||
Ok(ChatCompletionStreamResponseServerEvents {
|
||||
events: response_chunks.into(),
|
||||
})
|
||||
}
|
||||
|
|
@ -312,7 +317,20 @@ pub mod open_ai {
|
|||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Delta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -370,7 +388,7 @@ pub struct PromptGuardResponse {
|
|||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::common_types::open_ai::{ChatCompletionChunkResponseServerEvents, Message};
|
||||
use crate::common_types::open_ai::{ChatCompletionStreamResponseServerEvents, Message};
|
||||
use pretty_assertions::{assert_eq, assert_ne};
|
||||
use std::collections::HashMap;
|
||||
|
||||
|
|
@ -508,7 +526,7 @@ mod test {
|
|||
|
||||
#[test]
|
||||
fn stream_chunk_parse() {
|
||||
use super::open_ai::{ChatCompletionChunkResponse, ChunkChoice, Delta};
|
||||
use super::open_ai::{ChatCompletionStreamResponse, ChunkChoice, Delta};
|
||||
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.chunk","created":1729755226,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
|
@ -524,7 +542,7 @@ data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.c
|
|||
"#;
|
||||
|
||||
let sever_events =
|
||||
ChatCompletionChunkResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 5);
|
||||
assert_eq!(
|
||||
sever_events.events[0].choices[0]
|
||||
|
|
@ -571,7 +589,7 @@ data: {"id":"chatcmpl-ALmdmtKulBMEq3fRLbrnxJwcKOqvS","object":"chat.completion.c
|
|||
|
||||
#[test]
|
||||
fn stream_chunk_parse_done() {
|
||||
use super::open_ai::{ChatCompletionChunkResponse, ChunkChoice, Delta};
|
||||
use super::open_ai::{ChatCompletionStreamResponse, ChunkChoice, Delta};
|
||||
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.chunk","created":1729756748,"model":"gpt-3.5-turbo-0125","system_fingerprint":null,"choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
|
|
@ -588,8 +606,8 @@ data: {"id":"chatcmpl-ALn2KTfmrIpYd9N3Un4Kyg08WIIP6","object":"chat.completion.c
|
|||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let sever_events: ChatCompletionChunkResponseServerEvents =
|
||||
ChatCompletionChunkResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
let sever_events: ChatCompletionStreamResponseServerEvents =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 6);
|
||||
assert_eq!(
|
||||
sever_events.events[0].choices[0]
|
||||
|
|
@ -638,7 +656,7 @@ data: [DONE]
|
|||
|
||||
#[test]
|
||||
fn stream_chunk_parse_mistral() {
|
||||
use super::open_ai::{ChatCompletionChunkResponse, ChunkChoice, Delta};
|
||||
use super::open_ai::{ChatCompletionStreamResponse, ChunkChoice, Delta};
|
||||
|
||||
const CHUNK_RESPONSE: &str = r#"data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk","created":1729805261,"model":"ministral-8b-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
|
||||
|
||||
|
|
@ -665,8 +683,8 @@ data: {"id":"e1ebce16de5443b79613512c2d757936","object":"chat.completion.chunk",
|
|||
data: [DONE]
|
||||
"#;
|
||||
|
||||
let sever_events: ChatCompletionChunkResponseServerEvents =
|
||||
ChatCompletionChunkResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
let sever_events: ChatCompletionStreamResponseServerEvents =
|
||||
ChatCompletionStreamResponseServerEvents::try_from(CHUNK_RESPONSE).unwrap();
|
||||
assert_eq!(sever_events.events.len(), 11);
|
||||
|
||||
assert_eq!(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
|||
pub const MESSAGES_KEY: &str = "messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const HEALTHZ_PATH: &str = "/healthz";
|
||||
pub const ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B";
|
||||
pub const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use crate::filter_context::WasmMetrics;
|
||||
use common::common_types::open_ai::{
|
||||
ChatCompletionChunkResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
|
||||
ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse,
|
||||
StreamOptions,
|
||||
};
|
||||
use common::configuration::LlmProvider;
|
||||
|
|
@ -13,7 +13,7 @@ use common::llm_providers::LlmProviders;
|
|||
use common::ratelimit::Header;
|
||||
use common::{ratelimit, routing, tokenizer};
|
||||
use http::StatusCode;
|
||||
use log::{debug, warn};
|
||||
use log::{debug, trace, warn};
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use std::num::NonZero;
|
||||
|
|
@ -222,6 +222,12 @@ impl HttpContext for StreamContext {
|
|||
.clone_from(&self.llm_provider.as_ref().unwrap().model);
|
||||
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
||||
|
||||
trace!(
|
||||
"arch => {:?}, body: {}",
|
||||
deserialized_body.model,
|
||||
chat_completion_request_str
|
||||
);
|
||||
|
||||
if deserialized_body.stream {
|
||||
self.streaming_response = Some(StreamingResponse::new());
|
||||
}
|
||||
|
|
@ -243,10 +249,6 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"arch => {:?}, body: {}",
|
||||
deserialized_body.model, chat_completion_request_str
|
||||
);
|
||||
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
|
||||
|
||||
Action::Continue
|
||||
|
|
@ -318,7 +320,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
if self.streaming_response.is_some() {
|
||||
let chat_completions_chunk_response_events =
|
||||
match ChatCompletionChunkResponseServerEvents::try_from(body_utf8.as_str()) {
|
||||
match ChatCompletionStreamResponseServerEvents::try_from(body_utf8.as_str()) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
|
|
@ -343,16 +345,20 @@ impl HttpContext for StreamContext {
|
|||
let tokens_str = chat_completions_chunk_response_events.to_string();
|
||||
//HACK: add support for tokenizing mistral and other models
|
||||
//filed issue https://github.com/katanemo/arch/issues/222
|
||||
if model.starts_with("mistral") || model.starts_with("ministral") {
|
||||
model = "gpt-4".to_string();
|
||||
if model.as_ref().unwrap().starts_with("mistral")
|
||||
|| model.as_ref().unwrap().starts_with("ministral")
|
||||
{
|
||||
model = Some("gpt-4".to_string());
|
||||
}
|
||||
let token_count = match tokenizer::token_count(model.as_str(), tokens_str.as_str()) {
|
||||
Ok(token_count) => token_count,
|
||||
Err(e) => {
|
||||
debug!("could not get token count: {:?}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
let token_count =
|
||||
match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str())
|
||||
{
|
||||
Ok(token_count) => token_count,
|
||||
Err(e) => {
|
||||
debug!("could not get token count: {:?}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
self.response_tokens += token_count;
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
|
|
|
|||
|
|
@ -3,14 +3,14 @@ use std::{collections::HashMap, time::Duration};
|
|||
use common::{
|
||||
common_types::{
|
||||
open_ai::{
|
||||
ArchState, ChatCompletionChunkResponseServerEvents, ChatCompletionsRequest, Message,
|
||||
ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, ChunkChoice, Delta,
|
||||
},
|
||||
PromptGuardRequest, PromptGuardTask,
|
||||
},
|
||||
consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_STATE_HEADER,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, GUARD_INTERNAL_HOST,
|
||||
REQUEST_ID_HEADER, TOOL_ROLE, USER_ROLE,
|
||||
HEALTHZ_PATH, REQUEST_ID_HEADER, TOOL_ROLE, USER_ROLE,
|
||||
},
|
||||
errors::ServerError,
|
||||
http::{CallArgs, Client},
|
||||
|
|
@ -33,6 +33,15 @@ impl HttpContext for StreamContext {
|
|||
// manipulate the body in benign ways e.g., compression.
|
||||
self.set_http_request_header("content-length", None);
|
||||
|
||||
if self.get_http_request_header(":path").unwrap_or_default() == HEALTHZ_PATH {
|
||||
if self.embeddings_store.is_none() {
|
||||
self.send_http_response(503, vec![], None);
|
||||
} else {
|
||||
self.send_http_response(200, vec![], None);
|
||||
}
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.is_chat_completions_request =
|
||||
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
|
||||
|
||||
|
|
@ -279,21 +288,46 @@ impl HttpContext for StreamContext {
|
|||
if self.streaming_response {
|
||||
trace!("streaming response");
|
||||
|
||||
let chat_completions_chunk_response_events =
|
||||
match ChatCompletionChunkResponseServerEvents::try_from(body_utf8.as_str()) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"invalid streaming response: body str: {}, {:?}",
|
||||
body_utf8, e
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
if self.tool_calls.is_some() {
|
||||
let tool_call_chunk = ChatCompletionStreamResponse {
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
choices: vec![ChunkChoice {
|
||||
delta: Delta {
|
||||
role: Some(ASSISTANT_ROLE.to_string()),
|
||||
tool_calls: self.tool_calls.to_owned(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
debug!(
|
||||
"parsed events: {}",
|
||||
chat_completions_chunk_response_events.to_string()
|
||||
);
|
||||
|
||||
let tool_call_chunk_str = serde_json::to_string(&tool_call_chunk).unwrap();
|
||||
|
||||
let api_call_chunk = ChatCompletionStreamResponse {
|
||||
model: None,
|
||||
choices: vec![ChunkChoice {
|
||||
delta: Delta {
|
||||
role: Some(TOOL_ROLE.to_string()),
|
||||
tool_calls: None,
|
||||
content: self.tool_call_response.clone(),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
|
||||
let api_call_chunk_str = serde_json::to_string(&api_call_chunk).unwrap();
|
||||
let chunk_str = format!(
|
||||
"data: {}\n\ndata: {}\n\n{}",
|
||||
tool_call_chunk_str, api_call_chunk_str, body_utf8
|
||||
);
|
||||
|
||||
self.set_http_response_body(0, body_size, chunk_str.as_bytes());
|
||||
self.tool_calls = None;
|
||||
}
|
||||
} else if let Some(tool_calls) = self.tool_calls.as_ref() {
|
||||
if !tool_calls.is_empty() {
|
||||
if self.arch_state.is_none() {
|
||||
|
|
@ -311,24 +345,9 @@ impl HttpContext for StreamContext {
|
|||
*metadata = Value::Object(serde_json::Map::new());
|
||||
}
|
||||
|
||||
// since arch gateway generates tool calls (using arch-fc) and calls upstream api to
|
||||
// get response, we will send these back to developer so they can see the api response
|
||||
// and tool call arch-fc generated
|
||||
let fc_messages = vec![
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None,
|
||||
},
|
||||
Message {
|
||||
role: TOOL_ROLE.to_string(),
|
||||
content: self.tool_call_response.clone(),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
|
||||
},
|
||||
self.generate_toll_call_message(),
|
||||
self.generate_api_response_message(),
|
||||
];
|
||||
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
|
||||
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ pub struct StreamCallContext {
|
|||
pub struct StreamContext {
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
pub embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
pub metrics: Rc<WasmMetrics>,
|
||||
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||
|
|
@ -309,7 +309,11 @@ impl StreamContext {
|
|||
match serde_json::from_str(boyd_str.as_str()) {
|
||||
Ok(hallucination_response) => hallucination_response,
|
||||
Err(e) => {
|
||||
warn!("error deserializing hallucination response: {}", e);
|
||||
warn!(
|
||||
"error deserializing hallucination response: {}, body: {}",
|
||||
e,
|
||||
boyd_str.as_str()
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -1015,6 +1019,26 @@ impl StreamContext {
|
|||
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
||||
pub fn generate_toll_call_message(&mut self) -> Message {
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_api_response_message(&mut self) -> Message {
|
||||
Message {
|
||||
role: TOOL_ROLE.to_string(),
|
||||
content: self.tool_call_response.clone(),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for StreamContext {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue