add tests

This commit is contained in:
Adil Hafeez 2025-05-15 14:20:43 -07:00
parent 37a7eab3c0
commit 966588bfef
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
5 changed files with 105 additions and 6 deletions

View file

@ -24,7 +24,7 @@ jobs:
- name: build arch docker image
run: |
cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.2.8
cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.2.8 -t katanemo/archgw:latest
- name: start archgw
env:

1
crates/Cargo.lock generated
View file

@ -205,6 +205,7 @@ dependencies = [
"opentelemetry-otlp",
"opentelemetry-stdout",
"opentelemetry_sdk",
"pretty_assertions",
"reqwest",
"serde",
"serde_json",

View file

@ -19,6 +19,7 @@ opentelemetry-http = "0.29.0"
opentelemetry-otlp = "0.29.0"
opentelemetry-stdout = "0.29.0"
opentelemetry_sdk = "0.29.0"
pretty_assertions = "1.4.1"
reqwest = { version = "0.12.15", features = ["stream"] }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"

View file

@ -1,6 +1,6 @@
use common::{
api::open_ai::{ChatCompletionsRequest, Message},
consts::USER_ROLE,
consts::{SYSTEM_ROLE, USER_ROLE},
};
use serde::{Deserialize, Serialize};
use tracing::info;
@ -63,12 +63,19 @@ struct LlmRouterResponse {
impl RouterModel for RouterModelV1 {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
let messages_str = messages
.iter()
.filter(|m| m.role != SYSTEM_ROLE)
.map(|m| {
let content_json_str = serde_json::to_string(&m.content).unwrap_or_default();
format!("{}: {}", m.role, content_json_str)
})
.collect::<Vec<String>>()
.join("\n");
let message = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &self.llm_providers_with_usage_yaml)
.replace(
"{conversation}",
&serde_json::to_string_pretty(messages).unwrap(),
);
.replace("{conversation}", messages_str.as_str());
ChatCompletionsRequest {
model: self.routing_model.clone(),
@ -138,3 +145,81 @@ impl std::fmt::Debug for dyn RouterModel {
write!(f, "RouterModel")
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn test_system_prompt_format() {
let expected_prompt = r#"
You are an advanced Routing Assistant designed to select the optimal route based on user requests.
Your task is to analyze conversations and match them to the most appropriate predefined route.
Review the available routes config:
# ROUTES CONFIG START
route1: description1
route2: description2
# ROUTES CONFIG END
Examine the following conversation between a user and an assistant:
# CONVERSATION START
user: "Hello, I want to book a flight."
assistant: "Sure, where would you like to go?"
user: "seattle"
# CONVERSATION END
Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps:
1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario.
2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description).
3. Find the route that best matches.
4. Use context clues from the entire conversation to determine the best fit.
5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config.
6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''.
# OUTPUT FORMAT
Your final output must follow this JSON format:
{
"route": "route_name" # The matched route name, or empty string '' if no match
}
Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace.
"#;
let routes_yaml = "route1: description1\nroute2: description2";
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone());
let messages = vec![
Message {
role: "system".to_string(),
content: Some("You are a helpful assistant.".to_string()),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some("Hello, I want to book a flight.".to_string()),
..Default::default()
},
Message {
role: "assistant".to_string(),
content: Some("Sure, where would you like to go?".to_string()),
..Default::default()
},
Message {
role: "user".to_string(),
content: Some("seattle".to_string()),
..Default::default()
},
];
let req = router.generate_request(&messages);
let prompt = req.messages[0].content.as_ref().unwrap();
assert_eq!(expected_prompt, prompt);
}
}

View file

@ -171,6 +171,18 @@ pub struct Message {
pub tool_call_id: Option<String>,
}
impl Default for Message {
fn default() -> Self {
Message {
role: ASSISTANT_ROLE.to_string(),
content: None,
model: None,
tool_calls: None,
tool_call_id: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Choice {
pub finish_reason: Option<String>,