add more changes

This commit is contained in:
Adil Hafeez 2025-06-24 13:30:12 -07:00
parent 3f21e29703
commit 60b1bdca06
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
3 changed files with 141 additions and 4 deletions

View file

@ -7,10 +7,10 @@ use serde_json;
use std::sync::Arc;
pub async fn list_models(
llm_providers: Arc<Vec<LlmProvider>>,
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let prov = llm_providers.clone();
let providers = (*prov).clone();
let prov = llm_providers.read().await;
let providers = prov.clone();
let openai_models: Models = providers.into_models();
match serde_json::to_string(&openai_models) {

View file

@ -0,0 +1,136 @@
use bytes::Bytes;
use common::configuration::LlmProvider;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::{Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json;
use tracing::{info, warn};
use std::{collections::HashMap, sync::Arc};
#[derive(Serialize, Deserialize)]
struct UsageBasedProvider {
model: String,
usage: String,
}
pub async fn list_preferences(
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let prov = llm_providers.read().await;
let providers_with_usage = prov
.iter()
.filter(|provider| provider.usage.is_some())
.map(|provider| UsageBasedProvider {
model: provider.name.clone(),
usage: provider.usage.as_ref().unwrap().clone(),
})
.collect::<Vec<UsageBasedProvider>>();
match serde_json::to_string(&providers_with_usage) {
Ok(json) => {
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap()
}
Err(_) => {
let body = Full::new(Bytes::from_static(
b"{\"error\":\"Failed to serialize models\"}",
))
.map_err(|never| match never {})
.boxed();
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("Content-Type", "application/json")
.body(body)
.unwrap()
}
}
}
pub async fn update_preferences(
request: Request<hyper::body::Incoming>,
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
info!("Updating preferences...");
let request_body = request.collect().await?.to_bytes();
let usage: Vec<UsageBasedProvider> = match serde_json::from_slice(&request_body) {
Ok(usage) => usage,
Err(_) => {
let response_body = Full::new(Bytes::from_static(b"Invalid request body: "))
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap());
}
};
let usage_model_map: HashMap<String, UsageBasedProvider> =
usage.into_iter().map(|u| (u.model.clone(), u)).collect();
let mut llm_providers = llm_providers.write().await;
// ensure that models coming in the request are valid
let llm_provider_names: Vec<String> = llm_providers
.iter()
.map(|provider| provider.name.clone())
.collect();
for model in usage_model_map.keys() {
if !llm_provider_names.contains(model) {
let model_not_found = format!("model not found: {}", model);
warn!("updating preferences: {}", model_not_found);
let response_body = Full::new(model_not_found.into())
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap());
}
}
let mut updated_models_list = Vec::new();
for provider in llm_providers.iter_mut() {
if let Some(usage_provider) = usage_model_map.get(&provider.name) {
provider.usage = Some(usage_provider.usage.clone());
updated_models_list.push(UsageBasedProvider {
model: provider.name.clone(),
usage: provider.usage.clone().unwrap_or_default(),
});
}
}
if !updated_models_list.is_empty() {
// return list of updated models
let response_body = Full::new(Bytes::from(format!(
"{{\"updated_models\": {}}}",
serde_json::to_string(&updated_models_list).unwrap()
)))
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(response_body)
.unwrap());
} else {
let response_body = Full::new(Bytes::from_static(b"Provider not found"))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap())
}
}

View file

@ -17,6 +17,7 @@ use opentelemetry_http::HeaderExtractor;
use std::sync::Arc;
use std::{env, fs};
use tokio::net::TcpListener;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
pub mod router;
@ -54,7 +55,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let arch_config = Arc::new(config);
let llm_providers = Arc::new(arch_config.llm_providers.clone());
let llm_providers = Arc::new(RwLock::new(arch_config.llm_providers.clone()));
debug!(
"arch_config: {:?}",