mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Make model field optional in request types, resolve from default provider (#768)
This commit is contained in:
parent
7b5f1549a5
commit
baeee56f6b
6 changed files with 61 additions and 11 deletions
|
|
@ -2,15 +2,17 @@ use std::sync::Arc;
|
|||
use std::time::Instant;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::llm_providers::LlmProviders;
|
||||
use hermesllm::apis::OpenAIMessage;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
use hermesllm::providers::request::ProviderRequest;
|
||||
use hermesllm::ProviderRequestType;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::{Request, Response};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use opentelemetry::trace::get_active_span;
|
||||
use serde::ser::Error as SerError;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use super::agent_selector::{AgentSelectionError, AgentSelector};
|
||||
|
|
@ -40,6 +42,7 @@ pub async fn agent_chat(
|
|||
_: String,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
// Extract request_id from headers or generate a new one
|
||||
let request_id: String = match request
|
||||
|
|
@ -71,6 +74,7 @@ pub async fn agent_chat(
|
|||
orchestrator_service,
|
||||
agents_list,
|
||||
listeners,
|
||||
llm_providers,
|
||||
request_id,
|
||||
)
|
||||
.await
|
||||
|
|
@ -155,6 +159,7 @@ async fn handle_agent_chat_inner(
|
|||
orchestrator_service: Arc<OrchestratorService>,
|
||||
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
|
||||
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
|
||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
request_id: String,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||
// Initialize services
|
||||
|
|
@ -221,16 +226,36 @@ async fn handle_agent_chat_inner(
|
|||
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
|
||||
})?;
|
||||
|
||||
let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("failed to parse request as ProviderRequestType: {}", err);
|
||||
let err_msg = format!("Failed to parse request: {}", err);
|
||||
return Err(AgentFilterChainError::RequestParsing(
|
||||
serde_json::Error::custom(err_msg),
|
||||
));
|
||||
let mut client_request =
|
||||
match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("failed to parse request as ProviderRequestType: {}", err);
|
||||
let err_msg = format!("Failed to parse request: {}", err);
|
||||
return Err(AgentFilterChainError::RequestParsing(
|
||||
serde_json::Error::custom(err_msg),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// If model is not specified in the request, resolve from default provider
|
||||
if client_request.model().is_empty() {
|
||||
match llm_providers.read().await.default() {
|
||||
Some(default_provider) => {
|
||||
let default_model = default_provider.name.clone();
|
||||
info!(default_model = %default_model, "no model specified in request, using default provider");
|
||||
client_request.set_model(default_model);
|
||||
}
|
||||
None => {
|
||||
let err_msg = "No model specified in request and no default provider configured";
|
||||
warn!("{}", err_msg);
|
||||
let mut bad_request =
|
||||
Response::new(ResponseHandler::create_full_body(err_msg.to_string()));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let message: Vec<OpenAIMessage> = client_request.get_messages();
|
||||
|
||||
|
|
|
|||
|
|
@ -150,9 +150,30 @@ async fn llm_chat_inner(
|
|||
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
|
||||
);
|
||||
|
||||
// If model is not specified in the request, resolve from default provider
|
||||
let model_from_request = client_request.model().to_string();
|
||||
let model_from_request = if model_from_request.is_empty() {
|
||||
match llm_providers.read().await.default() {
|
||||
Some(default_provider) => {
|
||||
let default_model = default_provider.name.clone();
|
||||
info!(default_model = %default_model, "no model specified in request, using default provider");
|
||||
client_request.set_model(default_model.clone());
|
||||
default_model
|
||||
}
|
||||
None => {
|
||||
let err_msg = "No model specified in request and no default provider configured";
|
||||
warn!("{}", err_msg);
|
||||
let mut bad_request = Response::new(full(err_msg.to_string()));
|
||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||
return Ok(bad_request);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
model_from_request
|
||||
};
|
||||
|
||||
// Model alias resolution: update model field in client_request immediately
|
||||
// This ensures all downstream objects use the resolved model
|
||||
let model_from_request = client_request.model().to_string();
|
||||
let temperature = client_request.get_temperature();
|
||||
let is_streaming_request = client_request.is_streaming();
|
||||
let alias_resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
|
||||
|
|
|
|||
|
|
@ -202,6 +202,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
fully_qualified_url,
|
||||
agents_list,
|
||||
listeners,
|
||||
llm_providers,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await;
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ pub struct McpServer {
|
|||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct MessagesRequest {
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
pub messages: Vec<MessagesMessage>,
|
||||
pub max_tokens: u32,
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ impl ApiDefinition for OpenAIApi {
|
|||
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
|
||||
pub struct ChatCompletionsRequest {
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
// pub audio: Option<Audio> // GOOD FIRST ISSUE: future support for audio input
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ impl TryFrom<&[u8]> for ResponsesAPIResponse {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResponsesAPIRequest {
|
||||
/// The model to use for generating the response
|
||||
#[serde(default)]
|
||||
pub model: String,
|
||||
|
||||
/// Text, image, or file inputs to the model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue