mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add more changes
This commit is contained in:
parent
3f21e29703
commit
60b1bdca06
3 changed files with 141 additions and 4 deletions
|
|
@ -7,10 +7,10 @@ use serde_json;
|
|||
use std::sync::Arc;
|
||||
|
||||
pub async fn list_models(
|
||||
llm_providers: Arc<Vec<LlmProvider>>,
|
||||
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let prov = llm_providers.clone();
|
||||
let providers = (*prov).clone();
|
||||
let prov = llm_providers.read().await;
|
||||
let providers = prov.clone();
|
||||
let openai_models: Models = providers.into_models();
|
||||
|
||||
match serde_json::to_string(&openai_models) {
|
||||
|
|
|
|||
136
crates/brightstaff/src/handlers/preferences.rs
Normal file
136
crates/brightstaff/src/handlers/preferences.rs
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::LlmProvider;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
use tracing::{info, warn};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct UsageBasedProvider {
|
||||
model: String,
|
||||
usage: String,
|
||||
}
|
||||
|
||||
pub async fn list_preferences(
|
||||
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
|
||||
) -> Response<BoxBody<Bytes, hyper::Error>> {
|
||||
let prov = llm_providers.read().await;
|
||||
let providers_with_usage = prov
|
||||
.iter()
|
||||
.filter(|provider| provider.usage.is_some())
|
||||
.map(|provider| UsageBasedProvider {
|
||||
model: provider.name.clone(),
|
||||
usage: provider.usage.as_ref().unwrap().clone(),
|
||||
})
|
||||
.collect::<Vec<UsageBasedProvider>>();
|
||||
|
||||
match serde_json::to_string(&providers_with_usage) {
|
||||
Ok(json) => {
|
||||
let body = Full::new(Bytes::from(json))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
Err(_) => {
|
||||
let body = Full::new(Bytes::from_static(
|
||||
b"{\"error\":\"Failed to serialize models\"}",
|
||||
))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(body)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update_preferences(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
info!("Updating preferences...");
|
||||
let request_body = request.collect().await?.to_bytes();
|
||||
|
||||
let usage: Vec<UsageBasedProvider> = match serde_json::from_slice(&request_body) {
|
||||
Ok(usage) => usage,
|
||||
Err(_) => {
|
||||
let response_body = Full::new(Bytes::from_static(b"Invalid request body: "))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(response_body)
|
||||
.unwrap());
|
||||
}
|
||||
};
|
||||
|
||||
let usage_model_map: HashMap<String, UsageBasedProvider> =
|
||||
usage.into_iter().map(|u| (u.model.clone(), u)).collect();
|
||||
|
||||
let mut llm_providers = llm_providers.write().await;
|
||||
|
||||
// ensure that models coming in the request are valid
|
||||
let llm_provider_names: Vec<String> = llm_providers
|
||||
.iter()
|
||||
.map(|provider| provider.name.clone())
|
||||
.collect();
|
||||
|
||||
for model in usage_model_map.keys() {
|
||||
if !llm_provider_names.contains(model) {
|
||||
let model_not_found = format!("model not found: {}", model);
|
||||
warn!("updating preferences: {}", model_not_found);
|
||||
let response_body = Full::new(model_not_found.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(response_body)
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
let mut updated_models_list = Vec::new();
|
||||
for provider in llm_providers.iter_mut() {
|
||||
if let Some(usage_provider) = usage_model_map.get(&provider.name) {
|
||||
provider.usage = Some(usage_provider.usage.clone());
|
||||
updated_models_list.push(UsageBasedProvider {
|
||||
model: provider.name.clone(),
|
||||
usage: provider.usage.clone().unwrap_or_default(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if !updated_models_list.is_empty() {
|
||||
// return list of updated models
|
||||
let response_body = Full::new(Bytes::from(format!(
|
||||
"{{\"updated_models\": {}}}",
|
||||
serde_json::to_string(&updated_models_list).unwrap()
|
||||
)))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
return Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(response_body)
|
||||
.unwrap());
|
||||
} else {
|
||||
let response_body = Full::new(Bytes::from_static(b"Provider not found"))
|
||||
.map_err(|never| match never {})
|
||||
.boxed();
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(response_body)
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
|
@ -17,6 +17,7 @@ use opentelemetry_http::HeaderExtractor;
|
|||
use std::sync::Arc;
|
||||
use std::{env, fs};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
pub mod router;
|
||||
|
|
@ -54,7 +55,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let arch_config = Arc::new(config);
|
||||
|
||||
let llm_providers = Arc::new(arch_config.llm_providers.clone());
|
||||
let llm_providers = Arc::new(RwLock::new(arch_config.llm_providers.clone()));
|
||||
|
||||
debug!(
|
||||
"arch_config: {:?}",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue