add more changes

This commit is contained in:
Adil Hafeez 2025-06-24 23:57:28 -07:00
parent 80998d446d
commit 4373aeb00b
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
9 changed files with 78 additions and 41 deletions

View file

@ -1,6 +1,7 @@
use std::sync::Arc;
use bytes::Bytes;
use common::configuration::ModelUsagePreference;
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use hermesllm::providers::openai::types::ChatCompletionsRequest;
use http_body_util::combinators::BoxBody;
@ -56,8 +57,25 @@ pub async fn chat_completions(
.find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
let usage_preferences_str: Option<String> =
chat_completion_request.metadata.and_then(|metadata| {
metadata
.get("archgw_preference_config")
.and_then(|value| value.as_str().map(String::from))
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
debug!("usage preferences: {:?}", usage_preferences);
let mut selected_llm = match router_service
.determine_route(&chat_completion_request.messages, trace_parent.clone())
.determine_route(
&chat_completion_request.messages,
trace_parent.clone(),
usage_preferences,
)
.await
{
Ok(route) => route,

View file

@ -1,19 +1,10 @@
use bytes::Bytes;
use common::configuration::LlmProvider;
use common::configuration::{LlmProvider, ModelUsagePreference};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::{Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json;
use tracing::{info, warn};
use std::{collections::HashMap, sync::Arc};
use serde_with::skip_serializing_none;
#[skip_serializing_none]
#[derive(Serialize, Deserialize)]
struct UsageBasedProvider {
model: String,
usage: Option<String>,
}
use tracing::{info, warn};
pub async fn list_preferences(
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
@ -22,11 +13,11 @@ pub async fn list_preferences(
// convert the LlmProvider to UsageBasedProvider
let providers_with_usage = prov
.iter()
.map(|provider| UsageBasedProvider {
.map(|provider| ModelUsagePreference {
model: provider.name.clone(),
usage: provider.usage.clone(),
})
.collect::<Vec<UsageBasedProvider>>();
.collect::<Vec<ModelUsagePreference>>();
match serde_json::to_string(&providers_with_usage) {
Ok(json) => {
@ -60,7 +51,7 @@ pub async fn update_preferences(
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_body = request.collect().await?.to_bytes();
let usage: Vec<UsageBasedProvider> = match serde_json::from_slice(&request_body) {
let usage: Vec<ModelUsagePreference> = match serde_json::from_slice(&request_body) {
Ok(usage) => usage,
Err(_) => {
let response_body = Full::new(Bytes::from_static(b"Invalid request body: "))
@ -74,10 +65,13 @@ pub async fn update_preferences(
}
};
let usage_model_map: HashMap<String, UsageBasedProvider> =
let usage_model_map: HashMap<String, ModelUsagePreference> =
usage.into_iter().map(|u| (u.model.clone(), u)).collect();
info!("Updating usage preferences for models: {:?}", usage_model_map.keys());
info!(
"Updating usage preferences for models: {:?}",
usage_model_map.keys()
);
let mut llm_providers = llm_providers.write().await;
@ -106,7 +100,7 @@ pub async fn update_preferences(
for provider in llm_providers.iter_mut() {
if let Some(usage_provider) = usage_model_map.get(&provider.name) {
provider.usage = usage_provider.usage.clone();
updated_models_list.push(UsageBasedProvider {
updated_models_list.push(ModelUsagePreference {
model: provider.name.clone(),
usage: provider.usage.clone(),
});
@ -121,11 +115,11 @@ pub async fn update_preferences(
)))
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(response_body)
.unwrap());
.unwrap())
} else {
let response_body = Full::new(Bytes::from_static(b"Provider not found"))
.map_err(|never| match never {})

View file

@ -103,10 +103,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/router/preferences") => Ok(list_preferences(llm_providers).await),
(&Method::PUT, "/v1/router/preferences") => {
update_preferences(req, llm_providers).await
},
(&Method::GET, "/v1/router/preferences") => {
Ok(list_preferences(llm_providers).await)
}
(&Method::PUT, "/v1/router/preferences") => {
update_preferences(req, llm_providers).await
}
(&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await),
(&Method::OPTIONS, "/v1/models") => {
let mut response = Response::new(empty());

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use common::{
configuration::{LlmProvider, LlmRoute},
configuration::{LlmProvider, LlmRoute, ModelUsagePreference},
consts::ARCH_PROVIDER_HINT_HEADER,
};
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
@ -68,12 +68,15 @@ impl RouterService {
&self,
messages: &[Message],
trace_parent: Option<String>,
usage_preferences: Option<Vec<ModelUsagePreference>>,
) -> Result<Option<String>> {
if !self.llm_usage_defined {
return Ok(None);
}
let router_request = self.router_model.generate_request(messages);
let router_request = self
.router_model
.generate_request(messages, usage_preferences);
info!(
"sending request to arch-router model: {}, endpoint: {}",

View file

@ -1,3 +1,4 @@
use common::configuration::ModelUsagePreference;
use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message};
use thiserror::Error;
@ -10,7 +11,11 @@ pub enum RoutingModelError {
pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub trait RouterModel: Send + Sync {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest;
fn generate_request(
&self,
messages: &[Message],
usage_preferences: Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest;
fn parse_response(&self, content: &str) -> Result<Option<String>>;
fn get_model_name(&self) -> String;
}

View file

@ -1,5 +1,5 @@
use common::{
configuration::LlmRoute,
configuration::{LlmRoute, ModelUsagePreference},
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
};
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
@ -55,7 +55,11 @@ struct LlmRouterResponse {
const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters
impl RouterModel for RouterModelV1 {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
fn generate_request(
&self,
messages: &[Message],
usage_preferences: Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest {
// remove system prompt, tool calls, tool call response and messages without content
// if content is empty its likely a tool call
// when role == tool its tool call response
@ -131,8 +135,13 @@ impl RouterModel for RouterModelV1 {
})
.collect::<Vec<Message>>();
let llm_route_json = usage_preferences
.as_ref()
.map(|prefs| serde_json::to_string(prefs).unwrap_or_default())
.unwrap_or_else(|| self.llm_route_json_str.clone());
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &self.llm_route_json_str)
.replace("{routes}", &llm_route_json)
.replace(
"{conversation}",
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
@ -204,8 +213,6 @@ impl std::fmt::Debug for dyn RouterModel {
#[cfg(test)]
mod tests {
use crate::utils::tracing::init_tracer;
use super::*;
use pretty_assertions::assert_eq;
@ -261,7 +268,7 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -270,7 +277,6 @@ Based on your analysis, provide your response in the following JSON formats if y
#[test]
fn test_conversation_exceed_token_count() {
let _tracer = init_tracer();
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
@ -323,7 +329,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -332,7 +338,6 @@ Based on your analysis, provide your response in the following JSON formats if y
#[test]
fn test_conversation_exceed_token_count_large_single_message() {
let _tracer = init_tracer();
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
@ -385,7 +390,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -394,7 +399,6 @@ Based on your analysis, provide your response in the following JSON formats if y
#[test]
fn test_conversation_trim_upto_user_message() {
let _tracer = init_tracer();
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
@ -455,7 +459,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -525,7 +529,7 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -621,7 +625,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, None);
let prompt = req.messages[0].content.as_ref().unwrap();

View file

@ -2,6 +2,7 @@ use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Display;
use serde_with::skip_serializing_none;
use crate::api::open_ai::{
ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType,
@ -176,6 +177,13 @@ impl Display for LlmProviderType {
}
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug)]
pub struct ModelUsagePreference {
pub model: String,
pub usage: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRoute {
pub name: String,

View file

@ -101,6 +101,7 @@ impl OpenAIRequestBuilder {
frequency_penalty: self.frequency_penalty,
stream_options: self.stream_options,
tools: self.tools,
metadata: None,
};
Ok(request)
}

View file

@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::fmt::Display;
use serde::{Deserialize, Serialize};
@ -109,6 +110,7 @@ pub struct ChatCompletionsRequest {
pub frequency_penalty: Option<f32>,
pub stream_options: Option<StreamOptions>,
pub tools: Option<Vec<Value>>,
pub metadata: Option<HashMap<String, Value>>,
}
impl TryFrom<&[u8]> for ChatCompletionsRequest {