diff --git a/crates/common/src/ratelimit.rs b/crates/common/src/ratelimit.rs index 39a79b9d..8825fd01 100644 --- a/crates/common/src/ratelimit.rs +++ b/crates/common/src/ratelimit.rs @@ -101,7 +101,9 @@ impl RatelimitMap { ) -> Result<(), Error> { trace!( "Checking limit for provider={}, with selector={:?}, consuming tokens={:?}", - provider, selector, tokens_used + provider, + selector, + tokens_used ); let provider_limits = match self.datastore.get(&provider) { diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 9f4b9d99..20ca9d62 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -300,25 +300,31 @@ impl HttpContext for StreamContext { .cloned(); let model_name = match self.llm_provider.as_ref() { - Some(llm_provider) => match llm_provider.model.as_ref() { - Some(model) => Some(model), - None => None, - }, + Some(llm_provider) => llm_provider.model.as_ref(), None => None, }; + let use_agent_orchestrator = match self.overrides.as_ref() { + Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(), + None => false, + }; + let model_requested = deserialized_body.model.clone(); if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" { deserialized_body.model = match model_name { Some(model_name) => model_name.clone(), None => { - self.send_server_error( - ServerError::BadRequest { - why: "No model specified in request and couldn't determine model name from arch_config".to_string(), - }, - Some(StatusCode::BAD_REQUEST), - ); - return Action::Continue; + if use_agent_orchestrator { + "agent_orchestrator".to_string() + } else { + self.send_server_error( + ServerError::BadRequest { + why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(), + }, + Some(StatusCode::BAD_REQUEST), + ); + return Action::Continue; + } } } } diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index df0d4748..9580e934 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -45,7 +45,8 @@ impl HttpContext for StreamContext { warn!("Need single endpoint when use_agent_orchestrator is set"); self.send_server_error( ServerError::LogicError( - "Need single endpoint when use_agent_orchestrator is set".to_string(), + "Need single endpoint when use_agent_orchestrator is set" + .to_string(), ), None, ); diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index fddb3b20..4a968d43 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -427,7 +427,6 @@ impl StreamContext { headers.insert(key.as_str(), value.as_str()); } - let call_args = CallArgs::new( ARCH_INTERNAL_CLUSTER_NAME, &path, @@ -499,10 +498,7 @@ impl StreamContext { } }; - if !prompt_target - .auto_llm_dispatch_on_response - .unwrap_or(true) - { + if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) { let tool_call_response = self.tool_call_response.as_ref().unwrap().clone(); let direct_response_str = if self.streaming_response { @@ -655,10 +651,7 @@ impl StreamContext { .clone(); // check if the default target should be dispatched to the LLM provider - if !prompt_target - .auto_llm_dispatch_on_response - .unwrap_or(true) - { + if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) { let default_target_response_str = if self.streaming_response { let chat_completion_response = match serde_json::from_slice::(&body) { diff --git a/demos/use_cases/orchestrating_agents/main.py b/demos/use_cases/orchestrating_agents/main.py index 27a9598a..b453e9f2 100644 --- a/demos/use_cases/orchestrating_agents/main.py +++ b/demos/use_cases/orchestrating_agents/main.py @@ -31,9 +31,10 @@ openai_client = openai.OpenAI( ) -def call_openai(messages: List[Dict[str, str]], stream: bool): +def call_openai(messages: List[Dict[str, str]], stream: bool, model: str): + logger.info(f"llm agent model: {model}") completion = openai_client.chat.completions.create( - model="None", # archgw picks the default LLM configured in the config file + model=model, messages=messages, stream=stream, ) @@ -53,14 +54,19 @@ def call_openai(messages: List[Dict[str, str]], stream: bool): class Agent: - def __init__(self, role: str, instructions: str): + def __init__(self, role: str, instructions: str, model: str = ""): + self.model = model self.system_prompt = f"You are a {role}.\n{instructions}" def handle(self, req: ChatCompletionsRequest): messages = [{"role": "system", "content": self.get_system_prompt()}] + [ message.model_dump() for message in req.messages ] - return call_openai(messages, req.stream) + + model = req.model + if self.model: + model = self.model + return call_openai(messages, req.stream, model) def get_system_prompt(self) -> str: return self.system_prompt @@ -77,13 +83,16 @@ AGENTS = { "2. Quote ridiculous price\n" "3. Reveal caveat if user agrees." ), + model="gpt-4o-mini", ), "issues_and_repairs": Agent( role="issues and repairs agent", instructions="Propose a solution, offer refund if necessary.", + model="gpt-4o", ), "escalate_to_human": Agent( - role="human escalation agent", instructions="Escalate issues to a human." + role="human escalation agent", + instructions="Escalate issues to a human.", ), "unknown_agent": Agent( role="general assistant", instructions="Assist the user in general queries."