supply per agent model

This commit is contained in:
Adil Hafeez 2025-03-21 15:43:04 -07:00
parent 96e857a682
commit 0ba7d73284
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
5 changed files with 38 additions and 27 deletions

View file

@ -101,7 +101,9 @@ impl RatelimitMap {
) -> Result<(), Error> {
trace!(
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
provider, selector, tokens_used
provider,
selector,
tokens_used
);
let provider_limits = match self.datastore.get(&provider) {

View file

@ -300,25 +300,31 @@ impl HttpContext for StreamContext {
.cloned();
let model_name = match self.llm_provider.as_ref() {
Some(llm_provider) => match llm_provider.model.as_ref() {
Some(model) => Some(model),
None => None,
},
Some(llm_provider) => llm_provider.model.as_ref(),
None => None,
};
let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
};
let model_requested = deserialized_body.model.clone();
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
deserialized_body.model = match model_name {
Some(model_name) => model_name.clone(),
None => {
self.send_server_error(
ServerError::BadRequest {
why: "No model specified in request and couldn't determine model name from arch_config".to_string(),
},
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
if use_agent_orchestrator {
"agent_orchestrator".to_string()
} else {
self.send_server_error(
ServerError::BadRequest {
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
},
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
}
}
}

View file

@ -45,7 +45,8 @@ impl HttpContext for StreamContext {
warn!("Need single endpoint when use_agent_orchestrator is set");
self.send_server_error(
ServerError::LogicError(
"Need single endpoint when use_agent_orchestrator is set".to_string(),
"Need single endpoint when use_agent_orchestrator is set"
.to_string(),
),
None,
);

View file

@ -427,7 +427,6 @@ impl StreamContext {
headers.insert(key.as_str(), value.as_str());
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&path,
@ -499,10 +498,7 @@ impl StreamContext {
}
};
if !prompt_target
.auto_llm_dispatch_on_response
.unwrap_or(true)
{
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) {
let tool_call_response = self.tool_call_response.as_ref().unwrap().clone();
let direct_response_str = if self.streaming_response {
@ -655,10 +651,7 @@ impl StreamContext {
.clone();
// check if the default target should be dispatched to the LLM provider
if !prompt_target
.auto_llm_dispatch_on_response
.unwrap_or(true)
{
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) {
let default_target_response_str = if self.streaming_response {
let chat_completion_response =
match serde_json::from_slice::<ChatCompletionsResponse>(&body) {