diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 5f31abf6..37c7af9c 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -81,7 +81,10 @@ def validate_and_render_schema(): updated_llm_providers = [] llm_provider_name_set = set() + llms_with_usage = [] for llm_provider in config_yaml["llm_providers"]: + if llm_provider.get("usage", None): + llms_with_usage.append(llm_provider["name"]) if llm_provider.get("name") in llm_provider_name_set: raise Exception( f"Duplicate llm_provider name {llm_provider.get('name')}, please provide unique name for each llm_provider" @@ -137,6 +140,15 @@ def validate_and_render_schema(): llm_provider["protocol"] = protocol llms_with_endpoint.append(llm_provider) + if ( + len(llms_with_usage) > 0 + and config_yaml.get("routing", {}).get("model", None) == None + ): + llms_with_usage_names = ", ".join(llms_with_usage) + raise Exception( + f"LLMs with usage found ({llms_with_usage_names}), please provide model in routing section in your arch_config.yaml file" + ) + config_yaml["llm_providers"] = updated_llm_providers arch_config_string = yaml.dump(config_yaml) diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index b758bdde..2c041648 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -84,12 +84,7 @@ async fn main() -> Result<(), Box> { info!("listening on http://{}", bind_address); let listener = TcpListener::bind(bind_address).await?; - // if routing is null then return gpt-4o as model name - //TODO: fail if routing is null - let model = arch_config - .routing - .as_ref() - .map_or_else(|| "gpt-4o".to_string(), |routing| routing.model.clone()); + let model = arch_config.routing.as_ref().unwrap().model.clone(); let router_service: Arc = Arc::new(RouterService::new( arch_config.llm_providers.clone(), diff --git a/crates/common/src/api/hallucination.rs b/crates/common/src/api/hallucination.rs index a7caba67..e90ea165 100644 --- a/crates/common/src/api/hallucination.rs +++ b/crates/common/src/api/hallucination.rs @@ -6,7 +6,6 @@ use crate::{ }; use serde::{Deserialize, Serialize}; -use super::open_ai::ContentType; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct HallucinationClassificationRequest {