mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
more updates
This commit is contained in:
parent
e7eb77383f
commit
7f90124bd1
29 changed files with 375 additions and 133 deletions
21
crates/.vscode/launch.json
vendored
Normal file
21
crates/.vscode/launch.json
vendored
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Debug Brightstaff",
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/target/debug/brightstaff",
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"stopOnEntry": false,
|
||||
"sourceLanguages": ["rust"],
|
||||
"env": {
|
||||
"RUST_LOG": "debug",
|
||||
"RUST_BACKTRACE": "1",
|
||||
"ARCH_CONFIG_PATH_RENDERED": "../demos/use_cases/preference_based_routing/arch_config_rendered.yaml"
|
||||
},
|
||||
"preLaunchTask": "rust: cargo build"
|
||||
}
|
||||
]
|
||||
}
|
||||
21
crates/.vscode/tasks.json
vendored
Normal file
21
crates/.vscode/tasks.json
vendored
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
{
|
||||
"version": "2.0.0",
|
||||
"tasks": [
|
||||
{
|
||||
"type": "cargo",
|
||||
"command": "build",
|
||||
"args": [
|
||||
"--bin",
|
||||
"brightstaff"
|
||||
],
|
||||
"problemMatcher": [
|
||||
"$rustc"
|
||||
],
|
||||
"group": {
|
||||
"kind": "build",
|
||||
"isDefault": true
|
||||
},
|
||||
"label": "rust: cargo build"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -12,7 +12,7 @@ use hyper::{Request, Response, StatusCode};
|
|||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{debug, info, trace, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::router::llm_router::RouterService;
|
||||
|
||||
|
|
@ -81,8 +81,8 @@ pub async fn chat_completions(
|
|||
}
|
||||
}
|
||||
|
||||
trace!(
|
||||
"arch-router request body: {}",
|
||||
debug!(
|
||||
"arch-router request received: {}",
|
||||
&serde_json::to_string(&chat_completion_request).unwrap()
|
||||
);
|
||||
|
||||
|
|
@ -102,7 +102,7 @@ pub async fn chat_completions(
|
|||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok());
|
||||
|
||||
debug!("usage preferences: {:?}", usage_preferences);
|
||||
debug!("usage preferences from request: {:?}", usage_preferences);
|
||||
|
||||
let mut determined_route = match router_service
|
||||
.determine_route(
|
||||
|
|
|
|||
|
|
@ -44,6 +44,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let _tracer_provider = init_tracer();
|
||||
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
|
||||
|
||||
info!(
|
||||
"current working directory: {}",
|
||||
env::current_dir().unwrap().display()
|
||||
);
|
||||
// loading arch_config.yaml file
|
||||
let arch_config_path = env::var("ARCH_CONFIG_PATH_RENDERED")
|
||||
.unwrap_or_else(|_| "./arch_config_rendered.yaml".to_string());
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use common::{
|
||||
configuration::{LlmProvider, LlmRoute, ModelUsagePreference},
|
||||
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
|
||||
consts::ARCH_PROVIDER_HINT_HEADER,
|
||||
};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
|
||||
|
|
@ -44,11 +44,14 @@ impl RouterService {
|
|||
) -> Self {
|
||||
let providers_with_usage = providers
|
||||
.iter()
|
||||
.filter(|provider| provider.usage.is_some())
|
||||
.filter(|provider| provider.routing_preferences.is_some())
|
||||
.cloned()
|
||||
.collect::<Vec<LlmProvider>>();
|
||||
|
||||
let llm_routes: Vec<LlmRoute> = providers_with_usage.iter().map(LlmRoute::from).collect();
|
||||
let llm_routes: Vec<RoutingPreference> = providers_with_usage
|
||||
.iter()
|
||||
.flat_map(|provider| provider.routing_preferences.clone().unwrap_or_default())
|
||||
.collect();
|
||||
|
||||
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
|
||||
llm_routes,
|
||||
|
|
@ -156,6 +159,12 @@ impl RouterService {
|
|||
router_response_time.as_millis()
|
||||
);
|
||||
|
||||
if let Some(ref route) = route_name {
|
||||
if route == "other" {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(route_name)
|
||||
} else {
|
||||
Ok(None)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use common::{
|
||||
configuration::{LlmRoute, ModelUsagePreference},
|
||||
configuration::{ModelUsagePreference, RoutingPreference},
|
||||
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
|
||||
};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
|
||||
|
|
@ -36,7 +36,11 @@ pub struct RouterModelV1 {
|
|||
max_token_length: usize,
|
||||
}
|
||||
impl RouterModelV1 {
|
||||
pub fn new(llm_routes: Vec<LlmRoute>, routing_model: String, max_token_length: usize) -> Self {
|
||||
pub fn new(
|
||||
llm_routes: Vec<RoutingPreference>,
|
||||
routing_model: String,
|
||||
max_token_length: usize,
|
||||
) -> Self {
|
||||
let llm_route_json_str =
|
||||
serde_json::to_string(&llm_routes).unwrap_or_else(|_| "[]".to_string());
|
||||
RouterModelV1 {
|
||||
|
|
@ -138,9 +142,9 @@ impl RouterModel for RouterModelV1 {
|
|||
let llm_route_json = usage_preferences
|
||||
.as_ref()
|
||||
.map(|prefs| {
|
||||
let llm_route: Vec<LlmRoute> = prefs
|
||||
let llm_route: Vec<RoutingPreference> = prefs
|
||||
.iter()
|
||||
.map(|pref| LlmRoute {
|
||||
.map(|pref| RoutingPreference {
|
||||
name: pref.name.clone(),
|
||||
description: pref.usage.clone().unwrap_or_default(),
|
||||
})
|
||||
|
|
@ -255,7 +259,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
|
|
@ -314,7 +318,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
|
|
@ -379,7 +383,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235);
|
||||
|
||||
|
|
@ -440,7 +444,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200);
|
||||
|
||||
|
|
@ -501,7 +505,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230);
|
||||
|
||||
|
|
@ -569,7 +573,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
|
|
@ -639,7 +643,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
|
|
@ -716,7 +720,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
|
||||
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
|
||||
|
||||
|
|
|
|||
|
|
@ -187,24 +187,11 @@ pub struct ModelUsagePreference {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LlmRoute {
|
||||
pub struct RoutingPreference {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
impl From<&LlmProvider> for LlmRoute {
|
||||
fn from(provider: &LlmProvider) -> Self {
|
||||
Self {
|
||||
name: provider.name.to_string(),
|
||||
description: provider
|
||||
.usage
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "No description available".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
//TODO: use enum for model, but if there is a new model, we need to update the code
|
||||
pub struct LlmProvider {
|
||||
|
|
@ -218,6 +205,7 @@ pub struct LlmProvider {
|
|||
pub port: Option<u16>,
|
||||
pub rate_limits: Option<LlmRatelimit>,
|
||||
pub usage: Option<String>,
|
||||
pub routing_preferences: Option<Vec<RoutingPreference>>,
|
||||
}
|
||||
|
||||
pub trait IntoModels {
|
||||
|
|
@ -256,6 +244,7 @@ impl Default for LlmProvider {
|
|||
port: None,
|
||||
rate_limits: None,
|
||||
usage: None,
|
||||
routing_preferences: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -368,7 +357,7 @@ mod test {
|
|||
#[test]
|
||||
fn test_deserialize_configuration() {
|
||||
let ref_config = fs::read_to_string(
|
||||
"../../docs/source/resources/includes/arch_config_full_reference.yaml",
|
||||
"../../docs/source/resources/includes/arch_config_full_reference_rendered.yaml",
|
||||
)
|
||||
.expect("reference config file not found");
|
||||
|
||||
|
|
@ -429,7 +418,7 @@ mod test {
|
|||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let ref_config = fs::read_to_string(
|
||||
"../../docs/source/resources/includes/arch_config_full_reference.yaml",
|
||||
"../../docs/source/resources/includes/arch_config_full_reference_rendered.yaml",
|
||||
)
|
||||
.expect("reference config file not found");
|
||||
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
|
||||
|
|
|
|||
|
|
@ -58,7 +58,16 @@ impl TryFrom<Vec<LlmProvider>> for LlmProviders {
|
|||
let name = llm_provider.name.clone();
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(name.clone(), llm_provider)
|
||||
.insert(name.clone(), llm_provider.clone())
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
}
|
||||
|
||||
// also add model_id as key for provider lookup
|
||||
if llm_providers
|
||||
.providers
|
||||
.insert(llm_provider.model.clone().unwrap(), llm_provider)
|
||||
.is_some()
|
||||
{
|
||||
return Err(LlmProvidersNewError::DuplicateName(name));
|
||||
|
|
|
|||
|
|
@ -113,16 +113,10 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
debug!(
|
||||
"request received: llm provider hint: {}, selected llm: {}, model: {}",
|
||||
"request received: llm provider hint: {}, selected provider: {}",
|
||||
self.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
|
||||
.unwrap_or_default(),
|
||||
self.llm_provider.as_ref().unwrap().name,
|
||||
self.llm_provider
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.model
|
||||
.as_ref()
|
||||
.unwrap_or(&String::new())
|
||||
self.llm_provider.as_ref().unwrap().name
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -313,6 +307,11 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"on_http_request_body: deserialized body: {}",
|
||||
serde_json::to_string(&deserialized_body).unwrap_or_default()
|
||||
);
|
||||
|
||||
self.user_message = deserialized_body
|
||||
.messages
|
||||
.iter()
|
||||
|
|
@ -349,8 +348,8 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
info!(
|
||||
"on_http_request_body: provider: {}, model requested: {}, model selected: {}",
|
||||
self.llm_provider().name,
|
||||
"on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}",
|
||||
self.llm_provider().provider_interface,
|
||||
model_requested,
|
||||
model_name.unwrap_or(&"None".to_string()),
|
||||
);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue