diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index 53ec8e74..4212674c 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -239,6 +239,8 @@ properties: routing: type: object properties: + llm_provider: + type: string model: type: string additionalProperties: false diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 37c7af9c..4f4249fb 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -140,14 +140,20 @@ def validate_and_render_schema(): llm_provider["protocol"] = protocol llms_with_endpoint.append(llm_provider) - if ( - len(llms_with_usage) > 0 - and config_yaml.get("routing", {}).get("model", None) == None - ): - llms_with_usage_names = ", ".join(llms_with_usage) - raise Exception( - f"LLMs with usage found ({llms_with_usage_names}), please provide model in routing section in your arch_config.yaml file" - ) + if len(llms_with_usage) > 0: + routing_llm_provider = config_yaml.get("routing", {}).get("llm_provider", None) + if routing_llm_provider and routing_llm_provider not in llm_provider_name_set: + raise Exception( + f"Routing llm_provider {routing_llm_provider} is not defined in llm_providers" + ) + if routing_llm_provider is None and "arch-router" not in llm_provider_name_set: + updated_llm_providers.append( + { + "name": "arch-router", + "provider_interface": "arch", + "model": config_yaml.get("routing", {}).get("model", "Arch-Router"), + } + ) config_yaml["llm_providers"] = updated_llm_providers diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 25ea72ff..05944a5f 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -23,6 +23,8 @@ use tracing::{debug, info, warn}; pub mod router; const BIND_ADDRESS: &str = "0.0.0.0:9091"; +const DEFAULT_ROUTING_LLM_PROVIDER: &str = "arch-router"; +const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router"; // Utility function to extract the context from the incoming request headers fn extract_context_from_request(req: &Request) -> Context { @@ -69,16 +71,23 @@ async fn main() -> Result<(), Box> { info!("listening on http://{}", bind_address); let listener = TcpListener::bind(bind_address).await?; - let model = arch_config + let routing_model_name: String = arch_config .routing .as_ref() - .map(|r| r.model.clone()) - .unwrap_or_else(|| "none".to_string()); + .and_then(|r| r.model.clone()) + .unwrap_or_else(|| DEFAULT_ROUTING_MODEL_NAME.to_string()); + + let routing_llm_provider = arch_config + .routing + .as_ref() + .and_then(|r| r.llm_provider.clone()) + .unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string()); let router_service: Arc = Arc::new(RouterService::new( arch_config.llm_providers.clone(), llm_provider_endpoint.clone(), - model, + routing_model_name, + routing_llm_provider, )); loop { diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index 0dab0a18..d4173b01 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -17,7 +17,7 @@ pub struct RouterService { router_url: String, client: reqwest::Client, router_model: Arc, - routing_model_name: String, + routing_provider_name: String, llm_usage_defined: bool, llm_provider_map: HashMap, } @@ -41,6 +41,7 @@ impl RouterService { providers: Vec, router_url: String, routing_model_name: String, + routing_provider_name: String, ) -> Self { let providers_with_usage = providers .iter() @@ -65,7 +66,7 @@ impl RouterService { router_url, client: reqwest::Client::new(), router_model, - routing_model_name, + routing_provider_name, llm_usage_defined: !providers_with_usage.is_empty(), llm_provider_map, } @@ -104,7 +105,7 @@ impl RouterService { llm_route_request_headers.insert( header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), - header::HeaderValue::from_str(&self.routing_model_name).unwrap(), + header::HeaderValue::from_str(&self.routing_provider_name).unwrap(), ); if let Some(trace_parent) = trace_parent { diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 80ec98bb..d92f38fb 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -10,7 +10,8 @@ use crate::api::open_ai::{ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Routing { - pub model: String, + pub llm_provider: Option, + pub model: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 6eebb398..2fa29496 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -247,7 +247,10 @@ impl HttpContext for StreamContext { } if let Err(error) = self.modify_auth_headers() { // ensure that the provider has an endpoint if the access key is missing else return a bad request - if self.llm_provider.as_ref().unwrap().endpoint.is_none() && !use_agent_orchestrator + if self.llm_provider.as_ref().unwrap().endpoint.is_none() + && !use_agent_orchestrator + && self.llm_provider.as_ref().unwrap().provider_interface + != LlmProviderType::Arch { self.send_server_error(error, Some(StatusCode::BAD_REQUEST)); } diff --git a/demos/use_cases/preference_based_routing/arch_config.yaml b/demos/use_cases/preference_based_routing/arch_config.yaml index 75a43d77..f8521811 100644 --- a/demos/use_cases/preference_based_routing/arch_config.yaml +++ b/demos/use_cases/preference_based_routing/arch_config.yaml @@ -1,8 +1,5 @@ version: v0.1.0 -routing: - model: arch-router - listeners: egress_traffic: address: 0.0.0.0 @@ -12,11 +9,6 @@ listeners: llm_providers: - - name: arch-router - access_key: $ARCH_API_KEY - provider_interface: arch - model: Arch-Router - - name: gpt-4o-mini provider_interface: openai access_key: $OPENAI_API_KEY diff --git a/demos/use_cases/preference_based_routing/arch_config_local.yaml b/demos/use_cases/preference_based_routing/arch_config_local.yaml index 607d180a..029918d0 100644 --- a/demos/use_cases/preference_based_routing/arch_config_local.yaml +++ b/demos/use_cases/preference_based_routing/arch_config_local.yaml @@ -1,7 +1,8 @@ version: v0.1.0 routing: - model: arch-router + model: Arch-Router + llm_provider: arch-router listeners: egress_traffic: @@ -13,7 +14,6 @@ listeners: llm_providers: - name: arch-router - access_key: $ARCH_API_KEY provider_interface: arch model: hf.co/katanemo/Arch-Router-1.5B.gguf:Q4_K_M endpoint: host.docker.internal:11434 diff --git a/demos/use_cases/preference_based_routing/test_router_endpoint.rest b/demos/use_cases/preference_based_routing/test_router_endpoint.rest index f141822a..bb0efcc2 100644 --- a/demos/use_cases/preference_based_routing/test_router_endpoint.rest +++ b/demos/use_cases/preference_based_routing/test_router_endpoint.rest @@ -22,8 +22,8 @@ Content-Type: application/json ### get model list from arch-function GET https://archfc.katanemo.dev/v1/models HTTP/1.1 -model: arch-router +model: Arch-Router -### get model list from arch-router (notice model header) +### get model list from Arch-Router (notice model header) GET https://archfc.katanemo.dev/v1/models HTTP/1.1 -model: arch-router +model: Arch-Router