diff --git a/arch/tools/cli/docker_cli.py b/arch/tools/cli/docker_cli.py index 0cb058ad..e8a12a13 100644 --- a/arch/tools/cli/docker_cli.py +++ b/arch/tools/cli/docker_cli.py @@ -52,6 +52,7 @@ def docker_start_archgw_detached( port_mappings = [ f"{prompt_gateway_port}:{prompt_gateway_port}", f"{llm_gateway_port}:{llm_gateway_port}", + f"{llm_gateway_port+1}:{llm_gateway_port+1}", "19901:9901", ] port_mappings_args = [item for port in port_mappings for item in ("-p", port)] diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index 50e65915..cb3094fc 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -3,7 +3,6 @@ use std::sync::Arc; use bytes::Bytes; use common::api::open_ai::ChatCompletionsRequest; use common::consts::ARCH_PROVIDER_HINT_HEADER; -use common::utils::shorten_string; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::Frame; @@ -39,7 +38,7 @@ pub async fn chat_completions( let v: Value = serde_json::from_slice(&chat_request_bytes).unwrap(); let err_msg = format!("Failed to parse request body: {}", err); warn!("{}", err_msg); - warn!("request body: {}", v.to_string()); + warn!("arch-router request body: {}", v.to_string()); let mut bad_request = Response::new(full(err_msg)); *bad_request.status_mut() = StatusCode::BAD_REQUEST; return Ok(bad_request); @@ -47,8 +46,8 @@ pub async fn chat_completions( }; debug!( - "request body: {}", - shorten_string(&serde_json::to_string(&chat_completion_request).unwrap()) + "arch-router request body: {}", + &serde_json::to_string(&chat_completion_request).unwrap() ); let trace_parent = request_headers @@ -56,7 +55,7 @@ pub async fn chat_completions( .find(|(ty, _)| ty.as_str() == "traceparent") .map(|(_, value)| value.to_str().unwrap_or_default().to_string()); - let selected_llm = match router_service + let mut selected_llm = match router_service .determine_route(&chat_completion_request.messages, trace_parent.clone()) .await { @@ -69,6 +68,11 @@ pub async fn chat_completions( } }; + if selected_llm.is_none() { + debug!("No LLM model selected, using default from request"); + selected_llm = Some(chat_completion_request.model.clone()); + } + info!( "sending request to llm provider: {} with llm model: {:?}", llm_provider_endpoint, selected_llm diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index dcf38982..6de38b5b 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -1 +1,2 @@ pub mod chat_completions; +pub mod models; diff --git a/crates/brightstaff/src/handlers/models.rs b/crates/brightstaff/src/handlers/models.rs new file mode 100644 index 00000000..7f3427a0 --- /dev/null +++ b/crates/brightstaff/src/handlers/models.rs @@ -0,0 +1,40 @@ +use bytes::Bytes; +use common::api::open_ai::Models; +use common::configuration::LlmProvider; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use hyper::{Response, StatusCode}; +use serde_json; +use std::sync::Arc; + +pub async fn list_models( + llm_providers: Arc>, +) -> Response> { + let prov = llm_providers.clone(); + let providers = (*prov).clone(); + let openai_models = Models::from(providers); + + match serde_json::to_string(&openai_models) { + Ok(json) => { + let body = Full::new(Bytes::from(json)) + .map_err(|never| match never {}) + .boxed(); + Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(body) + .unwrap() + } + Err(_) => { + let body = Full::new(Bytes::from_static( + b"{\"error\":\"Failed to serialize models\"}", + )) + .map_err(|never| match never {}) + .boxed(); + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header("Content-Type", "application/json") + .body(body) + .unwrap() + } + } +} diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index c19e86db..8eb2d7e2 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -1,4 +1,5 @@ use brightstaff::handlers::chat_completions::chat_completions; +use brightstaff::handlers::models::list_models; use brightstaff::router::llm_router::RouterService; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; @@ -52,6 +53,8 @@ async fn main() -> Result<(), Box> { let arch_config = Arc::new(config); + let llm_providers = Arc::new(arch_config.llm_providers.clone()); + debug!( "arch_config: {:?}", &serde_json::to_string(arch_config.as_ref()).unwrap() @@ -84,10 +87,12 @@ async fn main() -> Result<(), Box> { let router_service = Arc::clone(&router_service); let llm_provider_endpoint = llm_provider_endpoint.clone(); + let llm_providers = llm_providers.clone(); let service = service_fn(move |req| { let router_service = Arc::clone(&router_service); let parent_cx = extract_context_from_request(&req); let llm_provider_endpoint = llm_provider_endpoint.clone(); + let llm_providers = llm_providers.clone(); async move { match (req.method(), req.uri().path()) { @@ -96,6 +101,35 @@ async fn main() -> Result<(), Box> { .with_context(parent_cx) .await } + (&Method::GET, "/v1/models") => { + Ok(list_models(llm_providers).await) + } + (&Method::OPTIONS, "/v1/models") => { + let mut response = Response::new(empty()); + *response.status_mut() = StatusCode::NO_CONTENT; + response.headers_mut().insert( + "Allow", + "GET, OPTIONS".parse().unwrap(), + ); + response.headers_mut().insert( + "Access-Control-Allow-Origin", + "*".parse().unwrap(), + ); + response.headers_mut().insert( + "Access-Control-Allow-Headers", + "Authorization, Content-Type".parse().unwrap(), + ); + response.headers_mut().insert( + "Access-Control-Allow-Methods", + "GET, POST, OPTIONS".parse().unwrap(), + ); + response.headers_mut().insert( + "Content-Type", + "application/json".parse().unwrap(), + ); + + Ok(response) + } _ => { let mut not_found = Response::new(empty()); *not_found.status_mut() = StatusCode::NOT_FOUND; diff --git a/crates/common/src/api/open_ai.rs b/crates/common/src/api/open_ai.rs index b059ecad..7c492798 100644 --- a/crates/common/src/api/open_ai.rs +++ b/crates/common/src/api/open_ai.rs @@ -1,7 +1,10 @@ -use crate::consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE}; +use crate::{ + configuration::LlmProvider, + consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE}, +}; +use core::{panic, str}; use serde::{ser::SerializeMap, Deserialize, Serialize}; use serde_yaml::Value; -use core::panic; use std::{ collections::{HashMap, VecDeque}, fmt::Display, @@ -420,6 +423,45 @@ pub fn to_server_events(chunks: Vec) -> String { response_str } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelDetail { + pub id: String, + pub object: String, + pub created: usize, + pub owned_by: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ModelObject { + #[serde(rename = "list")] + List, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Models { + pub object: ModelObject, + pub data: Vec, +} + +impl From> for Models { + fn from(llm_providers: Vec) -> Self { + let data = llm_providers + .iter() + .map(|provider| ModelDetail { + id: provider.model.as_ref().unwrap().clone(), + object: "model".to_string(), + created: 1721172741, + owned_by: "system".to_string(), + }) + .collect(); + + Models { + object: ModelObject::List, + data, + } + } +} + #[cfg(test)] mod test { use crate::api::open_ai::{ChatCompletionsRequest, ContentType, MultiPartContentType}; @@ -775,7 +817,10 @@ data: [DONE] if let Some(ContentType::MultiPart(multi_part_content)) = chat_completions_request.messages[0].content.as_ref() { - assert_eq!(multi_part_content[0].content_type, MultiPartContentType::Text); + assert_eq!( + multi_part_content[0].content_type, + MultiPartContentType::Text + ); assert_eq!( multi_part_content[0].text, Some("What city do you want to know the weather for?".to_string()) @@ -815,22 +860,24 @@ data: [DONE] chat_completions_request.messages[0].content.as_ref() { assert_eq!(multi_part_content.len(), 2); - assert_eq!(multi_part_content[0].content_type, MultiPartContentType::Text); + assert_eq!( + multi_part_content[0].content_type, + MultiPartContentType::Text + ); assert_eq!( multi_part_content[0].text, Some("What city do you want to know the weather for?".to_string()) ); - assert_eq!(multi_part_content[1].content_type, MultiPartContentType::Text); assert_eq!( - multi_part_content[1].text, - Some("hello world".to_string()) + multi_part_content[1].content_type, + MultiPartContentType::Text ); + assert_eq!(multi_part_content[1].text, Some("hello world".to_string())); } else { panic!("Expected MultiPartContent"); } } - #[test] fn stream_chunk_parse_claude() { const CHUNK_RESPONSE: &str = r#"data: {"id":"msg_01DZDMxYSgq8aPQxMQoBv6Kb","choices":[{"index":0,"delta":{"role":"assistant"}}],"created":1747685264,"model":"claude-3-7-sonnet-latest","object":"chat.completion.chunk"} diff --git a/demos/use_cases/llm_routing/arch_config.yaml b/demos/use_cases/llm_routing/arch_config.yaml index d0074642..896fb795 100644 --- a/demos/use_cases/llm_routing/arch_config.yaml +++ b/demos/use_cases/llm_routing/arch_config.yaml @@ -8,6 +8,7 @@ listeners: timeout: 30s llm_providers: + - name: gpt-4o-mini access_key: $OPENAI_API_KEY provider_interface: openai @@ -17,6 +18,7 @@ llm_providers: access_key: $OPENAI_API_KEY provider_interface: openai model: gpt-4o + default: true - name: ministral-3b access_key: $MISTRAL_API_KEY @@ -27,7 +29,6 @@ llm_providers: access_key: $ANTHROPY_API_KEY provider_interface: claude model: claude-3-7-sonnet-latest - default: true - name: claude-sonnet-4 access_key: $ANTHROPY_API_KEY diff --git a/demos/use_cases/llm_routing/docker-compose.yaml b/demos/use_cases/llm_routing/docker-compose.yaml index c2d794c6..6c5d2b68 100644 --- a/demos/use_cases/llm_routing/docker-compose.yaml +++ b/demos/use_cases/llm_routing/docker-compose.yaml @@ -1,17 +1,15 @@ services: - chatbot_ui: - build: - context: ../../shared/chatbot_ui - dockerfile: Dockerfile + + open-web-ui: + image: ghcr.io/open-webui/open-webui:main + restart: always ports: - - "18080:8080" + - "8080:8080" environment: - - CHAT_COMPLETION_ENDPOINT=http://host.docker.internal:12000/v1 - extra_hosts: - - "host.docker.internal:host-gateway" - volumes: - - ./arch_config.yaml:/app/arch_config.yaml + - DEFAULT_MODEL=gpt-4o-mini + - ENABLE_OPENAI_API=true + - OPENAI_API_BASE_URL=http://host.docker.internal:12000/v1 jaeger: build: