mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
use passed in model name in chat completion request (#445)
This commit is contained in:
parent
bd8004d1ae
commit
eb48f3d5bb
20 changed files with 364 additions and 89 deletions
|
|
@ -89,7 +89,7 @@ impl StreamContext {
|
|||
provider_hint,
|
||||
));
|
||||
|
||||
debug!(
|
||||
trace!(
|
||||
"request received: llm provider hint: {}, selected llm: {}, model: {}",
|
||||
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.unwrap_or_default(),
|
||||
|
|
@ -167,7 +167,7 @@ impl StreamContext {
|
|||
|
||||
// Check if rate limiting needs to be applied.
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
log::debug!("Applying ratelimit for model: {}", model);
|
||||
log::trace!("Applying ratelimit for model: {}", model);
|
||||
ratelimit::ratelimits(None).read().unwrap().check_limit(
|
||||
model.to_owned(),
|
||||
selector,
|
||||
|
|
@ -300,22 +300,45 @@ impl HttpContext for StreamContext {
|
|||
.cloned();
|
||||
|
||||
let model_name = match self.llm_provider.as_ref() {
|
||||
Some(llm_provider) => match llm_provider.model.as_ref() {
|
||||
Some(model) => model,
|
||||
None => "--",
|
||||
},
|
||||
None => "--",
|
||||
Some(llm_provider) => llm_provider.model.as_ref(),
|
||||
None => None,
|
||||
};
|
||||
|
||||
deserialized_body.model = model_name.to_string();
|
||||
let use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
|
||||
None => false,
|
||||
};
|
||||
|
||||
let model_requested = deserialized_body.model.clone();
|
||||
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
|
||||
deserialized_body.model = match model_name {
|
||||
Some(model_name) => model_name.clone(),
|
||||
None => {
|
||||
if use_agent_orchestrator {
|
||||
"agent_orchestrator".to_string()
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
"provider: {:?}, model requested: {}, model selected: {:?}",
|
||||
self.llm_provider().name,
|
||||
model_requested,
|
||||
model_name,
|
||||
);
|
||||
|
||||
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
|
||||
|
||||
trace!(
|
||||
"arch => {:?}, body: {}",
|
||||
deserialized_body.model,
|
||||
chat_completion_request_str
|
||||
);
|
||||
trace!("request body: {}", chat_completion_request_str);
|
||||
|
||||
if deserialized_body.stream {
|
||||
self.streaming_response = true;
|
||||
|
|
@ -528,22 +551,13 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
let mut model = chat_completions_chunk_response_events
|
||||
let model = chat_completions_chunk_response_events
|
||||
.events
|
||||
.first()
|
||||
.unwrap()
|
||||
.model
|
||||
.clone();
|
||||
let tokens_str = chat_completions_chunk_response_events.to_string();
|
||||
//HACK: add support for tokenizing mistral and other models
|
||||
//filed issue https://github.com/katanemo/arch/issues/222
|
||||
if !model.as_ref().unwrap().starts_with("gpt") {
|
||||
trace!(
|
||||
"tiktoken_rs: unsupported model: {}, using gpt-4 to compute token count",
|
||||
model.as_ref().unwrap()
|
||||
);
|
||||
}
|
||||
model = Some("gpt-4".to_string());
|
||||
|
||||
let token_count =
|
||||
match tokenizer::token_count(model.as_ref().unwrap().as_str(), tokens_str.as_str())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue