use passed in model name in chat completion request (#445)

This commit is contained in:
Adil Hafeez 2025-03-21 15:56:17 -07:00 committed by GitHub
parent bd8004d1ae
commit eb48f3d5bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 364 additions and 89 deletions

View file

@ -89,7 +89,7 @@ impl StreamContext {
provider_hint,
));
debug!(
trace!(
"request received: llm provider hint: {}, selected llm: {}, model: {}",
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.unwrap_or_default(),
@ -167,7 +167,7 @@ impl StreamContext {
// Check if rate limiting needs to be applied.
if let Some(selector) = self.ratelimit_selector.take() {
log::debug!("Applying ratelimit for model: {}", model);
log::trace!("Applying ratelimit for model: {}", model);
ratelimit::ratelimits(None).read().unwrap().check_limit(
model.to_owned(),
selector,
@ -300,22 +300,45 @@ impl HttpContext for StreamContext {
.cloned();
let model_name = match self.llm_provider.as_ref() {
Some(llm_provider) => match llm_provider.model.as_ref() {
Some(model) => model,
None => "--",
},
None => "--",
Some(llm_provider) => llm_provider.model.as_ref(),
None => None,
};
deserialized_body.model = model_name.to_string();
let use_agent_orchestrator = match self.overrides.as_ref() {
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
None => false,
};
let model_requested = deserialized_body.model.clone();
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
deserialized_body.model = match model_name {
Some(model_name) => model_name.clone(),
None => {
if use_agent_orchestrator {
"agent_orchestrator".to_string()
} else {
self.send_server_error(
ServerError::BadRequest {
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
},
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
}
}
}
debug!(
"provider: {:?}, model requested: {}, model selected: {:?}",
self.llm_provider().name,
model_requested,
model_name,
);
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
trace!(
"arch => {:?}, body: {}",
deserialized_body.model,
chat_completion_request_str
);
trace!("request body: {}", chat_completion_request_str);
if deserialized_body.stream {
self.streaming_response = true;
@ -528,22 +551,13 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
let mut model = chat_completions_chunk_response_events
let model = chat_completions_chunk_response_events
.events
.first()
.unwrap()
.model
.clone();
let tokens_str = chat_completions_chunk_response_events.to_string();
//HACK: add support for tokenizing mistral and other models
//filed issue https://github.com/katanemo/arch/issues/222
if !model.as_ref().unwrap().starts_with("gpt") {
trace!(
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
model.as_ref().unwrap()
);
}
model = Some("gpt-4".to_string());
let token_count =
match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str())