diff --git a/crates/brightstaff/src/handlers/models.rs b/crates/brightstaff/src/handlers/models.rs index 5e4b55b2..3a4662a6 100644 --- a/crates/brightstaff/src/handlers/models.rs +++ b/crates/brightstaff/src/handlers/models.rs @@ -7,10 +7,10 @@ use serde_json; use std::sync::Arc; pub async fn list_models( - llm_providers: Arc>, + llm_providers: Arc>>, ) -> Response> { - let prov = llm_providers.clone(); - let providers = (*prov).clone(); + let prov = llm_providers.read().await; + let providers = prov.clone(); let openai_models: Models = providers.into_models(); match serde_json::to_string(&openai_models) { diff --git a/crates/brightstaff/src/handlers/preferences.rs b/crates/brightstaff/src/handlers/preferences.rs new file mode 100644 index 00000000..4a478ad9 --- /dev/null +++ b/crates/brightstaff/src/handlers/preferences.rs @@ -0,0 +1,136 @@ +use bytes::Bytes; +use common::configuration::LlmProvider; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use hyper::{Request, Response, StatusCode}; +use serde::{Deserialize, Serialize}; +use serde_json; +use tracing::{info, warn}; +use std::{collections::HashMap, sync::Arc}; + +#[derive(Serialize, Deserialize)] +struct UsageBasedProvider { + model: String, + usage: String, +} + +pub async fn list_preferences( + llm_providers: Arc>>, +) -> Response> { + let prov = llm_providers.read().await; + let providers_with_usage = prov + .iter() + .filter(|provider| provider.usage.is_some()) + .map(|provider| UsageBasedProvider { + model: provider.name.clone(), + usage: provider.usage.as_ref().unwrap().clone(), + }) + .collect::>(); + + match serde_json::to_string(&providers_with_usage) { + Ok(json) => { + let body = Full::new(Bytes::from(json)) + .map_err(|never| match never {}) + .boxed(); + Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(body) + .unwrap() + } + Err(_) => { + let body = Full::new(Bytes::from_static( + b"{\"error\":\"Failed to serialize models\"}", + )) + .map_err(|never| match never {}) + .boxed(); + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header("Content-Type", "application/json") + .body(body) + .unwrap() + } + } +} + +pub async fn update_preferences( + request: Request, + llm_providers: Arc>>, +) -> Result>, hyper::Error> { + info!("Updating preferences..."); + let request_body = request.collect().await?.to_bytes(); + + let usage: Vec = match serde_json::from_slice(&request_body) { + Ok(usage) => usage, + Err(_) => { + let response_body = Full::new(Bytes::from_static(b"Invalid request body: ")) + .map_err(|never| match never {}) + .boxed(); + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .header("Content-Type", "text/plain") + .body(response_body) + .unwrap()); + } + }; + + let usage_model_map: HashMap = + usage.into_iter().map(|u| (u.model.clone(), u)).collect(); + + let mut llm_providers = llm_providers.write().await; + + // ensure that models coming in the request are valid + let llm_provider_names: Vec = llm_providers + .iter() + .map(|provider| provider.name.clone()) + .collect(); + + for model in usage_model_map.keys() { + if !llm_provider_names.contains(model) { + let model_not_found = format!("model not found: {}", model); + warn!("updating preferences: {}", model_not_found); + let response_body = Full::new(model_not_found.into()) + .map_err(|never| match never {}) + .boxed(); + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .header("Content-Type", "text/plain") + .body(response_body) + .unwrap()); + } + } + + let mut updated_models_list = Vec::new(); + for provider in llm_providers.iter_mut() { + if let Some(usage_provider) = usage_model_map.get(&provider.name) { + provider.usage = Some(usage_provider.usage.clone()); + updated_models_list.push(UsageBasedProvider { + model: provider.name.clone(), + usage: provider.usage.clone().unwrap_or_default(), + }); + } + } + + if !updated_models_list.is_empty() { + // return list of updated models + let response_body = Full::new(Bytes::from(format!( + "{{\"updated_models\": {}}}", + serde_json::to_string(&updated_models_list).unwrap() + ))) + .map_err(|never| match never {}) + .boxed(); + return Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(response_body) + .unwrap()); + } else { + let response_body = Full::new(Bytes::from_static(b"Provider not found")) + .map_err(|never| match never {}) + .boxed(); + Ok(Response::builder() + .status(StatusCode::NOT_FOUND) + .header("Content-Type", "text/plain") + .body(response_body) + .unwrap()) + } +} diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 30bfa11a..65b47cd5 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -17,6 +17,7 @@ use opentelemetry_http::HeaderExtractor; use std::sync::Arc; use std::{env, fs}; use tokio::net::TcpListener; +use tokio::sync::RwLock; use tracing::{debug, info, warn}; pub mod router; @@ -54,7 +55,7 @@ async fn main() -> Result<(), Box> { let arch_config = Arc::new(config); - let llm_providers = Arc::new(arch_config.llm_providers.clone()); + let llm_providers = Arc::new(RwLock::new(arch_config.llm_providers.clone())); debug!( "arch_config: {:?}",