adding support for model aliases in archgw

This commit is contained in:
Salman Paracha 2025-09-14 22:30:57 -07:00
parent 1e8c81d8f6
commit f13b420146
13 changed files with 1438 additions and 8 deletions

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use bytes::Bytes;
use common::configuration::ModelUsagePreference;
use common::configuration::{Configuration, ModelUsagePreference};
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use hermesllm::apis::openai::ChatCompletionsRequest;
use hermesllm::clients::SupportedAPIs;
@ -28,6 +28,7 @@ pub async fn chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
config: Arc<Configuration>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
@ -35,6 +36,7 @@ pub async fn chat(
let chat_request_bytes = request.collect().await?.to_bytes();
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
Ok(request) => request,
Err(err) => {
@ -46,6 +48,24 @@ pub async fn chat(
}
};
// === Model alias resolution: update model field in client_request immediately ===
// This ensures all downstream objects use the resolved model
let original_model = client_request.model().to_string();
let resolved_model = if let Some(model_aliases) = &config.model_aliases {
if let Some(alias) = model_aliases.get(&original_model) {
debug!(
"[BRIGHTSTAFF] Model Alias: 'From {}' -> 'To{}'",
original_model, alias.target
);
alias.target.clone()
} else {
original_model.clone()
}
} else {
original_model.clone()
};
client_request.set_model(resolved_model.clone());
// Clone metadata for routing and remove archgw_preference_config from original
let routing_metadata = client_request.metadata().clone();
@ -132,11 +152,12 @@ pub async fn chat(
Ok(route) => match route {
Some((_, model_name)) => model_name,
None => {
debug!(
"No route determined, using default model from request: {}",
debug!(
"[BRIGHTSTAFF] No route determined, using default model from request: {}",
chat_completions_request_for_arch_router.model
);
chat_completions_request_for_arch_router.model.clone()
}
},
Err(err) => {
@ -148,7 +169,7 @@ pub async fn chat(
};
debug!(
"[BRIGHTSTAFF -> ARCH_ROUTER] URL: {}, Model Hint: {}",
"[BRIGHTSTAFF -> ARCH_ROUTER] URL: {}, Final Model: {}",
full_qualified_llm_provider_url, model_name
);

View file

@ -101,6 +101,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let router_service: Arc<RouterService> = Arc::clone(&router_service);
let llm_provider_url = llm_provider_url.clone();
let arch_config = Arc::clone(&arch_config);
let llm_providers = llm_providers.clone();
let service = service_fn(move |req| {
@ -109,12 +110,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let arch_config = Arc::clone(&arch_config);
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
chat(req, router_service, fully_qualified_url)
chat(req, router_service, fully_qualified_url, arch_config)
.with_context(parent_cx)
.await
}

View file

@ -13,11 +13,17 @@ pub struct Routing {
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelAlias {
pub target: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
pub endpoints: Option<HashMap<String, Endpoint>>,
pub llm_providers: Vec<LlmProvider>,
pub model_aliases: Option<HashMap<String, ModelAlias>>,
pub overrides: Option<Overrides>,
pub system_prompt: Option<String>,
pub prompt_guards: Option<PromptGuards>,

View file

@ -104,6 +104,20 @@ pub struct ChatCompletionsRequest {
// pub web_search: Option<bool>, // GOOD FIRST ISSUE: Future support for web search
}
impl ChatCompletionsRequest {
/// Suppress max_tokens if the model is o3, o3-*, openrouter/o3, or openrouter/o3-*
pub fn suppress_max_tokens_if_o3(&mut self) {
let model = self.model.as_str();
let is_o3 = model == "o3"
|| model.starts_with("o3-")
|| model == "openrouter/o3"
|| model.starts_with("openrouter/o3-");
if is_o3 {
self.max_tokens = None;
}
}
}
// ============================================================================
// CHAT COMPLETIONS API TYPES
// ============================================================================
@ -530,7 +544,10 @@ impl TryFrom<&[u8]> for ChatCompletionsRequest {
type Error = OpenAIStreamError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)
let mut req: ChatCompletionsRequest = serde_json::from_slice(bytes).map_err(OpenAIStreamError::from)?;
// Use the centralized suppression logic
req.suppress_max_tokens_if_o3();
Ok(req)
}
}

View file

@ -97,7 +97,7 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
let openai_tools = req.tools.map(|tools| convert_anthropic_tools(tools));
let (openai_tool_choice, parallel_tool_calls) = convert_anthropic_tool_choice(req.tool_choice);
Ok(ChatCompletionsRequest {
let mut _chat_completions_req: ChatCompletionsRequest = ChatCompletionsRequest {
model: req.model,
messages: openai_messages,
temperature: req.temperature,
@ -109,7 +109,9 @@ impl TryFrom<AnthropicMessagesRequest> for ChatCompletionsRequest {
tool_choice: openai_tool_choice,
parallel_tool_calls,
..Default::default()
})
};
_chat_completions_req.suppress_max_tokens_if_o3();
Ok(_chat_completions_req)
}
}