Add support for updating model preferences (#510)

This commit is contained in:
Adil Hafeez 2025-07-02 14:08:19 -07:00 committed by GitHub
parent 1963020c21
commit 00dc95e034
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 437 additions and 53 deletions

View file

@ -1,6 +1,7 @@
use std::sync::Arc;
use bytes::Bytes;
use common::configuration::ModelUsagePreference;
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use hermesllm::providers::openai::types::ChatCompletionsRequest;
use http_body_util::combinators::BoxBody;
@ -11,7 +12,7 @@ use hyper::{Request, Response, StatusCode};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{debug, info, warn};
use tracing::{debug, info, trace, warn};
use crate::router::llm_router::RouterService;
@ -30,23 +31,57 @@ pub async fn chat_completions(
let chat_request_bytes = request.collect().await?.to_bytes();
let chat_completion_request: ChatCompletionsRequest =
match ChatCompletionsRequest::try_from(chat_request_bytes.as_ref()) {
Ok(request) => request,
Err(err) => {
warn!(
"arch-router request body string: {}",
String::from_utf8_lossy(&chat_request_bytes)
);
let err_msg = format!("Failed to parse request body: {}", err);
warn!("{}", err_msg);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
let chat_request_parsed = serde_json::from_slice::<serde_json::Value>(&chat_request_bytes)
.inspect_err(|err| {
warn!(
"Failed to parse request body as JSON: err: {}, str: {}",
err,
String::from_utf8_lossy(&chat_request_bytes)
)
})
.unwrap_or_else(|_| {
warn!(
"Failed to parse request body as JSON: {}",
String::from_utf8_lossy(&chat_request_bytes)
);
serde_json::Value::Null
});
debug!(
if chat_request_parsed == serde_json::Value::Null {
warn!("Request body is not valid JSON");
let err_msg = "Request body is not valid JSON".to_string();
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
let chat_completion_request: ChatCompletionsRequest =
serde_json::from_value(chat_request_parsed.clone()).unwrap();
// remove metadata from the request
let mut chat_request_user_preferences_removed = chat_request_parsed;
if let Some(metadata) = chat_request_user_preferences_removed.get_mut("metadata") {
info!("Removing metadata from request");
if let Some(m) = metadata.as_object_mut() {
m.remove("archgw_preference_config");
info!("Removed archgw_preference_config from metadata");
}
// metadata.as_object_mut().map(|m| {
// m.remove("archgw_preference_config");
// info!("Removed archgw_preference_config from metadata");
// });
// if metadata is empty, remove it
if metadata.as_object().map_or(false, |m| m.is_empty()) {
info!("Removing empty metadata from request");
chat_request_user_preferences_removed
.as_object_mut()
.map(|m| m.remove("metadata"));
}
}
trace!(
"arch-router request body: {}",
&serde_json::to_string(&chat_completion_request).unwrap()
);
@ -56,8 +91,25 @@ pub async fn chat_completions(
.find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
let usage_preferences_str: Option<String> =
chat_completion_request.metadata.and_then(|metadata| {
metadata
.get("archgw_preference_config")
.and_then(|value| value.as_str().map(String::from))
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
debug!("usage preferences: {:?}", usage_preferences);
let mut selected_llm = match router_service
.determine_route(&chat_completion_request.messages, trace_parent.clone())
.determine_route(
&chat_completion_request.messages,
trace_parent.clone(),
usage_preferences,
)
.await
{
Ok(route) => route,
@ -93,10 +145,16 @@ pub async fn chat_completions(
);
}
let chat_request_parsed_bytes =
serde_json::to_string(&chat_request_user_preferences_removed).unwrap();
// remove content-length header if it exists
request_headers.remove(header::CONTENT_LENGTH);
let llm_response = match reqwest::Client::new()
.post(llm_provider_endpoint)
.headers(request_headers)
.body(chat_request_bytes)
.body(chat_request_parsed_bytes)
.send()
.await
{

View file

@ -1,2 +1,3 @@
pub mod chat_completions;
pub mod models;
pub mod preferences;

View file

@ -7,10 +7,10 @@ use serde_json;
use std::sync::Arc;
pub async fn list_models(
llm_providers: Arc<Vec<LlmProvider>>,
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let prov = llm_providers.clone();
let providers = (*prov).clone();
let prov = llm_providers.read().await;
let providers = prov.clone();
let openai_models: Models = providers.into_models();
match serde_json::to_string(&openai_models) {

View file

@ -0,0 +1,135 @@
use bytes::Bytes;
use common::configuration::{LlmProvider, ModelUsagePreference};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::{Request, Response, StatusCode};
use serde_json;
use std::{collections::HashMap, sync::Arc};
use tracing::{info, warn};
pub async fn list_preferences(
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let prov = llm_providers.read().await;
// convert the LlmProvider to UsageBasedProvider
let providers_with_usage = prov
.iter()
.map(|provider| ModelUsagePreference {
name: provider.name.clone(),
model: provider.model.clone().unwrap_or_default(),
usage: provider.usage.clone(),
})
.collect::<Vec<ModelUsagePreference>>();
match serde_json::to_string(&providers_with_usage) {
Ok(json) => {
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap()
}
Err(_) => {
let body = Full::new(Bytes::from_static(
b"{\"error\":\"Failed to serialize models\"}",
))
.map_err(|never| match never {})
.boxed();
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("Content-Type", "application/json")
.body(body)
.unwrap()
}
}
}
pub async fn update_preferences(
request: Request<hyper::body::Incoming>,
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_body = request.collect().await?.to_bytes();
let usage: Vec<ModelUsagePreference> = match serde_json::from_slice(&request_body) {
Ok(usage) => usage,
Err(_) => {
let response_body = Full::new(Bytes::from_static(b"Invalid request body: "))
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap());
}
};
let usage_model_map: HashMap<String, ModelUsagePreference> =
usage.into_iter().map(|u| (u.model.clone(), u)).collect();
info!(
"Updating usage preferences for models: {:?}",
usage_model_map.keys()
);
let mut llm_providers = llm_providers.write().await;
// ensure that models coming in the request are valid
let llm_provider_names: Vec<String> = llm_providers
.iter()
.map(|provider| provider.name.clone())
.collect();
for model in usage_model_map.keys() {
if !llm_provider_names.contains(model) {
let model_not_found = format!("model not found: {}", model);
warn!("updating preferences: {}", model_not_found);
let response_body = Full::new(model_not_found.into())
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap());
}
}
let mut updated_models_list = Vec::new();
for provider in llm_providers.iter_mut() {
if let Some(usage_provider) = usage_model_map.get(&provider.name) {
provider.usage = usage_provider.usage.clone();
updated_models_list.push(ModelUsagePreference {
name: provider.name.clone(),
model: provider.model.clone().unwrap_or_default(),
usage: provider.usage.clone(),
});
}
}
if !updated_models_list.is_empty() {
// return list of updated models
let response_body = Full::new(Bytes::from(format!(
"{{\"updated_models\": {}}}",
serde_json::to_string(&updated_models_list).unwrap()
)))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(response_body)
.unwrap())
} else {
let response_body = Full::new(Bytes::from_static(b"Provider not found"))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap())
}
}