mirror of
https://github.com/katanemo/plano.git
synced 2026-05-08 07:12:42 +02:00
adding support for model aliases in archgw (#566)
* adding support for model aliases in archgw * fixed PR based on feedback * removing README. Not relevant for PR --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-136.local>
This commit is contained in:
parent
1e8c81d8f6
commit
4eb2b410c5
12 changed files with 634 additions and 14 deletions
|
|
@ -1,7 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::ModelUsagePreference;
|
||||
use common::configuration::{ModelAlias, ModelUsagePreference};
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
|
|
@ -28,6 +28,7 @@ pub async fn chat(
|
|||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
|
|
@ -35,6 +36,7 @@ pub async fn chat(
|
|||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
|
||||
|
||||
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
|
|
@ -46,6 +48,24 @@ pub async fn chat(
|
|||
}
|
||||
};
|
||||
|
||||
// Model alias resolution: update model field in client_request immediately
|
||||
// This ensures all downstream objects use the resolved model
|
||||
let model_from_request = client_request.model().to_string();
|
||||
let resolved_model = if let Some(model_aliases) = model_aliases.as_ref() {
|
||||
if let Some(model_alias) = model_aliases.get(&model_from_request) {
|
||||
debug!(
|
||||
"Model Alias: 'From {}' -> 'To{}'",
|
||||
model_from_request, model_alias.target
|
||||
);
|
||||
model_alias.target.clone()
|
||||
} else {
|
||||
model_from_request.clone()
|
||||
}
|
||||
} else {
|
||||
model_from_request.clone()
|
||||
};
|
||||
client_request.set_model(resolved_model.clone());
|
||||
|
||||
// Clone metadata for routing and remove archgw_preference_config from original
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
||||
|
|
@ -77,7 +97,7 @@ pub async fn chat(
|
|||
};
|
||||
|
||||
debug!(
|
||||
"[BRIGHTSTAFF -> ARCH_ROUTER] REQ: {}",
|
||||
"[ARCH_ROUTER REQ]: {}",
|
||||
&serde_json::to_string(&chat_completions_request_for_arch_router).unwrap()
|
||||
);
|
||||
|
||||
|
|
@ -132,11 +152,12 @@ pub async fn chat(
|
|||
Ok(route) => match route {
|
||||
Some((_, model_name)) => model_name,
|
||||
None => {
|
||||
debug!(
|
||||
debug!(
|
||||
"No route determined, using default model from request: {}",
|
||||
chat_completions_request_for_arch_router.model
|
||||
);
|
||||
chat_completions_request_for_arch_router.model.clone()
|
||||
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
|
|
@ -148,7 +169,7 @@ pub async fn chat(
|
|||
};
|
||||
|
||||
debug!(
|
||||
"[BRIGHTSTAFF -> ARCH_ROUTER] URL: {}, Model Hint: {}",
|
||||
"[ARCH_ROUTER] URL: {}, Resolved Model: {}",
|
||||
full_qualified_llm_provider_url, model_name
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -94,12 +94,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
routing_llm_provider,
|
||||
));
|
||||
|
||||
let model_aliases = Arc::new(arch_config.model_aliases.clone());
|
||||
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await?;
|
||||
let peer_addr = stream.peer_addr()?;
|
||||
let io = TokioIo::new(stream);
|
||||
|
||||
let router_service: Arc<RouterService> = Arc::clone(&router_service);
|
||||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
|
||||
let llm_providers = llm_providers.clone();
|
||||
|
|
@ -109,12 +113,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let parent_cx = extract_context_from_request(&req);
|
||||
let llm_provider_url = llm_provider_url.clone();
|
||||
let llm_providers = llm_providers.clone();
|
||||
let model_aliases = Arc::clone(&model_aliases);
|
||||
|
||||
async move {
|
||||
match (req.method(), req.uri().path()) {
|
||||
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
|
||||
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
|
||||
chat(req, router_service, fully_qualified_url)
|
||||
chat(req, router_service, fully_qualified_url, model_aliases)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue