mirror of
https://github.com/katanemo/plano.git
synced 2026-06-23 15:38:07 +02:00
Merge branch 'main' into adil/agent_format
This commit is contained in:
commit
c1757bec88
26 changed files with 864 additions and 188 deletions
|
|
@ -1,7 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::ModelUsagePreference;
|
||||
use common::configuration::{ModelAlias, ModelUsagePreference};
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
|
|
@ -28,6 +28,7 @@ pub async fn chat(
|
|||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
|
|
@ -35,6 +36,7 @@ pub async fn chat(
|
|||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
|
||||
|
||||
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
|
|
@ -46,6 +48,24 @@ pub async fn chat(
|
|||
}
|
||||
};
|
||||
|
||||
// Model alias resolution: update model field in client_request immediately
|
||||
// This ensures all downstream objects use the resolved model
|
||||
let model_from_request = client_request.model().to_string();
|
||||
let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() {
|
||||
if let Some(model_alias) = model_aliases.get(&model_from_request) {
|
||||
debug!(
|
||||
"Model Alias: 'From {}' -> 'To{}'",
|
||||
model_from_request, model_alias.target
|
||||
);
|
||||
model_alias.target.clone()
|
||||
} else {
|
||||
model_from_request.clone()
|
||||
}
|
||||
} else {
|
||||
model_from_request.clone()
|
||||
};
|
||||
client_request.set_model(resolved_model.clone());
|
||||
|
||||
// Clone metadata for routing and remove archgw_preference_config from original
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
||||
|
|
@ -77,7 +97,7 @@ pub async fn chat(
|
|||
};
|
||||
|
||||
debug!(
|
||||
"[BRIGHTSTAFF -> ARCH_ROUTER] REQ: {}",
|
||||
"[ARCH_ROUTER REQ]: {}",
|
||||
&serde_json::to_string(&chat_completions_request_for_arch_router).unwrap()
|
||||
);
|
||||
|
||||
|
|
@ -132,11 +152,12 @@ pub async fn chat(
|
|||
Ok(route) => match route {
|
||||
Some((_, model_name)) => model_name,
|
||||
None => {
|
||||
debug!(
|
||||
debug!(
|
||||
"No route determined, using default model from request: {}",
|
||||
chat_completions_request_for_arch_router.model
|
||||
);
|
||||
chat_completions_request_for_arch_router.model.clone()
|
||||
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
|
|
@ -148,7 +169,7 @@ pub async fn chat(
|
|||
};
|
||||
|
||||
debug!(
|
||||
"[BRIGHTSTAFF -> ARCH_ROUTER] URL: {}, Model Hint: {}",
|
||||
"[ARCH_ROUTER] URL: {}, Resolved Model: {}",
|
||||
full_qualified_llm_provider_url, model_name
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -97,12 +97,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
routing_llm_provider,
|
||||
));
|
||||
|
||||
let model_aliases = Arc::new(arch_config.model_aliases.clone());
|
||||
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let peer_addr = stream.peer_addr()?;
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
|
|
@ -114,6 +118,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
let llm_providers = llm_providers.clone();
|
||||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let agents_list = agents_list.clone();
|
||||
let listeners = listeners.clone();
|
||||
|
||||
|
|
@ -121,7 +126,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
|
||||
chat(req, router_service, fully_qualified_url)
|
||||
chat(req, router_service, fully_qualified_url, model_aliases)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@ pub struct Routing {
|
|||
pub model: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelAlias {
|
||||
pub target: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Agent {
|
||||
pub name: String,
|
||||
|
|
@ -41,6 +46,7 @@ pub struct Configuration {
|
|||
pub version: String,
|
||||
pub endpoints: Option<HashMap<String, Endpoint>>,
|
||||
pub llm_providers: Vec<LlmProvider>,
|
||||
pub model_aliases: Option<HashMap<String, ModelAlias>>,
|
||||
pub overrides: Option<Overrides>,
|
||||
pub system_prompt: Option<String>,
|
||||
pub prompt_guards: Option<PromptGuards>,
|
||||
|
|
|
|||
|
|
@ -104,6 +104,20 @@ pub struct ChatCompletionsRequest {
|
|||
// pub web_search: Option<bool>, // GOOD FIRST ISSUE: Future support for web search
|
||||
}
|
||||
|
||||
impl ChatCompletionsRequest {
|
||||
/// Suppress max_tokens if the model is o3, o3-*, openrouter/o3, or openrouter/o3-*
|
||||
pub fn suppress_max_tokens_if_o3(&mut self) {
|
||||
let model = self.model.as_str();
|
||||
let is_o3 = model == "o3"
|
||||
|| model.starts_with("o3-")
|
||||
|| model == "openrouter/o3"
|
||||
|| model.starts_with("openrouter/o3-");
|
||||
if is_o3 {
|
||||
self.max_tokens = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// CHAT COMPLETIONS API TYPES
|
||||
// ============================================================================
|
||||
|
|
@ -148,6 +162,20 @@ pub struct ResponseMessage {
|
|||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
impl Default for ResponseMessage {
|
||||
fn default() -> Self {
|
||||
ResponseMessage {
|
||||
role: Role::Assistant,
|
||||
content: None,
|
||||
refusal: None,
|
||||
annotations: None,
|
||||
audio: None,
|
||||
function_call: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseMessage {
|
||||
/// Convert ResponseMessage to Message for internal processing
|
||||
/// This is useful for transformations that need to work with the request Message type
|
||||
|
|
@ -353,6 +381,21 @@ pub struct ChatCompletionsResponse {
|
|||
pub service_tier: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for ChatCompletionsResponse {
|
||||
fn default() -> Self {
|
||||
ChatCompletionsResponse {
|
||||
id: String::new(),
|
||||
object: String::new(),
|
||||
created: 0,
|
||||
model: String::new(),
|
||||
choices: vec![],
|
||||
usage: Usage::default(),
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Finish reason for completion
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
|
|
@ -375,6 +418,18 @@ pub struct Usage {
|
|||
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
||||
}
|
||||
|
||||
impl Default for Usage {
|
||||
fn default() -> Self {
|
||||
Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
prompt_tokens_details: None,
|
||||
completion_tokens_details: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Detailed breakdown of prompt tokens
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
|
@ -403,6 +458,16 @@ pub struct Choice {
|
|||
pub logprobs: Option<Value>,
|
||||
}
|
||||
|
||||
impl Default for Choice {
|
||||
fn default() -> Self {
|
||||
Choice {
|
||||
index: 0,
|
||||
message: ResponseMessage::default(),
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// STREAMING API TYPES
|
||||
|
|
@ -530,7 +595,10 @@ impl TryFrom<&[u8]> for ChatCompletionsRequest {
|
|||
type Error = OpenAIStreamError;
|
||||
|
||||
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
|
||||
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
|
||||
let mut req: ChatCompletionsRequest = serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
|
||||
// Use the centralized suppression logic
|
||||
req.suppress_max_tokens_if_o3();
|
||||
Ok(req)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
|
|||
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
|
||||
let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice);
|
||||
|
||||
Ok(ChatCompletionsRequest {
|
||||
let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: req.model,
|
||||
messages: openai_messages,
|
||||
temperature: req.temperature,
|
||||
|
|
@ -109,7 +109,9 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
|
|||
tool_choice: openai_tool_choice,
|
||||
parallel_tool_calls,
|
||||
..Default::default()
|
||||
})
|
||||
};
|
||||
_chat_completions_req.suppress_max_tokens_if_o3();
|
||||
Ok(_chat_completions_req)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue