Add support for updating model preferences (#510)

This commit is contained in:
Adil Hafeez 2025-07-02 14:08:19 -07:00 committed by GitHub
parent 1963020c21
commit 00dc95e034
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 437 additions and 53 deletions

View file

@ -1,6 +1,7 @@
use std::sync::Arc;
use bytes::Bytes;
use common::configuration::ModelUsagePreference;
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use hermesllm::providers::openai::types::ChatCompletionsRequest;
use http_body_util::combinators::BoxBody;
@ -11,7 +12,7 @@ use hyper::{Request, Response, StatusCode};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{debug, info, warn};
use tracing::{debug, info, trace, warn};
use crate::router::llm_router::RouterService;
@ -30,23 +31,57 @@ pub async fn chat_completions(
let chat_request_bytes = request.collect().await?.to_bytes();
let chat_completion_request: ChatCompletionsRequest =
match ChatCompletionsRequest::try_from(chat_request_bytes.as_ref()) {
Ok(request) => request,
Err(err) => {
warn!(
"arch-router request body string: {}",
String::from_utf8_lossy(&chat_request_bytes)
);
let err_msg = format!("Failed to parse request body: {}", err);
warn!("{}", err_msg);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
let chat_request_parsed = serde_json::from_slice::<serde_json::Value>(&chat_request_bytes)
.inspect_err(|err| {
warn!(
"Failed to parse request body as JSON: err: {}, str: {}",
err,
String::from_utf8_lossy(&chat_request_bytes)
)
})
.unwrap_or_else(|_| {
warn!(
"Failed to parse request body as JSON: {}",
String::from_utf8_lossy(&chat_request_bytes)
);
serde_json::Value::Null
});
debug!(
if chat_request_parsed == serde_json::Value::Null {
warn!("Request body is not valid JSON");
let err_msg = "Request body is not valid JSON".to_string();
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
let chat_completion_request: ChatCompletionsRequest =
serde_json::from_value(chat_request_parsed.clone()).unwrap();
// remove metadata from the request
let mut chat_request_user_preferences_removed = chat_request_parsed;
if let Some(metadata) = chat_request_user_preferences_removed.get_mut("metadata") {
info!("Removing metadata from request");
if let Some(m) = metadata.as_object_mut() {
m.remove("archgw_preference_config");
info!("Removed archgw_preference_config from metadata");
}
// metadata.as_object_mut().map(|m| {
// m.remove("archgw_preference_config");
// info!("Removed archgw_preference_config from metadata");
// });
// if metadata is empty, remove it
if metadata.as_object().map_or(false, |m| m.is_empty()) {
info!("Removing empty metadata from request");
chat_request_user_preferences_removed
.as_object_mut()
.map(|m| m.remove("metadata"));
}
}
trace!(
"arch-router request body: {}",
&serde_json::to_string(&chat_completion_request).unwrap()
);
@ -56,8 +91,25 @@ pub async fn chat_completions(
.find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
let usage_preferences_str: Option<String> =
chat_completion_request.metadata.and_then(|metadata| {
metadata
.get("archgw_preference_config")
.and_then(|value| value.as_str().map(String::from))
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
debug!("usage preferences: {:?}", usage_preferences);
let mut selected_llm = match router_service
.determine_route(&chat_completion_request.messages, trace_parent.clone())
.determine_route(
&chat_completion_request.messages,
trace_parent.clone(),
usage_preferences,
)
.await
{
Ok(route) => route,
@ -93,10 +145,16 @@ pub async fn chat_completions(
);
}
let chat_request_parsed_bytes =
serde_json::to_string(&chat_request_user_preferences_removed).unwrap();
// remove content-length header if it exists
request_headers.remove(header::CONTENT_LENGTH);
let llm_response = match reqwest::Client::new()
.post(llm_provider_endpoint)
.headers(request_headers)
.body(chat_request_bytes)
.body(chat_request_parsed_bytes)
.send()
.await
{

View file

@ -1,2 +1,3 @@
pub mod chat_completions;
pub mod models;
pub mod preferences;

View file

@ -7,10 +7,10 @@ use serde_json;
use std::sync::Arc;
pub async fn list_models(
llm_providers: Arc<Vec<LlmProvider>>,
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let prov = llm_providers.clone();
let providers = (*prov).clone();
let prov = llm_providers.read().await;
let providers = prov.clone();
let openai_models: Models = providers.into_models();
match serde_json::to_string(&openai_models) {

View file

@ -0,0 +1,135 @@
use bytes::Bytes;
use common::configuration::{LlmProvider, ModelUsagePreference};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::{Request, Response, StatusCode};
use serde_json;
use std::{collections::HashMap, sync::Arc};
use tracing::{info, warn};
pub async fn list_preferences(
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let prov = llm_providers.read().await;
// convert the LlmProvider to UsageBasedProvider
let providers_with_usage = prov
.iter()
.map(|provider| ModelUsagePreference {
name: provider.name.clone(),
model: provider.model.clone().unwrap_or_default(),
usage: provider.usage.clone(),
})
.collect::<Vec<ModelUsagePreference>>();
match serde_json::to_string(&providers_with_usage) {
Ok(json) => {
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap()
}
Err(_) => {
let body = Full::new(Bytes::from_static(
b"{\"error\":\"Failed to serialize models\"}",
))
.map_err(|never| match never {})
.boxed();
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("Content-Type", "application/json")
.body(body)
.unwrap()
}
}
}
pub async fn update_preferences(
request: Request<hyper::body::Incoming>,
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_body = request.collect().await?.to_bytes();
let usage: Vec<ModelUsagePreference> = match serde_json::from_slice(&request_body) {
Ok(usage) => usage,
Err(_) => {
let response_body = Full::new(Bytes::from_static(b"Invalid request body: "))
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap());
}
};
let usage_model_map: HashMap<String, ModelUsagePreference> =
usage.into_iter().map(|u| (u.model.clone(), u)).collect();
info!(
"Updating usage preferences for models: {:?}",
usage_model_map.keys()
);
let mut llm_providers = llm_providers.write().await;
// ensure that models coming in the request are valid
let llm_provider_names: Vec<String> = llm_providers
.iter()
.map(|provider| provider.name.clone())
.collect();
for model in usage_model_map.keys() {
if !llm_provider_names.contains(model) {
let model_not_found = format!("model not found: {}", model);
warn!("updating preferences: {}", model_not_found);
let response_body = Full::new(model_not_found.into())
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap());
}
}
let mut updated_models_list = Vec::new();
for provider in llm_providers.iter_mut() {
if let Some(usage_provider) = usage_model_map.get(&provider.name) {
provider.usage = usage_provider.usage.clone();
updated_models_list.push(ModelUsagePreference {
name: provider.name.clone(),
model: provider.model.clone().unwrap_or_default(),
usage: provider.usage.clone(),
});
}
}
if !updated_models_list.is_empty() {
// return list of updated models
let response_body = Full::new(Bytes::from(format!(
"{{\"updated_models\": {}}}",
serde_json::to_string(&updated_models_list).unwrap()
)))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(response_body)
.unwrap())
} else {
let response_body = Full::new(Bytes::from_static(b"Provider not found"))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap())
}
}

View file

@ -1,5 +1,6 @@
use brightstaff::handlers::chat_completions::chat_completions;
use brightstaff::handlers::models::list_models;
use brightstaff::handlers::preferences::{list_preferences, update_preferences};
use brightstaff::router::llm_router::RouterService;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
@ -16,7 +17,8 @@ use opentelemetry_http::HeaderExtractor;
use std::sync::Arc;
use std::{env, fs};
use tokio::net::TcpListener;
use tracing::{debug, info};
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
pub mod router;
@ -53,7 +55,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let arch_config = Arc::new(config);
let llm_providers = Arc::new(arch_config.llm_providers.clone());
let llm_providers = Arc::new(RwLock::new(arch_config.llm_providers.clone()));
debug!(
"arch_config: {:?}",
@ -101,6 +103,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/router/preferences") => {
Ok(list_preferences(llm_providers).await)
}
(&Method::PUT, "/v1/router/preferences") => {
update_preferences(req, llm_providers).await
}
(&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await),
(&Method::OPTIONS, "/v1/models") => {
let mut response = Response::new(empty());
@ -141,7 +149,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.serve_connection(io, service)
.await
{
info!("Error serving connection: {:?}", err);
warn!("Error serving connection: {:?}", err);
}
});
}

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};
use common::{
configuration::{LlmProvider, LlmRoute},
configuration::{LlmProvider, LlmRoute, ModelUsagePreference},
consts::ARCH_PROVIDER_HINT_HEADER,
};
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
@ -19,6 +19,7 @@ pub struct RouterService {
router_model: Arc<dyn RouterModel>,
routing_model_name: String,
llm_usage_defined: bool,
llm_provider_map: HashMap<String, LlmProvider>,
}
#[derive(Debug, Error)]
@ -55,12 +56,18 @@ impl RouterService {
router_model_v1::MAX_TOKEN_LEN,
));
let llm_provider_map: HashMap<String, LlmProvider> = providers
.into_iter()
.map(|provider| (provider.name.clone(), provider))
.collect();
RouterService {
router_url,
client: reqwest::Client::new(),
router_model,
routing_model_name,
llm_usage_defined: !providers_with_usage.is_empty(),
llm_provider_map,
}
}
@ -68,12 +75,15 @@ impl RouterService {
&self,
messages: &[Message],
trace_parent: Option<String>,
usage_preferences: Option<Vec<ModelUsagePreference>>,
) -> Result<Option<String>> {
if !self.llm_usage_defined {
return Ok(None);
}
let router_request = self.router_model.generate_request(messages);
let router_request = self
.router_model
.generate_request(messages, &usage_preferences);
info!(
"sending request to arch-router model: {}, endpoint: {}",
@ -144,13 +154,40 @@ impl RouterService {
if let Some(ContentType::Text(content)) =
&chat_completion_response.choices[0].message.content
{
let mut selected_model: Option<String> = None;
if let Some(selected_llm_name) = self.router_model.parse_response(content)? {
if selected_llm_name != "other" {
if let Some(usage_preferences) = usage_preferences {
for usage in usage_preferences {
if usage.name == selected_llm_name {
selected_model = Some(usage.model);
break;
}
}
if selected_model.is_none() {
warn!(
"Selected LLM model not found in usage preferences: {}",
selected_llm_name
);
}
} else if let Some(provider) = self.llm_provider_map.get(&selected_llm_name) {
selected_model = provider.model.clone();
} else {
warn!(
"Selected LLM model not found in provider map: {}",
selected_llm_name
);
}
}
}
info!(
"router response: {}, response time: {}ms",
"router response: {}, selected_model: {:?}, response time: {}ms",
content.replace("\n", "\\n"),
selected_model,
router_response_time.as_millis()
);
let selected_llm = self.router_model.parse_response(content)?;
Ok(selected_llm)
Ok(selected_model)
} else {
Ok(None)
}

View file

@ -1,3 +1,4 @@
use common::configuration::ModelUsagePreference;
use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message};
use thiserror::Error;
@ -10,7 +11,11 @@ pub enum RoutingModelError {
pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub trait RouterModel: Send + Sync {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest;
fn generate_request(
&self,
messages: &[Message],
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest;
fn parse_response(&self, content: &str) -> Result<Option<String>>;
fn get_model_name(&self) -> String;
}

View file

@ -1,5 +1,5 @@
use common::{
configuration::LlmRoute,
configuration::{LlmRoute, ModelUsagePreference},
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
};
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
@ -55,7 +55,11 @@ struct LlmRouterResponse {
const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters
impl RouterModel for RouterModelV1 {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
fn generate_request(
&self,
messages: &[Message],
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest {
// remove system prompt, tool calls, tool call response and messages without content
// if content is empty its likely a tool call
// when role == tool its tool call response
@ -131,8 +135,22 @@ impl RouterModel for RouterModelV1 {
})
.collect::<Vec<Message>>();
let llm_route_json = usage_preferences
.as_ref()
.map(|prefs| {
let llm_route: Vec<LlmRoute> = prefs
.iter()
.map(|pref| LlmRoute {
name: pref.name.clone(),
description: pref.usage.clone().unwrap_or_default(),
})
.collect();
serde_json::to_string(&llm_route).unwrap_or_default()
})
.unwrap_or_else(|| self.llm_route_json_str.clone());
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &self.llm_route_json_str)
.replace("{routes}", &llm_route_json)
.replace(
"{conversation}",
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
@ -204,8 +222,6 @@ impl std::fmt::Debug for dyn RouterModel {
#[cfg(test)]
mod tests {
use crate::utils::tracing::init_tracer;
use super::*;
use pretty_assertions::assert_eq;
@ -261,7 +277,71 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
assert_eq!(expected_prompt, prompt.to_string());
}
#[test]
fn test_system_prompt_format_usage_preferences() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"code-generation","description":"generating new code snippets, functions, or boilerplate based on user prompts or requirements"}]
</routes>
<conversation>
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
let conversation_str = r#"
[
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
}
]
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let usage_preferences = Some(vec![ModelUsagePreference {
name: "code-generation".to_string(),
model: "claude/claude-3-7-sonnet".to_string(),
usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()),
}]);
let req = router.generate_request(&conversation, &usage_preferences);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -270,7 +350,6 @@ Based on your analysis, provide your response in the following JSON formats if y
#[test]
fn test_conversation_exceed_token_count() {
let _tracer = init_tracer();
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
@ -323,7 +402,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -332,7 +411,6 @@ Based on your analysis, provide your response in the following JSON formats if y
#[test]
fn test_conversation_exceed_token_count_large_single_message() {
let _tracer = init_tracer();
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
@ -385,7 +463,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -394,7 +472,6 @@ Based on your analysis, provide your response in the following JSON formats if y
#[test]
fn test_conversation_trim_upto_user_message() {
let _tracer = init_tracer();
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
@ -455,7 +532,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -525,7 +602,7 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -621,7 +698,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();