diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 44ceda27..5f31abf6 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -80,7 +80,17 @@ def validate_and_render_schema(): llms_with_endpoint = [] updated_llm_providers = [] + llm_provider_name_set = set() for llm_provider in config_yaml["llm_providers"]: + if llm_provider.get("name") in llm_provider_name_set: + raise Exception( + f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider" + ) + if llm_provider.get("name") is None: + raise Exception( + f"llm_provider name is required, please provide name for llm_provider" + ) + llm_provider_name_set.add(llm_provider.get("name")) provider = None if llm_provider.get("provider") and llm_provider.get("provider_interface"): raise Exception( diff --git a/arch/tools/cli/docker_cli.py b/arch/tools/cli/docker_cli.py index b5a31040..8cbf47a6 100644 --- a/arch/tools/cli/docker_cli.py +++ b/arch/tools/cli/docker_cli.py @@ -58,7 +58,6 @@ def docker_start_archgw_detached( volume_mappings = [ f"{arch_config_file}:/app/arch_config.yaml:ro", - # "/Users/adilhafeez/src/intelligent-prompt-gateway/crates/target/wasm32-wasip1/release:/etc/envoy/proxy-wasm-plugins:ro", ] volume_mappings_args = [ item for volume in volume_mappings for item in ("-v", volume) diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index d1610d61..50e65915 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -9,10 +9,11 @@ use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::Frame; use hyper::header::{self}; use hyper::{Request, Response, StatusCode}; +use serde_json::Value; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; -use tracing::{info, warn}; +use tracing::{debug, info, warn}; use crate::router::llm_router::RouterService; @@ -30,19 +31,23 @@ pub async fn chat_completions( let mut request_headers = request.headers().clone(); let chat_request_bytes = request.collect().await?.to_bytes(); + let chat_completion_request: ChatCompletionsRequest = match serde_json::from_slice(&chat_request_bytes) { Ok(request) => request, Err(err) => { + let v: Value = serde_json::from_slice(&chat_request_bytes).unwrap(); let err_msg = format!("Failed to parse request body: {}", err); + warn!("{}", err_msg); + warn!("request body: {}", v.to_string()); let mut bad_request = Response::new(full(err_msg)); *bad_request.status_mut() = StatusCode::BAD_REQUEST; return Ok(bad_request); } }; - info!( - "request body received: {}", + debug!( + "request body: {}", shorten_string(&serde_json::to_string(&chat_completion_request).unwrap()) ); diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index f683638f..b758bdde 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -2,38 +2,27 @@ use brightstaff::handlers::chat_completions::chat_completions; use brightstaff::router::llm_router::RouterService; use bytes::Bytes; use common::configuration::Configuration; -use common::utils::shorten_string; use http_body_util::{combinators::BoxBody, BodyExt, Empty}; use hyper::body::Incoming; use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::{Method, Request, Response, StatusCode}; use hyper_util::rt::TokioIo; -use opentelemetry::global::BoxedTracer; use opentelemetry::trace::FutureExt; -use opentelemetry::{ - global, - trace::{SpanKind, Tracer}, - Context, -}; +use opentelemetry::{global, Context}; use opentelemetry_http::HeaderExtractor; use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider}; use opentelemetry_stdout::SpanExporter; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use std::{env, fs}; use tokio::net::TcpListener; -use tracing::info; +use tracing::{debug, info}; use tracing_subscriber::EnvFilter; pub mod router; const BIND_ADDRESS: &str = "0.0.0.0:9091"; -fn get_tracer() -> &'static BoxedTracer { - static TRACER: OnceLock = OnceLock::new(); - TRACER.get_or_init(|| global::tracer("archgw/router")) -} - // Utility function to extract the context from the incoming request headers fn extract_context_from_request(req: &Request) -> Context { global::get_text_map_propagator(|propagator| { @@ -83,24 +72,24 @@ async fn main() -> Result<(), Box> { let arch_config = Arc::new(config); - info!( + debug!( "arch_config: {:?}", - shorten_string(&serde_json::to_string(arch_config.as_ref()).unwrap()) + &serde_json::to_string(arch_config.as_ref()).unwrap() ); let llm_provider_endpoint = env::var("LLM_PROVIDER_ENDPOINT") .unwrap_or_else(|_| "http://localhost:12001/v1/chat/completions".to_string()); info!("llm provider endpoint: {}", llm_provider_endpoint); - info!("Listening on http://{}", bind_address); + info!("listening on http://{}", bind_address); let listener = TcpListener::bind(bind_address).await?; - // if routing is null then return gpt-4o as model name - let model = arch_config.routing.as_ref().map_or_else( - || "gpt-4o".to_string(), - |routing| routing.model.clone(), - ); + //TODO: fail if routing is null + let model = arch_config + .routing + .as_ref() + .map_or_else(|| "gpt-4o".to_string(), |routing| routing.model.clone()); let router_service: Arc = Arc::new(RouterService::new( arch_config.llm_providers.clone(), @@ -119,12 +108,6 @@ async fn main() -> Result<(), Box> { let service = service_fn(move |req| { let router_service = Arc::clone(&router_service); let parent_cx = extract_context_from_request(&req); - info!("parent_cx: {:?}", parent_cx); - let tracer = get_tracer(); - let _span = tracer - .span_builder("request") - .with_kind(SpanKind::Server) - .start_with_context(tracer, &parent_cx); let llm_provider_endpoint = llm_provider_endpoint.clone(); async move { diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index 47f2b41c..6fe644a8 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -1,14 +1,13 @@ use std::sync::Arc; use common::{ - api::open_ai::{ChatCompletionsResponse, Message}, + api::open_ai::{ChatCompletionsResponse, ContentType, Message}, configuration::LlmProvider, consts::ARCH_PROVIDER_HINT_HEADER, - utils::shorten_string, }; use hyper::header; use thiserror::Error; -use tracing::{info, warn}; +use tracing::{debug, info, warn}; use super::router_model::RouterModel; @@ -59,9 +58,9 @@ impl RouterService { .collect::>() .join("\n"); - info!( + debug!( "llm_providers from config with usage: {}...", - shorten_string(&llm_providers_with_usage_yaml.replace("\n", "\\n")) + llm_providers_with_usage_yaml.replace("\n", "\\n") ); let router_model = Arc::new(super::router_model_v1::RouterModelV1::new( @@ -83,7 +82,6 @@ impl RouterService { messages: &[Message], trace_parent: Option, ) -> Result> { - if !self.llm_usage_defined { return Ok(None); } @@ -91,8 +89,14 @@ impl RouterService { let router_request = self.router_model.generate_request(messages); info!( - "router_request: {}", - shorten_string(&serde_json::to_string(&router_request).unwrap()), + "sending request to arch-router model: {}, endpoint: {}", + self.router_model.get_model_name(), + self.router_url + ); + + debug!( + "arch request body: {}", + &serde_json::to_string(&router_request).unwrap(), ); let mut llm_route_request_headers = header::HeaderMap::new(); @@ -113,6 +117,7 @@ impl RouterService { ); } + let start_time = std::time::Instant::now(); let res = self .client .post(&self.router_url) @@ -122,6 +127,7 @@ impl RouterService { .await?; let body = res.text().await?; + let router_response_time = start_time.elapsed(); let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) { Ok(response) => response, @@ -138,14 +144,18 @@ impl RouterService { } }; - let selected_llm = self.router_model.parse_response( - chat_completion_response.choices[0] - .message - .content - .as_ref() - .unwrap(), - )?; - - Ok(selected_llm) + if let Some(ContentType::Text(content)) = + &chat_completion_response.choices[0].message.content + { + info!( + "router response: {}, response time: {}ms", + content.replace("\n", "\\n"), + router_response_time.as_millis() + ); + let selected_llm = self.router_model.parse_response(content)?; + Ok(selected_llm) + } else { + Ok(None) + } } } diff --git a/crates/brightstaff/src/router/router_model.rs b/crates/brightstaff/src/router/router_model.rs index e9f5e256..6e591e4c 100644 --- a/crates/brightstaff/src/router/router_model.rs +++ b/crates/brightstaff/src/router/router_model.rs @@ -12,4 +12,5 @@ pub type Result = std::result::Result; pub trait RouterModel: Send + Sync { fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest; fn parse_response(&self, content: &str) -> Result>; + fn get_model_name(&self) -> String; } diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index 836c79e5..8b0b0ecf 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -1,9 +1,8 @@ use common::{ - api::open_ai::{ChatCompletionsRequest, Message}, + api::open_ai::{ChatCompletionsRequest, ContentType, Message}, consts::{SYSTEM_ROLE, USER_ROLE}, }; use serde::{Deserialize, Serialize}; -use tracing::info; use super::router_model::{RouterModel, RoutingModelError}; @@ -68,7 +67,7 @@ impl RouterModel for RouterModelV1 { ChatCompletionsRequest { model: self.routing_model.clone(), messages: vec![Message { - content: Some(message), + content: Some(ContentType::Text(message)), role: USER_ROLE.to_string(), model: None, tool_calls: None, @@ -86,10 +85,6 @@ impl RouterModel for RouterModelV1 { return Ok(None); } let router_resp_fixed = fix_json_response(content); - info!( - "router response (fixed): {}", - router_resp_fixed.replace("\n", "\\n") - ); let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?; let selected_llm = router_response.route.unwrap_or_default().to_string(); @@ -100,6 +95,10 @@ impl RouterModel for RouterModelV1 { Ok(Some(selected_llm)) } + + fn get_model_name(&self) -> String { + self.routing_model.clone() + } } fn fix_json_response(body: &str) -> String { @@ -172,22 +171,28 @@ user: "seattle" let messages = vec![ Message { role: "system".to_string(), - content: Some("You are a helpful assistant.".to_string()), + content: Some(ContentType::Text( + "You are a helpful assistant.".to_string(), + )), ..Default::default() }, Message { role: "user".to_string(), - content: Some("Hello, I want to book a flight.".to_string()), + content: Some(ContentType::Text( + "Hello, I want to book a flight.".to_string(), + )), ..Default::default() }, Message { role: "assistant".to_string(), - content: Some("Sure, where would you like to go?".to_string()), + content: Some(ContentType::Text( + "Sure, where would you like to go?".to_string(), + )), ..Default::default() }, Message { role: "user".to_string(), - content: Some("seattle".to_string()), + content: Some(ContentType::Text("seattle".to_string())), ..Default::default() }, ]; @@ -198,7 +203,7 @@ user: "seattle" println!("Prompt: {}", prompt); - assert_eq!(expected_prompt, prompt); + assert_eq!(expected_prompt, prompt.to_string()); } } diff --git a/crates/common/src/api/hallucination.rs b/crates/common/src/api/hallucination.rs index c0efc198..a7caba67 100644 --- a/crates/common/src/api/hallucination.rs +++ b/crates/common/src/api/hallucination.rs @@ -6,6 +6,8 @@ use crate::{ }; use serde::{Deserialize, Serialize}; +use super::open_ai::ContentType; + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct HallucinationClassificationRequest { pub prompt: String, @@ -21,7 +23,7 @@ pub struct HallucinationClassificationResponse { pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec { let mut arch_assistant = false; - let mut user_messages = Vec::new(); + let mut user_messages: Vec = Vec::new(); if messages.len() >= 2 { let latest_assistant_message = &messages[messages.len() - 2]; if let Some(model) = latest_assistant_message.model.as_ref() { @@ -35,7 +37,7 @@ pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec { if let Some(model) = message.model.as_ref() { if !model.starts_with(ARCH_MODEL_PREFIX) { if let Some(content) = &message.content { - if !content.starts_with(HALLUCINATION_TEMPLATE) { + if !content.to_string().starts_with(HALLUCINATION_TEMPLATE) { break; } } @@ -43,13 +45,13 @@ pub fn extract_messages_for_hallucination(messages: &[Message]) -> Vec { } if message.role == USER_ROLE { if let Some(content) = &message.content { - user_messages.push(content.clone()); + user_messages.push(content.to_string()); } } } } else if let Some(message) = messages.last() { if let Some(content) = &message.content { - user_messages.push(content.clone()); + user_messages.push(content.to_string()); } } user_messages.reverse(); // Reverse to maintain the original order diff --git a/crates/common/src/api/open_ai.rs b/crates/common/src/api/open_ai.rs index 6db30190..b059ecad 100644 --- a/crates/common/src/api/open_ai.rs +++ b/crates/common/src/api/open_ai.rs @@ -1,6 +1,7 @@ use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE}; use serde::{ser::SerializeMap, Deserialize, Serialize}; use serde_yaml::Value; +use core::panic; use std::{ collections::{HashMap, VecDeque}, fmt::Display, @@ -154,12 +155,54 @@ pub struct StreamOptions { pub include_usage: bool, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum MultiPartContentType { + #[serde(rename = "text")] + Text, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct MultiPartContent { + pub text: Option, + #[serde(rename = "type")] + pub content_type: MultiPartContentType, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum ContentType { + Text(String), + MultiPart(Vec), +} + +impl Display for ContentType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ContentType::Text(text) => write!(f, "{}", text), + ContentType::MultiPart(multi_part) => { + let text_parts: Vec = multi_part + .iter() + .filter_map(|part| { + if part.content_type == MultiPartContentType::Text { + part.text.clone() + } else { + panic!("Unsupported content type: {:?}", part.content_type); + } + }) + .collect(); + let combined_text = text_parts.join("\n"); + write!(f, "{}", combined_text) + } + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { pub role: String, #[serde(skip_serializing_if = "Option::is_none")] - pub content: Option, + pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, @@ -237,7 +280,7 @@ impl ChatCompletionsResponse { choices: vec![Choice { message: Message { role: ASSISTANT_ROLE.to_string(), - content: Some(message), + content: Some(ContentType::Text(message)), model: Some(ARCH_FC_MODEL_NAME.to_string()), tool_calls: None, tool_call_id: None, @@ -379,6 +422,8 @@ pub fn to_server_events(chunks: Vec) -> String { #[cfg(test)] mod test { + use crate::api::open_ai::{ChatCompletionsRequest, ContentType, MultiPartContentType}; + use super::{ChatCompletionStreamResponseServerEvents, Message}; use pretty_assertions::assert_eq; use std::collections::HashMap; @@ -448,7 +493,9 @@ mod test { model: "gpt-3.5-turbo".to_string(), messages: vec![Message { role: "user".to_string(), - content: Some("What city do you want to know the weather for?".to_string()), + content: Some(ContentType::Text( + "What city do you want to know the weather for?".to_string(), + )), model: None, tool_calls: None, tool_call_id: None, @@ -679,6 +726,111 @@ data: [DONE] ); } + #[test] + fn test_chat_completions_request() { + const CHAT_COMPLETIONS_REQUEST: &str = r#" +{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "What city do you want to know the weather for?" + } + ] +}"#; + + let chat_completions_request: ChatCompletionsRequest = + serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap(); + assert_eq!(chat_completions_request.model, "gpt-3.5-turbo"); + assert_eq!( + chat_completions_request.messages[0].content, + Some(ContentType::Text( + "What city do you want to know the weather for?".to_string() + )) + ); + } + + #[test] + fn test_chat_completions_request_text_type() { + const CHAT_COMPLETIONS_REQUEST: &str = r#" +{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What city do you want to know the weather for?" + } + ] + } + ] +} +"#; + + let chat_completions_request: ChatCompletionsRequest = + serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap(); + assert_eq!(chat_completions_request.model, "gpt-3.5-turbo"); + if let Some(ContentType::MultiPart(multi_part_content)) = + chat_completions_request.messages[0].content.as_ref() + { + assert_eq!(multi_part_content[0].content_type, MultiPartContentType::Text); + assert_eq!( + multi_part_content[0].text, + Some("What city do you want to know the weather for?".to_string()) + ); + } else { + panic!("Expected MultiPartContent"); + } + } + + #[test] + fn test_chat_completions_request_text_type_array() { + const CHAT_COMPLETIONS_REQUEST: &str = r#" +{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What city do you want to know the weather for?" + }, + { + "type": "text", + "text": "hello world" + } + ] + } + ] +} +"#; + + let chat_completions_request: ChatCompletionsRequest = + serde_json::from_str(CHAT_COMPLETIONS_REQUEST).unwrap(); + assert_eq!(chat_completions_request.model, "gpt-3.5-turbo"); + if let Some(ContentType::MultiPart(multi_part_content)) = + chat_completions_request.messages[0].content.as_ref() + { + assert_eq!(multi_part_content.len(), 2); + assert_eq!(multi_part_content[0].content_type, MultiPartContentType::Text); + assert_eq!( + multi_part_content[0].text, + Some("What city do you want to know the weather for?".to_string()) + ); + assert_eq!(multi_part_content[1].content_type, MultiPartContentType::Text); + assert_eq!( + multi_part_content[1].text, + Some("hello world".to_string()) + ); + } else { + panic!("Expected MultiPartContent"); + } + } + + #[test] fn stream_chunk_parse_claude() { const CHUNK_RESPONSE: &str = r#"data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"role":"assistant"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index dd86109a..7ca3a99b 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,7 +1,7 @@ use crate::metrics::Metrics; use common::api::open_ai::{ ChatCompletionStreamResponseServerEvents, ChatCompletionsRequest, ChatCompletionsResponse, - Message, StreamOptions, + ContentType, Message, StreamOptions, }; use common::configuration::{LlmProvider, LlmProviderType, Overrides}; use common::consts::{ @@ -369,7 +369,12 @@ impl HttpContext for StreamContext { .messages .iter() .fold(String::new(), |acc, m| { - acc + " " + m.content.as_ref().unwrap_or(&String::new()) + acc + " " + + m.content + .as_ref() + .unwrap_or(&ContentType::Text(String::new())) + .to_string() + .as_str() }); // enforce ratelimits on ingress if let Err(e) = self.enforce_ratelimits(&deserialized_body.model, input_tokens_str.as_str()) diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index c0e8df94..bb673208 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -237,22 +237,32 @@ impl HttpContext for StreamContext { Duration::from_secs(5), ); - let call_context = StreamCallContext { - response_handler_type: ResponseHandlerType::ArchFC, - user_message: self.user_prompt.as_ref().unwrap().content.clone(), - prompt_target_name: None, - request_body: self.chat_completions_request.as_ref().unwrap().clone(), - similarity_scores: None, - upstream_cluster: Some(ARCH_INTERNAL_CLUSTER_NAME.to_string()), - upstream_cluster_path: Some("/function_calling".to_string()), - }; + if let Some(content) = + self.user_prompt.as_ref().unwrap().content.as_ref() + { + let call_context = StreamCallContext { + response_handler_type: ResponseHandlerType::ArchFC, + user_message: Some(content.to_string()), + prompt_target_name: None, + request_body: self.chat_completions_request.as_ref().unwrap().clone(), + similarity_scores: None, + upstream_cluster: Some(ARCH_INTERNAL_CLUSTER_NAME.to_string()), + upstream_cluster_path: Some("/function_calling".to_string()), + }; - if let Err(e) = self.http_call(call_args, call_context) { - warn!("http_call failed: {:?}", e); - self.send_server_error(ServerError::HttpDispatch(e), None); + if let Err(e) = self.http_call(call_args, call_context) { + warn!("http_call failed: {:?}", e); + self.send_server_error(ServerError::HttpDispatch(e), None); + } + } else { + warn!("No content in the last user prompt"); + self.send_server_error( + ServerError::LogicError("No content in the last user prompt".to_string()), + None, + ); } - Action::Pause + } fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 0ceb4aa5..96caa378 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -2,7 +2,7 @@ use crate::metrics::Metrics; use crate::tools::compute_request_path_body; use common::api::open_ai::{ to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, - ChatCompletionsResponse, Message, ToolCall, + ChatCompletionsResponse, ContentType, Message, ToolCall, }; use common::configuration::{Endpoint, Overrides, PromptTarget, Tracing}; use common::consts::{ @@ -215,7 +215,7 @@ impl StreamContext { Some(system_prompt) => { let system_prompt_message = Message { role: SYSTEM_ROLE.to_string(), - content: Some(system_prompt.clone()), + content: Some(ContentType::Text(system_prompt.clone())), model: None, tool_calls: None, tool_call_id: None, @@ -279,6 +279,13 @@ impl StreamContext { //TODO: add resolver name to the response so the client can send the response back to the correct resolver let direct_response_str = if self.streaming_response { + let content = model_server_response.choices[0] + .message + .content + .as_ref() + .unwrap() + .clone(); + let chunks = vec![ ChatCompletionStreamResponse::new( self.arch_fc_response.clone(), @@ -287,14 +294,7 @@ impl StreamContext { None, ), ChatCompletionStreamResponse::new( - Some( - model_server_response.choices[0] - .message - .content - .as_ref() - .unwrap() - .clone(), - ), + Some(content.to_string()), None, Some(format!("{}-Chat", ARCH_FC_MODEL_NAME.to_owned())), None, @@ -542,7 +542,7 @@ impl StreamContext { messages.push({ Message { role: USER_ROLE.to_string(), - content: Some(final_prompt), + content: Some(ContentType::Text(final_prompt)), model: None, tool_calls: None, tool_call_id: None, @@ -612,7 +612,7 @@ impl StreamContext { if system_prompt.is_some() { let system_prompt_message = Message { role: SYSTEM_ROLE.to_string(), - content: system_prompt, + content: Some(ContentType::Text(system_prompt.unwrap())), model: None, tool_calls: None, tool_call_id: None, @@ -639,7 +639,9 @@ impl StreamContext { } else { Message { role: ASSISTANT_ROLE.to_string(), - content: self.arch_fc_response.as_ref().cloned(), + content: Some(ContentType::Text( + self.arch_fc_response.as_ref().unwrap().clone(), + )), model: Some(ARCH_FC_MODEL_NAME.to_string()), tool_calls: None, tool_call_id: None, @@ -650,7 +652,9 @@ impl StreamContext { pub fn generate_api_response_message(&mut self) -> Message { Message { role: TOOL_ROLE.to_string(), - content: self.tool_call_response.clone(), + content: Some(ContentType::Text( + self.tool_call_response.as_ref().unwrap().clone(), + )), model: None, tool_calls: None, tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()), @@ -688,7 +692,14 @@ impl StreamContext { None, ), ChatCompletionStreamResponse::new( - chat_completion_response.choices[0].message.content.clone(), + Some( + chat_completion_response.choices[0] + .message + .content + .as_ref() + .unwrap() + .to_string(), + ), None, Some(chat_completion_response.model.clone()), None, @@ -727,7 +738,7 @@ impl StreamContext { Some(system_prompt) => { let system_prompt_message = Message { role: SYSTEM_ROLE.to_string(), - content: Some(system_prompt.clone()), + content: Some(ContentType::Text(system_prompt.clone())), model: None, tool_calls: None, tool_call_id: None, @@ -748,7 +759,7 @@ impl StreamContext { let message = format!("{}\ncontext: {}", user_message.content.unwrap(), api_resp); messages.push(Message { role: USER_ROLE.to_string(), - content: Some(message), + content: Some(ContentType::Text(message)), model: None, tool_calls: None, tool_call_id: None, @@ -781,7 +792,7 @@ fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool .first() .and_then(|choice| choice.message.content.as_ref()); - let content_has_value = content.is_some() && !content.unwrap().is_empty(); + let content_has_value = content.is_some() && !content.unwrap().to_string().is_empty(); let tool_calls = model_server_response .choices @@ -807,7 +818,7 @@ impl Client for StreamContext { #[cfg(test)] mod test { - use common::api::open_ai::{ChatCompletionsResponse, Choice, Message, ToolCall}; + use common::api::open_ai::{ChatCompletionsResponse, Choice, ContentType, Message, ToolCall}; use crate::stream_context::check_intent_matched; @@ -816,7 +827,7 @@ mod test { let model_server_response = ChatCompletionsResponse { choices: vec![Choice { message: Message { - content: Some("".to_string()), + content: Some(ContentType::Text("".to_string())), tool_calls: Some(vec![]), role: "assistant".to_string(), model: None, @@ -835,7 +846,7 @@ mod test { let model_server_response = ChatCompletionsResponse { choices: vec![Choice { message: Message { - content: Some("hello".to_string()), + content: Some(ContentType::Text("hello".to_string())), tool_calls: Some(vec![]), role: "assistant".to_string(), model: None, @@ -854,7 +865,7 @@ mod test { let model_server_response = ChatCompletionsResponse { choices: vec![Choice { message: Message { - content: Some("".to_string()), + content: Some(ContentType::Text("".to_string())), tool_calls: Some(vec![ToolCall { id: "1".to_string(), function: common::api::open_ai::FunctionCallDetail { diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 91b36c01..563c9393 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -1,5 +1,5 @@ use common::api::open_ai::{ - ChatCompletionsResponse, Choice, FunctionCallDetail, Message, ToolCall, ToolType, Usage, + ChatCompletionsResponse, Choice, ContentType, FunctionCallDetail, Message, ToolCall, ToolType, Usage }; use common::configuration::Configuration; use http::StatusCode; @@ -431,7 +431,7 @@ fn prompt_gateway_request_to_llm_gateway() { index: Some(0), message: Message { role: "assistant".to_string(), - content: Some("hello from fake llm gateway".to_string()), + content: Some(ContentType::Text("hello from fake llm gateway".to_string())), model: None, tool_calls: None, tool_call_id: None, diff --git a/demos/use_cases/preference_based_routing/arch_config.yaml b/demos/use_cases/preference_based_routing/arch_config.yaml index eea4ef70..682527ca 100644 --- a/demos/use_cases/preference_based_routing/arch_config.yaml +++ b/demos/use_cases/preference_based_routing/arch_config.yaml @@ -14,26 +14,26 @@ llm_providers: - name: archgw-v1-router-model provider_interface: openai - model: cotran2/llama-1b-4-26 - base_url: http://35.192.87.187:8000/v1 - - - name: gpt-4o-mini - provider_interface: openai - access_key: $OPENAI_API_KEY - model: gpt-4o-mini - default: true + model: cotran2/llama-4-epoch + base_url: http://34.46.85.85:8000/v1 - name: gpt-4o provider_interface: openai access_key: $OPENAI_API_KEY model: gpt-4o - usage: Generating original content such as scripts, articles, or creative materials. + default: true - - name: o4-mini + - name: code_generation provider_interface: openai access_key: $OPENAI_API_KEY - model: o4-mini - usage: Requesting topic ideas specifically related to personal finance and budgeting. + model: gpt-4o + usage: Generating new code snippets, functions, or boilerplate based on user prompts or requirements + + - name: code_understanding + provider_interface: openai + access_key: $OPENAI_API_KEY + model: gpt-4.1 + usage: understand and explain existing code snippets, functions, or libraries tracing: random_sampling: 100 diff --git a/demos/use_cases/preference_based_routing/hurl_tests/simple.hurl b/demos/use_cases/preference_based_routing/hurl_tests/simple.hurl index 517767b6..51f4ac69 100644 --- a/demos/use_cases/preference_based_routing/hurl_tests/simple.hurl +++ b/demos/use_cases/preference_based_routing/hurl_tests/simple.hurl @@ -12,7 +12,7 @@ Content-Type: application/json HTTP 200 [Asserts] header "content-type" == "application/json" -jsonpath "$.model" matches /^o4-mini/ +jsonpath "$.model" matches /^gpt-4o/ jsonpath "$.usage" != null jsonpath "$.choices[0].message.content" != null jsonpath "$.choices[0].message.role" == "assistant" diff --git a/demos/use_cases/preference_based_routing/hurl_tests/simple_stream.hurl b/demos/use_cases/preference_based_routing/hurl_tests/simple_stream.hurl index d6b770f6..00cc9385 100644 --- a/demos/use_cases/preference_based_routing/hurl_tests/simple_stream.hurl +++ b/demos/use_cases/preference_based_routing/hurl_tests/simple_stream.hurl @@ -13,4 +13,4 @@ Content-Type: application/json HTTP 200 [Asserts] header "content-type" matches /text\/event-stream/ -body matches /^data: .*?o4-mini.*?\n/ +body matches /^data: .*?gpt-4o.*?\n/