adding support for model aliases in archgw (#566)

* adding support for model aliases in archgw

* fixed PR based on feedback

* removing README. Not relevant for PR

---------

Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-136.local>
This commit is contained in:
Salman Paracha 2025-09-16 11:12:08 -07:00 committed by GitHub
parent 1e8c81d8f6
commit 4eb2b410c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 634 additions and 14 deletions

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use std::collections::HashMap;
use bytes::Bytes;
use common::configuration::ModelUsagePreference;
use common::configuration::{ModelAlias, ModelUsagePreference};
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use hermesllm::apis::openai::ChatCompletionsRequest;
use hermesllm::clients::SupportedAPIs;
@ -28,6 +28,7 @@ pub async fn chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
@ -35,6 +36,7 @@ pub async fn chat(
let chat_request_bytes = request.collect().await?.to_bytes();
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
Ok(request) => request,
Err(err) => {
@ -46,6 +48,24 @@ pub async fn chat(
}
};
// Model alias resolution: update model field in client_request immediately
// This ensures all downstream objects use the resolved model
let model_from_request = client_request.model().to_string();
let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() {
if let Some(model_alias) = model_aliases.get(&model_from_request) {
debug!(
"Model Alias: 'From {}' -> 'To{}'",
model_from_request, model_alias.target
);
model_alias.target.clone()
} else {
model_from_request.clone()
}
} else {
model_from_request.clone()
};
client_request.set_model(resolved_model.clone());
// Clone metadata for routing and remove archgw_preference_config from original
let routing_metadata = client_request.metadata().clone();
@ -77,7 +97,7 @@ pub async fn chat(
};
debug!(
"[BRIGHTSTAFF -> ARCH_ROUTER] REQ: {}",
"[ARCH_ROUTER REQ]: {}",
&serde_json::to_string(&chat_completions_request_for_arch_router).unwrap()
);
@ -132,11 +152,12 @@ pub async fn chat(
Ok(route) => match route {
Some((_, model_name)) => model_name,
None => {
debug!(
debug!(
"No route determined, using default model from request: {}",
chat_completions_request_for_arch_router.model
);
chat_completions_request_for_arch_router.model.clone()
}
},
Err(err) => {
@ -148,7 +169,7 @@ pub async fn chat(
};
debug!(
"[BRIGHTSTAFF -> ARCH_ROUTER] URL: {}, Model Hint: {}",
"[ARCH_ROUTER] URL: {}, Resolved Model: {}",
full_qualified_llm_provider_url, model_name
);

View file

@ -94,12 +94,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
routing_llm_provider,
));
let model_aliases = Arc::new(arch_config.model_aliases.clone());
loop {
let (stream, _) = listener.accept().await?;
let peer_addr = stream.peer_addr()?;
let io = TokioIo::new(stream);
let router_service: Arc<RouterService> = Arc::clone(&router_service);
let model_aliases = Arc::clone(&model_aliases);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
@ -109,12 +113,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases);
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
chat(req, router_service, fully_qualified_url)
chat(req, router_service, fully_qualified_url, model_aliases)
.with_context(parent_cx)
.await
}