mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
In request path use same format for usage preferences as arch_config (#533)
This commit is contained in:
parent
79a62fffe8
commit
d341f4365b
5 changed files with 83 additions and 187 deletions
|
|
@ -1,3 +1,2 @@
|
|||
pub mod chat_completions;
|
||||
pub mod models;
|
||||
pub mod preferences;
|
||||
|
|
|
|||
|
|
@ -1,135 +0,0 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{LlmProvider, ModelUsagePreference};
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use serde_json;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub async fn list_preferences(
|
||||
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let prov = llm_providers.read().await;
|
||||
// convert the LlmProvider to UsageBasedProvider
|
||||
let providers_with_usage = prov
|
||||
.iter()
|
||||
.map(|provider| ModelUsagePreference {
|
||||
name: provider.name.clone(),
|
||||
model: provider.model.clone().unwrap_or_default(),
|
||||
usage: provider.usage.clone(),
|
||||
})
|
||||
.collect::<Vec<ModelUsagePreference>>();
|
||||
|
||||
match serde_json::to_string(&providers_with_usage) {
|
||||
Ok(json) => {
|
||||
let body = Full::new(Bytes::from(json))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
Err(_) => {
|
||||
let body = Full::new(Bytes::from_static(
|
||||
b"{\"error\":\"Failed to serialize models\"}",
|
||||
))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update_preferences(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_body = request.collect().await?.to_bytes();
|
||||
|
||||
let usage: Vec<ModelUsagePreference> = match serde_json::from_slice(&request_body) {
|
||||
Ok(usage) => usage,
|
||||
Err(_) => {
|
||||
let response_body = Full::new(Bytes::from_static(b"Invalid request body: "))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(response_body)
|
||||
.unwrap());
|
||||
}
|
||||
};
|
||||
|
||||
let usage_model_map: HashMap<String, ModelUsagePreference> =
|
||||
usage.into_iter().map(|u| (u.model.clone(), u)).collect();
|
||||
|
||||
info!(
|
||||
"Updating usage preferences for models: {:?}",
|
||||
usage_model_map.keys()
|
||||
);
|
||||
|
||||
let mut llm_providers = llm_providers.write().await;
|
||||
|
||||
// ensure that models coming in the request are valid
|
||||
let llm_provider_names: Vec<String> = llm_providers
|
||||
.iter()
|
||||
.map(|provider| provider.name.clone())
|
||||
.collect();
|
||||
|
||||
for model in usage_model_map.keys() {
|
||||
if !llm_provider_names.contains(model) {
|
||||
let model_not_found = format!("model not found: {}", model);
|
||||
warn!("updating preferences: {}", model_not_found);
|
||||
let response_body = Full::new(model_not_found.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(response_body)
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
let mut updated_models_list = Vec::new();
|
||||
for provider in llm_providers.iter_mut() {
|
||||
if let Some(usage_provider) = usage_model_map.get(&provider.name) {
|
||||
provider.usage = usage_provider.usage.clone();
|
||||
updated_models_list.push(ModelUsagePreference {
|
||||
name: provider.name.clone(),
|
||||
model: provider.model.clone().unwrap_or_default(),
|
||||
usage: provider.usage.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if !updated_models_list.is_empty() {
|
||||
// return list of updated models
|
||||
let response_body = Full::new(Bytes::from(format!(
|
||||
"{{\"updated_models\": {}}}",
|
||||
serde_json::to_string(&updated_models_list).unwrap()
|
||||
)))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(response_body)
|
||||
.unwrap())
|
||||
} else {
|
||||
let response_body = Full::new(Bytes::from_static(b"Provider not found"))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(response_body)
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
use brightstaff::handlers::chat_completions::chat_completions;
|
||||
use brightstaff::handlers::models::list_models;
|
||||
use brightstaff::handlers::preferences::{list_preferences, update_preferences};
|
||||
use brightstaff::router::llm_router::RouterService;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
|
|
@ -116,12 +115,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
.with_context(parent_cx)
|
||||
.await
|
||||
}
|
||||
(&Method::GET, "/v1/router/preferences") => {
|
||||
Ok(list_preferences(llm_providers).await)
|
||||
}
|
||||
(&Method::PUT, "/v1/router/preferences") => {
|
||||
update_preferences(req, llm_providers).await
|
||||
}
|
||||
(&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await),
|
||||
(&Method::OPTIONS, "/v1/models") => {
|
||||
let mut response = Response::new(empty());
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ impl RouterModel for RouterModelV1 {
|
|||
fn generate_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||
usage_preferences_from_request: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> ChatCompletionsRequest {
|
||||
// remove system prompt, tool calls, tool call response and messages without content
|
||||
// if content is empty its likely a tool call
|
||||
|
|
@ -150,31 +150,17 @@ impl RouterModel for RouterModelV1 {
|
|||
})
|
||||
.collect::<Vec<Message>>();
|
||||
|
||||
let llm_route_json = usage_preferences
|
||||
.as_ref()
|
||||
.map(|prefs| {
|
||||
let llm_route: Vec<RoutingPreference> = prefs
|
||||
.iter()
|
||||
.map(|pref| RoutingPreference {
|
||||
name: pref.name.clone(),
|
||||
description: pref.usage.clone().unwrap_or_default(),
|
||||
})
|
||||
.collect();
|
||||
serde_json::to_string(&llm_route).unwrap_or_default()
|
||||
})
|
||||
.unwrap_or_else(|| self.llm_route_json_str.clone());
|
||||
|
||||
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
|
||||
.replace("{routes}", &llm_route_json)
|
||||
.replace(
|
||||
"{conversation}",
|
||||
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
|
||||
);
|
||||
// Generate the router request message based on the usage preferences.
|
||||
// If preferences are passed in request then we use them otherwise we use the default routing model preferences.
|
||||
let router_message = match convert_to_router_preferences(usage_preferences_from_request) {
|
||||
Some(prefs) => generate_router_message(&prefs, &selected_conversation_list),
|
||||
None => generate_router_message(&self.llm_route_json_str, &selected_conversation_list),
|
||||
};
|
||||
|
||||
ChatCompletionsRequest {
|
||||
model: self.routing_model.clone(),
|
||||
messages: vec![Message {
|
||||
content: Some(ContentType::Text(messages_content)),
|
||||
content: Some(ContentType::Text(router_message)),
|
||||
role: USER_ROLE.to_string(),
|
||||
}],
|
||||
temperature: Some(0.01),
|
||||
|
|
@ -201,12 +187,18 @@ impl RouterModel for RouterModelV1 {
|
|||
|
||||
if let Some(usage_preferences) = usage_preferences {
|
||||
// If usage preferences are defined, we need to find the model that matches the selected route
|
||||
let matching_preference = usage_preferences
|
||||
let model_name: Option<String> = usage_preferences
|
||||
.iter()
|
||||
.find(|pref| pref.name == selected_route);
|
||||
.map(|pref| {
|
||||
pref.routing_preferences
|
||||
.iter()
|
||||
.find(|routing_pref| routing_pref.name == selected_route)
|
||||
.map(|_| pref.model.clone())
|
||||
})
|
||||
.find_map(|model| model);
|
||||
|
||||
if let Some(preference) = matching_preference {
|
||||
return Ok(Some((selected_route, preference.model.clone())));
|
||||
if let Some(model_name) = model_name {
|
||||
return Ok(Some((selected_route, model_name)));
|
||||
} else {
|
||||
warn!(
|
||||
"No matching model found for route: {}, usage preferences: {:?}",
|
||||
|
|
@ -216,7 +208,7 @@ impl RouterModel for RouterModelV1 {
|
|||
}
|
||||
}
|
||||
|
||||
// If no usage preferences are defined, we return the route with the routing model
|
||||
// If no usage preferences are passed in request then use the default routing model preferences
|
||||
if let Some(model) = self.llm_route_to_model_map.get(&selected_route).cloned() {
|
||||
return Ok(Some((selected_route, model)));
|
||||
}
|
||||
|
|
@ -234,6 +226,37 @@ impl RouterModel for RouterModelV1 {
|
|||
}
|
||||
}
|
||||
|
||||
fn generate_router_message(prefs: &str, selected_conversation_list: &Vec<Message>) -> String {
|
||||
ARCH_ROUTER_V1_SYSTEM_PROMPT
|
||||
.replace("{routes}", prefs)
|
||||
.replace(
|
||||
"{conversation}",
|
||||
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn convert_to_router_preferences(
|
||||
prefs_from_request: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> Option<String> {
|
||||
if let Some(usage_preferences) = prefs_from_request {
|
||||
let routing_preferences = usage_preferences
|
||||
.iter()
|
||||
.flat_map(|pref| {
|
||||
pref.routing_preferences
|
||||
.iter()
|
||||
.map(|routing_pref| RoutingPreference {
|
||||
name: routing_pref.name.clone(),
|
||||
description: routing_pref.description.clone(),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<RoutingPreference>>();
|
||||
|
||||
return Some(serde_json::to_string(&routing_preferences).unwrap_or_default());
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn fix_json_response(body: &str) -> String {
|
||||
let mut updated_body = body.to_string();
|
||||
|
||||
|
|
@ -299,7 +322,8 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
|
|
@ -356,7 +380,8 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
|
|
@ -379,9 +404,11 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let usage_preferences = Some(vec![ModelUsagePreference {
|
||||
name: "code-generation".to_string(),
|
||||
model: "claude/claude-3-7-sonnet".to_string(),
|
||||
usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: "code-generation".to_string(),
|
||||
description: "generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string(),
|
||||
}],
|
||||
}]);
|
||||
let req = router.generate_request(&conversation, &usage_preferences);
|
||||
|
||||
|
|
@ -419,7 +446,8 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235);
|
||||
|
||||
|
|
@ -478,7 +506,8 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200);
|
||||
|
|
@ -538,7 +567,8 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230);
|
||||
|
||||
|
|
@ -604,7 +634,8 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
|
|
@ -672,7 +703,8 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
|
|
@ -747,14 +779,18 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
|
||||
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
|
||||
|
||||
// Case 1: Valid JSON with non-empty route
|
||||
let input = r#"{"route": "Image generation"}"#;
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("Image generation".to_string(), "gpt-4o".to_string()))
|
||||
);
|
||||
|
||||
// Case 2: Valid JSON with empty route
|
||||
let input = r#"{"route": ""}"#;
|
||||
|
|
@ -784,11 +820,17 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
// Case 6: Single quotes and \n in JSON
|
||||
let input = "{'route': 'Image generation'}\\n";
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("Image generation".to_string(), "gpt-4o".to_string()))
|
||||
);
|
||||
|
||||
// Case 7: Code block marker
|
||||
let input = "```json\n{\"route\": \"Image generation\"}\n```";
|
||||
let result = router.parse_response(input, &None).unwrap();
|
||||
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(("Image generation".to_string(), "gpt-4o".to_string()))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::skip_serializing_none;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
|
||||
|
|
@ -178,12 +177,10 @@ impl Display for LlmProviderType {
|
|||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ModelUsagePreference {
|
||||
pub name: String,
|
||||
pub model: String,
|
||||
pub usage: Option<String>,
|
||||
pub routing_preferences: Vec<RoutingPreference>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue