Merge branch 'main' into adil/agent_format

This commit is contained in:
Adil Hafeez 2025-09-16 14:54:43 -07:00
commit c1757bec88
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
26 changed files with 864 additions and 188 deletions

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use std::collections::HashMap;
use bytes::Bytes;
use common::configuration::ModelUsagePreference;
use common::configuration::{ModelAlias, ModelUsagePreference};
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use hermesllm::apis::openai::ChatCompletionsRequest;
use hermesllm::clients::SupportedAPIs;
@ -28,6 +28,7 @@ pub async fn chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
@ -35,6 +36,7 @@ pub async fn chat(
let chat_request_bytes = request.collect().await?.to_bytes();
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
Ok(request) => request,
Err(err) => {
@ -46,6 +48,24 @@ pub async fn chat(
}
};
// Model alias resolution: update model field in client_request immediately
// This ensures all downstream objects use the resolved model
let model_from_request = client_request.model().to_string();
let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() {
if let Some(model_alias) = model_aliases.get(&model_from_request) {
debug!(
"Model Alias: 'From {}' -> 'To{}'",
model_from_request, model_alias.target
);
model_alias.target.clone()
} else {
model_from_request.clone()
}
} else {
model_from_request.clone()
};
client_request.set_model(resolved_model.clone());
// Clone metadata for routing and remove archgw_preference_config from original
let routing_metadata = client_request.metadata().clone();
@ -77,7 +97,7 @@ pub async fn chat(
};
debug!(
"[BRIGHTSTAFF -> ARCH_ROUTER] REQ: {}",
"[ARCH_ROUTER REQ]: {}",
&serde_json::to_string(&chat_completions_request_for_arch_router).unwrap()
);
@ -132,11 +152,12 @@ pub async fn chat(
Ok(route) => match route {
Some((_, model_name)) => model_name,
None => {
debug!(
debug!(
"No route determined, using default model from request: {}",
chat_completions_request_for_arch_router.model
);
chat_completions_request_for_arch_router.model.clone()
}
},
Err(err) => {
@ -148,7 +169,7 @@ pub async fn chat(
};
debug!(
"[BRIGHTSTAFF -> ARCH_ROUTER] URL: {}, Model Hint: {}",
"[ARCH_ROUTER] URL: {}, Resolved Model: {}",
full_qualified_llm_provider_url, model_name
);

View file

@ -97,12 +97,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
routing_llm_provider,
));
let model_aliases = Arc::new(arch_config.model_aliases.clone());
loop {
let (stream, _) = listener.accept().await?;
let peer_addr = stream.peer_addr()?;
let io = TokioIo::new(stream);
let router_service: Arc<RouterService> = Arc::clone(&router_service);
let model_aliases = Arc::clone(&model_aliases);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
@ -114,6 +118,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let listeners = listeners.clone();
@ -121,7 +126,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
match (req.method(), req.uri().path()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
chat(req, router_service, fully_qualified_url)
chat(req, router_service, fully_qualified_url, model_aliases)
.with_context(parent_cx)
.await
}

View file

@ -13,6 +13,11 @@ pub struct Routing {
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelAlias {
pub target: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Agent {
pub name: String,
@ -41,6 +46,7 @@ pub struct Configuration {
pub version: String,
pub endpoints: Option<HashMap<String, Endpoint>>,
pub llm_providers: Vec<LlmProvider>,
pub model_aliases: Option<HashMap<String, ModelAlias>>,
pub overrides: Option<Overrides>,
pub system_prompt: Option<String>,
pub prompt_guards: Option<PromptGuards>,

View file

@ -104,6 +104,20 @@ pub struct ChatCompletionsRequest {
// pub web_search: Option<bool>, // GOOD FIRST ISSUE: Future support for web search
}
impl ChatCompletionsRequest {
/// Suppress max_tokens if the model is o3, o3-*, openrouter/o3, or openrouter/o3-*
pub fn suppress_max_tokens_if_o3(&mut self) {
let model = self.model.as_str();
let is_o3 = model == "o3"
|| model.starts_with("o3-")
|| model == "openrouter/o3"
|| model.starts_with("openrouter/o3-");
if is_o3 {
self.max_tokens = None;
}
}
}
// ============================================================================
// CHAT COMPLETIONS API TYPES
// ============================================================================
@ -148,6 +162,20 @@ pub struct ResponseMessage {
pub tool_calls: Option<Vec<ToolCall>>,
}
impl Default for ResponseMessage {
fn default() -> Self {
ResponseMessage {
role: Role::Assistant,
content: None,
refusal: None,
annotations: None,
audio: None,
function_call: None,
tool_calls: None,
}
}
}
impl ResponseMessage {
/// Convert ResponseMessage to Message for internal processing
/// This is useful for transformations that need to work with the request Message type
@ -353,6 +381,21 @@ pub struct ChatCompletionsResponse {
pub service_tier: Option<String>,
}
impl Default for ChatCompletionsResponse {
fn default() -> Self {
ChatCompletionsResponse {
id: String::new(),
object: String::new(),
created: 0,
model: String::new(),
choices: vec![],
usage: Usage::default(),
system_fingerprint: None,
service_tier: None,
}
}
}
/// Finish reason for completion
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
@ -375,6 +418,18 @@ pub struct Usage {
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
impl Default for Usage {
fn default() -> Self {
Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
completion_tokens_details: None,
}
}
}
/// Detailed breakdown of prompt tokens
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone)]
@ -403,6 +458,16 @@ pub struct Choice {
pub logprobs: Option<Value>,
}
impl Default for Choice {
fn default() -> Self {
Choice {
index: 0,
message: ResponseMessage::default(),
finish_reason: None,
logprobs: None,
}
}
}
// ============================================================================
// STREAMING API TYPES
@ -530,7 +595,10 @@ impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIStreamError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
let mut req: ChatCompletionsRequest = serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
// Use the centralized suppression logic
req.suppress_max_tokens_if_o3();
Ok(req)
}
}

View file

@ -97,7 +97,7 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice);
Ok(ChatCompletionsRequest {
let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest {
model: req.model,
messages: openai_messages,
temperature: req.temperature,
@ -109,7 +109,9 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
tool_choice: openai_tool_choice,
parallel_tool_calls,
..Default::default()
})
};
_chat_completions_req.suppress_max_tokens_if_o3();
Ok(_chat_completions_req)
}
}