fix rust tests

This commit is contained in:
Adil Hafeez 2025-05-14 17:15:42 -07:00
parent f60cac27f4
commit 0e2f53426a
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
7 changed files with 165 additions and 26 deletions

View file

@ -72,7 +72,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
//loading arch_config.yaml file
let arch_config_path =
env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "arch_config.yaml".to_string());
env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "./arch_config.yaml".to_string());
info!("Loading arch_config.yaml from {}", arch_config_path);
let config_contents =
@ -88,14 +88,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
shorten_string(&serde_json::to_string(arch_config.as_ref()).unwrap())
);
let llm_provider_endpoint = env::var("LLM_PROVIDER_ENDPOINT")
.unwrap_or_else(|_| "http://localhost:12000/v1/chat/completions".to_string());
info!("llm provider endpoint: {}", llm_provider_endpoint);
info!("Listening on http://{}", bind_address);
let listener = TcpListener::bind(bind_address).await?;
let llm_provider_endpoint = "http://localhost:12000/v1/chat/completions";
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
arch_config.llm_providers.clone(),
llm_provider_endpoint.to_string(),
llm_provider_endpoint.clone(),
arch_config.routing.as_ref().unwrap().model.clone(),
));
@ -105,6 +107,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let io = TokioIo::new(stream);
let router_service = Arc::clone(&router_service);
let llm_provider_endpoint = llm_provider_endpoint.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
@ -115,11 +118,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.span_builder("router_service")
.with_kind(SpanKind::Server)
.start_with_context(tracer, &parent_cx);
let llm_provider_endpoint = llm_provider_endpoint.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, "/v1/chat/completions") => {
chat_completion(req, router_service, llm_provider_endpoint.to_string())
chat_completion(req, router_service, llm_provider_endpoint)
.with_context(parent_cx)
.await
}

View file

@ -327,10 +327,10 @@ impl HttpContext for StreamContext {
let model_requested = deserialized_body.model.clone();
info!(
"on_http_request_body: provider: {}, model requested: {}, model selected: {:?}",
"on_http_request_body: provider: {}, model requested: {}, model selected: {}",
self.llm_provider().name,
model_requested,
self.llm_provider().model,
self.llm_provider().model.as_ref().unwrap_or(&String::new())
);
deserialized_body.model = self.llm_provider().model.clone().unwrap();

View file

@ -489,7 +489,6 @@ fn llm_gateway_override_model_name() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_record("input_sequence_length", 29)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)