mirror of
https://github.com/katanemo/plano.git
synced 2026-05-09 15:52:44 +02:00
compact and deduplicate test suite
- Extract generate_storage_tests! macro for shared CRUD tests across memory/postgresql backends - Move merge tests to mod.rs (testing default trait method once) - Consolidate signal analyzer tests into table-driven tests - Extract shared fixtures in router test files - Parametrize Python CLI tests - Remove dead tests (test_skip_version_check_env_var, test_arch_agent_config_default) - Extract SSE event test helpers in streaming_response
This commit is contained in:
parent
785bf7e021
commit
c4634d0034
10 changed files with 822 additions and 1490 deletions
|
|
@ -1481,12 +1481,6 @@ mod tests {
|
|||
assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arch_agent_config_default() {
|
||||
let config = ArchAgentConfig::default();
|
||||
assert_eq!(config.generation_params.temperature, 0.01); // Different from ArchFunctionConfig
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fix_json_string_valid() {
|
||||
let handler = ArchFunctionHandler::new(
|
||||
|
|
|
|||
|
|
@ -415,6 +415,17 @@ mod tests {
|
|||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn default_orchestrations() -> HashMap<String, Vec<OrchestrationPreference>> {
|
||||
serde_json::from_str(
|
||||
r#"{"gpt-4o": [{"name": "Image generation", "description": "generating image"}]}"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn default_conversation() -> Vec<Message> {
|
||||
serde_json::from_str(r#"[{"role": "user", "content": "hi"},{"role": "assistant", "content": "Hello! How can I assist you today?"},{"role": "user", "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]"#).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spaced_json_formatter() {
|
||||
// Test basic object
|
||||
|
|
@ -509,41 +520,12 @@ Return your answer strictly in JSON as follows:
|
|||
{{"route": ["route_name_1", "route_name_2", "..."]}}
|
||||
If no routes are needed, return an empty list for `route`.
|
||||
"#;
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model.clone(),
|
||||
default_orchestrations(),
|
||||
"test-model".to_string(),
|
||||
usize::MAX,
|
||||
);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
let conversation = default_conversation();
|
||||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
|
|
@ -591,31 +573,9 @@ Return your answer strictly in JSON as follows:
|
|||
If no routes are needed, return an empty list for `route`.
|
||||
"#;
|
||||
// Empty orchestrations map - not used when usage_preferences are provided
|
||||
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model.clone(),
|
||||
usize::MAX,
|
||||
);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(HashMap::new(), "test-model".to_string(), usize::MAX);
|
||||
let conversation = default_conversation();
|
||||
|
||||
let usage_preferences = Some(vec![AgentUsagePreference {
|
||||
model: "claude/claude-3-7-sonnet".to_string(),
|
||||
|
|
@ -662,38 +622,9 @@ Return your answer strictly in JSON as follows:
|
|||
If no routes are needed, return an empty list for `route`.
|
||||
"#;
|
||||
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model, 235);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(default_orchestrations(), "test-model".to_string(), 235);
|
||||
let conversation = default_conversation();
|
||||
|
||||
let req = orchestrator.generate_request(&conversation, &None);
|
||||
|
||||
|
|
@ -733,20 +664,8 @@ Return your answer strictly in JSON as follows:
|
|||
If no routes are needed, return an empty list for `route`.
|
||||
"#;
|
||||
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model, 200);
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(default_orchestrations(), "test-model".to_string(), 200);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -813,19 +732,8 @@ Return your answer strictly in JSON as follows:
|
|||
If no routes are needed, return an empty list for `route`.
|
||||
"#;
|
||||
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model, 230);
|
||||
let orchestrator =
|
||||
OrchestratorModelV1::new(default_orchestrations(), "test-model".to_string(), 230);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -899,21 +807,9 @@ Return your answer strictly in JSON as follows:
|
|||
{{"route": ["route_name_1", "route_name_2", "..."]}}
|
||||
If no routes are needed, return an empty list for `route`.
|
||||
"#;
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model.clone(),
|
||||
default_orchestrations(),
|
||||
"test-model".to_string(),
|
||||
usize::MAX,
|
||||
);
|
||||
|
||||
|
|
@ -991,21 +887,9 @@ Return your answer strictly in JSON as follows:
|
|||
{{"route": ["route_name_1", "route_name_2", "..."]}}
|
||||
If no routes are needed, return an empty list for `route`.
|
||||
"#;
|
||||
let orchestrations_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let agent_orchestrations = serde_json::from_str::<
|
||||
HashMap<String, Vec<OrchestrationPreference>>,
|
||||
>(orchestrations_str)
|
||||
.unwrap();
|
||||
let orchestration_model = "test-model".to_string();
|
||||
let orchestrator = OrchestratorModelV1::new(
|
||||
agent_orchestrations,
|
||||
orchestration_model.clone(),
|
||||
default_orchestrations(),
|
||||
"test-model".to_string(),
|
||||
usize::MAX,
|
||||
);
|
||||
|
||||
|
|
|
|||
|
|
@ -299,6 +299,17 @@ mod tests {
|
|||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn default_routes() -> HashMap<String, Vec<RoutingPreference>> {
|
||||
serde_json::from_str(
|
||||
r#"{"gpt-4o": [{"name": "Image generation", "description": "generating image"}]}"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn default_conversation() -> Vec<Message> {
|
||||
serde_json::from_str(r#"[{"role": "user", "content": "hi"},{"role": "assistant", "content": "Hello! How can I assist you today?"},{"role": "user", "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]"#).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_format() {
|
||||
let expected_prompt = r#"
|
||||
|
|
@ -320,35 +331,8 @@ Your task is to decide which route is best suit with user intent on the conversa
|
|||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
let router = RouterModelV1::new(default_routes(), "test-model".to_string(), usize::MAX);
|
||||
let conversation = default_conversation();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
|
|
@ -378,35 +362,8 @@ Your task is to decide which route is best suit with user intent on the conversa
|
|||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
let router = RouterModelV1::new(default_routes(), "test-model".to_string(), usize::MAX);
|
||||
let conversation = default_conversation();
|
||||
|
||||
let usage_preferences = Some(vec![ModelUsagePreference {
|
||||
model: "claude/claude-3-7-sonnet".to_string(),
|
||||
|
|
@ -444,36 +401,8 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, 235);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
let router = RouterModelV1::new(default_routes(), "test-model".to_string(), 235);
|
||||
let conversation = default_conversation();
|
||||
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
|
|
@ -504,18 +433,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, 200);
|
||||
let router = RouterModelV1::new(default_routes(), "test-model".to_string(), 200);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -565,17 +483,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"route": "route_name"}
|
||||
"#;
|
||||
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, 230);
|
||||
let router = RouterModelV1::new(default_routes(), "test-model".to_string(), 230);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -632,17 +540,7 @@ Your task is to decide which route is best suit with user intent on the conversa
|
|||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
let router = RouterModelV1::new(default_routes(), "test-model".to_string(), usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -701,17 +599,7 @@ Your task is to decide which route is best suit with user intent on the conversa
|
|||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX);
|
||||
let router = RouterModelV1::new(default_routes(), "test-model".to_string(), usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
|
|
@ -777,17 +665,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
#[test]
|
||||
fn test_parse_response() {
|
||||
let routes_str = r#"
|
||||
{
|
||||
"gpt-4o": [
|
||||
{"name": "Image generation", "description": "generating image"}
|
||||
]
|
||||
}
|
||||
"#;
|
||||
let llm_routes =
|
||||
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||
|
||||
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
|
||||
let router = RouterModelV1::new(default_routes(), "test-model".to_string(), 2000);
|
||||
|
||||
// Case 1: Valid JSON with non-empty route
|
||||
let input = r#"{"route": "Image generation"}"#;
|
||||
|
|
|
|||
|
|
@ -1959,145 +1959,109 @@ mod tests {
|
|||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_char_ngram_similarity_exact_match() {
|
||||
let msg = NormalizedMessage::from_text("thank you very much");
|
||||
let similarity = msg.char_ngram_similarity("thank you very much");
|
||||
assert!(
|
||||
similarity > 0.95,
|
||||
"Exact match should have very high similarity"
|
||||
);
|
||||
fn test_char_ngram_similarity() {
|
||||
let cases = [
|
||||
(
|
||||
"thank you very much",
|
||||
"thank you very much",
|
||||
0.95,
|
||||
"exact match",
|
||||
),
|
||||
("thank you very much", "thnks you very much", 0.50, "typo"),
|
||||
("this doesn't work", "this doesnt work", 0.70, "small edit"),
|
||||
(
|
||||
"i don't understand",
|
||||
"i really don't understand",
|
||||
0.40,
|
||||
"word insertion",
|
||||
),
|
||||
];
|
||||
for (msg_text, pattern, threshold, label) in cases {
|
||||
let msg = NormalizedMessage::from_text(msg_text);
|
||||
let similarity = msg.char_ngram_similarity(pattern);
|
||||
assert!(
|
||||
similarity > threshold,
|
||||
"{}: expected > {}, got {}",
|
||||
label,
|
||||
threshold,
|
||||
similarity
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_char_ngram_similarity_typo() {
|
||||
let msg = NormalizedMessage::from_text("thank you very much");
|
||||
// Common typo: "thnks" instead of "thanks"
|
||||
let similarity = msg.char_ngram_similarity("thnks you very much");
|
||||
assert!(
|
||||
similarity > 0.50,
|
||||
"Should handle single-character typo with decent similarity: {}",
|
||||
similarity
|
||||
);
|
||||
fn test_token_cosine_similarity() {
|
||||
let cases: Vec<(&str, &str, f64, f64, &str)> = vec![
|
||||
(
|
||||
"this is not helpful",
|
||||
"this is not helpful",
|
||||
0.99,
|
||||
1.01,
|
||||
"exact match",
|
||||
),
|
||||
(
|
||||
"not helpful at all",
|
||||
"helpful not at all",
|
||||
0.95,
|
||||
2.0,
|
||||
"word order",
|
||||
),
|
||||
(
|
||||
"help help help please",
|
||||
"help please",
|
||||
0.7,
|
||||
1.0,
|
||||
"frequency",
|
||||
),
|
||||
(
|
||||
"I've been trying to set up my account for the past hour \
|
||||
and the verification email never arrived. I checked my spam folder \
|
||||
and still nothing. This is really frustrating and not helpful at all.",
|
||||
"not helpful",
|
||||
0.15,
|
||||
0.7,
|
||||
"long message with context",
|
||||
),
|
||||
];
|
||||
for (msg_text, pattern, min, max, label) in cases {
|
||||
let msg = NormalizedMessage::from_text(msg_text);
|
||||
let similarity = msg.token_cosine_similarity(pattern);
|
||||
assert!(
|
||||
similarity > min && similarity < max,
|
||||
"{}: expected ({}, {}), got {}",
|
||||
label,
|
||||
min,
|
||||
max,
|
||||
similarity
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_char_ngram_similarity_small_edit() {
|
||||
let msg = NormalizedMessage::from_text("this doesn't work");
|
||||
let similarity = msg.char_ngram_similarity("this doesnt work");
|
||||
assert!(
|
||||
similarity > 0.70,
|
||||
"Should handle punctuation removal gracefully: {}",
|
||||
similarity
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_char_ngram_similarity_word_insertion() {
|
||||
let msg = NormalizedMessage::from_text("i don't understand");
|
||||
let similarity = msg.char_ngram_similarity("i really don't understand");
|
||||
assert!(
|
||||
similarity > 0.40,
|
||||
"Should be robust to word insertions: {}",
|
||||
similarity
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_cosine_similarity_exact_match() {
|
||||
let msg = NormalizedMessage::from_text("this is not helpful");
|
||||
let similarity = msg.token_cosine_similarity("this is not helpful");
|
||||
assert!(
|
||||
(similarity - 1.0).abs() < 0.01,
|
||||
"Exact match should have cosine similarity of 1.0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_cosine_similarity_word_order() {
|
||||
let msg = NormalizedMessage::from_text("not helpful at all");
|
||||
let similarity = msg.token_cosine_similarity("helpful not at all");
|
||||
assert!(
|
||||
similarity > 0.95,
|
||||
"Should be robust to word order changes: {}",
|
||||
similarity
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_cosine_similarity_frequency() {
|
||||
let msg = NormalizedMessage::from_text("help help help please");
|
||||
let similarity = msg.token_cosine_similarity("help please");
|
||||
assert!(
|
||||
similarity > 0.7 && similarity < 1.0,
|
||||
"Should account for frequency differences: {}",
|
||||
similarity
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_cosine_similarity_long_message_with_context() {
|
||||
let msg = NormalizedMessage::from_text(
|
||||
"I've been trying to set up my account for the past hour \
|
||||
and the verification email never arrived. I checked my spam folder \
|
||||
and still nothing. This is really frustrating and not helpful at all.",
|
||||
);
|
||||
let similarity = msg.token_cosine_similarity("not helpful");
|
||||
assert!(
|
||||
similarity > 0.15 && similarity < 0.7,
|
||||
"Should detect pattern in long message with lower but non-zero similarity: {}",
|
||||
similarity
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layered_matching_exact_hit() {
|
||||
let msg = NormalizedMessage::from_text("thank you so much");
|
||||
assert!(
|
||||
msg.layered_contains_phrase("thank you", 0.50, 0.60),
|
||||
"Should match exact phrase in Layer 0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layered_matching_typo_hit() {
|
||||
// Test that shows layered matching is more robust than exact matching alone
|
||||
let msg = NormalizedMessage::from_text("it doesnt work for me");
|
||||
|
||||
// "doesnt work" should match "doesn't work" via character ngrams (high overlap)
|
||||
assert!(
|
||||
msg.layered_contains_phrase("doesn't work", 0.50, 0.60),
|
||||
"Should match 'doesnt work' to 'doesn't work' via character ngrams"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layered_matching_word_order_hit() {
|
||||
let msg = NormalizedMessage::from_text("helpful not very");
|
||||
assert!(
|
||||
msg.layered_contains_phrase("not helpful", 0.50, 0.60),
|
||||
"Should match reordered words via token cosine in Layer 2"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layered_matching_long_message_with_pattern() {
|
||||
let msg = NormalizedMessage::from_text(
|
||||
"I've tried everything and followed all the instructions \
|
||||
but this is not helpful at all and I'm getting frustrated",
|
||||
);
|
||||
assert!(
|
||||
msg.layered_contains_phrase("not helpful", 0.50, 0.60),
|
||||
"Should detect pattern buried in long message"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layered_matching_no_match() {
|
||||
let msg = NormalizedMessage::from_text("everything is working perfectly");
|
||||
assert!(
|
||||
!msg.layered_contains_phrase("not helpful", 0.50, 0.60),
|
||||
"Should not match completely different content"
|
||||
);
|
||||
fn test_layered_matching() {
|
||||
let cases = [
|
||||
("thank you so much", "thank you", true, "exact hit"),
|
||||
("it doesnt work for me", "doesn't work", true, "typo hit"),
|
||||
("helpful not very", "not helpful", true, "word order hit"),
|
||||
(
|
||||
"I've tried everything and followed all the instructions \
|
||||
but this is not helpful at all and I'm getting frustrated",
|
||||
"not helpful",
|
||||
true,
|
||||
"long message with pattern",
|
||||
),
|
||||
(
|
||||
"everything is working perfectly",
|
||||
"not helpful",
|
||||
false,
|
||||
"no match",
|
||||
),
|
||||
];
|
||||
for (msg_text, pattern, expected, label) in cases {
|
||||
let msg = NormalizedMessage::from_text(msg_text);
|
||||
let result = msg.layered_contains_phrase(pattern, 0.50, 0.60);
|
||||
assert_eq!(result, expected, "{}: expected {}", label, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -2139,7 +2103,6 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_turn_count_efficient() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![
|
||||
create_message(Role::User, "Hello"),
|
||||
|
|
@ -2154,12 +2117,10 @@ mod tests {
|
|||
assert!(!signal.is_concerning);
|
||||
assert!(!signal.is_excessive);
|
||||
assert!(signal.efficiency_score > 0.9);
|
||||
println!("test_turn_count_efficient took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_turn_count_excessive() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let mut messages = Vec::new();
|
||||
for i in 0..15 {
|
||||
|
|
@ -2178,12 +2139,10 @@ mod tests {
|
|||
assert!(signal.is_concerning);
|
||||
assert!(signal.is_excessive);
|
||||
assert!(signal.efficiency_score < 0.5);
|
||||
println!("test_turn_count_excessive took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_follow_up_detection() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![
|
||||
create_message(Role::User, "Show me restaurants"),
|
||||
|
|
@ -2196,12 +2155,10 @@ mod tests {
|
|||
let signal = analyzer.analyze_follow_up(&normalized_messages);
|
||||
assert_eq!(signal.repair_count, 1);
|
||||
assert!(signal.repair_ratio > 0.0);
|
||||
println!("test_follow_up_detection took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frustration_detection() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![
|
||||
create_message(Role::User, "THIS IS RIDICULOUS!!!"),
|
||||
|
|
@ -2214,12 +2171,10 @@ mod tests {
|
|||
assert!(signal.has_frustration);
|
||||
assert!(signal.frustration_count >= 2);
|
||||
assert!(signal.severity > 0);
|
||||
println!("test_frustration_detection took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_positive_feedback_detection() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![
|
||||
create_message(Role::User, "Can you help me?"),
|
||||
|
|
@ -2232,15 +2187,10 @@ mod tests {
|
|||
assert!(signal.has_positive_feedback);
|
||||
assert!(signal.positive_count >= 1);
|
||||
assert!(signal.confidence > 0.5);
|
||||
println!(
|
||||
"test_positive_feedback_detection took: {:?}",
|
||||
start.elapsed()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escalation_detection() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![
|
||||
create_message(Role::User, "This isn't working"),
|
||||
|
|
@ -2252,12 +2202,10 @@ mod tests {
|
|||
let signal = analyzer.analyze_escalation(&normalized_messages);
|
||||
assert!(signal.escalation_requested);
|
||||
assert_eq!(signal.escalation_count, 1);
|
||||
println!("test_escalation_detection took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_repetition_detection() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![
|
||||
create_message(Role::User, "What's the weather?"),
|
||||
|
|
@ -2273,22 +2221,13 @@ mod tests {
|
|||
let normalized_messages = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_repetition(&normalized_messages);
|
||||
|
||||
for rep in &signal.repetitions {
|
||||
println!(
|
||||
" - Messages {:?}, similarity: {:.3}, type: {:?}",
|
||||
rep.message_indices, rep.similarity, rep.repetition_type
|
||||
);
|
||||
}
|
||||
|
||||
assert!(signal.repetition_count > 0,
|
||||
"Should detect the subtle repetition between 'I can help you with the weather information' \
|
||||
and 'Sure, I can help you with the forecast'");
|
||||
println!("test_repetition_detection took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_analysis_excellent() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![
|
||||
create_message(Role::User, "I need to book a flight"),
|
||||
|
|
@ -2305,12 +2244,10 @@ mod tests {
|
|||
));
|
||||
assert!(report.positive_feedback.has_positive_feedback);
|
||||
assert!(!report.frustration.has_frustration);
|
||||
println!("test_full_analysis_excellent took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_analysis_poor() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![
|
||||
create_message(Role::User, "Help me"),
|
||||
|
|
@ -2329,86 +2266,64 @@ mod tests {
|
|||
));
|
||||
assert!(report.frustration.has_frustration);
|
||||
assert!(report.escalation.escalation_requested);
|
||||
println!("test_full_analysis_poor took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fuzzy_matching_gratitude() {
|
||||
let start = Instant::now();
|
||||
fn test_fuzzy_matching() {
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
|
||||
// Gratitude with typo
|
||||
let messages = vec![
|
||||
create_message(Role::User, "Can you help me?"),
|
||||
create_message(Role::Assistant, "Sure!"),
|
||||
create_message(Role::User, "thnaks! that's exactly what i needed."),
|
||||
];
|
||||
|
||||
let normalized_messages = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_positive_feedback(&normalized_messages);
|
||||
assert!(signal.has_positive_feedback);
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_positive_feedback(&normalized);
|
||||
assert!(
|
||||
signal.has_positive_feedback,
|
||||
"fuzzy gratitude should be detected"
|
||||
);
|
||||
assert!(signal.positive_count >= 1);
|
||||
println!("test_fuzzy_matching_gratitude took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fuzzy_matching_escalation() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
// Escalation with typo
|
||||
let messages = vec![
|
||||
create_message(Role::User, "This isn't working"),
|
||||
create_message(Role::Assistant, "Let me help"),
|
||||
create_message(Role::User, "i need to speek to a human agnet"),
|
||||
];
|
||||
|
||||
let normalized_messages = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_escalation(&normalized_messages);
|
||||
assert!(signal.escalation_requested);
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_escalation(&normalized);
|
||||
assert!(
|
||||
signal.escalation_requested,
|
||||
"fuzzy escalation should be detected"
|
||||
);
|
||||
assert_eq!(signal.escalation_count, 1);
|
||||
println!("test_fuzzy_matching_escalation took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fuzzy_matching_repair() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
// Repair with typo
|
||||
let messages = vec![
|
||||
create_message(Role::User, "Show me restaurants"),
|
||||
create_message(Role::Assistant, "Here are some options"),
|
||||
create_message(Role::User, "no i ment Italian restaurants"),
|
||||
create_message(Role::Assistant, "Here are Italian restaurants"),
|
||||
];
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_follow_up(&normalized);
|
||||
assert!(signal.repair_count >= 1, "fuzzy repair should be detected");
|
||||
|
||||
let normalized_messages = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_follow_up(&normalized_messages);
|
||||
assert!(signal.repair_count >= 1);
|
||||
println!("test_fuzzy_matching_repair took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fuzzy_matching_complaint() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
// Use a complaint that should match - "doesnt work" is close enough to "doesn't work"
|
||||
// Complaint with typo
|
||||
let messages = vec![
|
||||
create_message(Role::User, "this doesnt work at all"), // Common typo: missing apostrophe
|
||||
create_message(Role::User, "this doesnt work at all"),
|
||||
create_message(Role::Assistant, "I apologize"),
|
||||
];
|
||||
|
||||
let normalized_messages = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_frustration(&normalized_messages);
|
||||
|
||||
// The layered matching should catch this via character ngrams or token cosine
|
||||
// "doesnt work" has high character-level similarity to "doesn't work"
|
||||
assert!(
|
||||
signal.has_frustration,
|
||||
"Should detect frustration from complaint pattern"
|
||||
);
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_frustration(&normalized);
|
||||
assert!(signal.has_frustration, "fuzzy complaint should be detected");
|
||||
assert!(signal.frustration_count >= 1);
|
||||
println!("test_fuzzy_matching_complaint took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exact_match_priority() {
|
||||
let start = Instant::now();
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![create_message(Role::User, "thank you so much")];
|
||||
|
||||
|
|
@ -2418,7 +2333,6 @@ mod tests {
|
|||
// Should detect exact match, not fuzzy
|
||||
assert!(signal.indicators[0].snippet.contains("thank you"));
|
||||
assert!(!signal.indicators[0].snippet.contains("fuzzy"));
|
||||
println!("test_exact_match_priority took: {:?}", start.elapsed());
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
|
|
@ -2426,31 +2340,54 @@ mod tests {
|
|||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_hello_not_profanity() {
|
||||
fn test_false_positive_guards() {
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![create_message(Role::User, "hello there")];
|
||||
|
||||
let normalized_messages = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_frustration(&normalized_messages);
|
||||
// "hello" should not trigger profanity
|
||||
let messages = vec![create_message(Role::User, "hello there")];
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_frustration(&normalized);
|
||||
assert!(
|
||||
!signal.has_frustration,
|
||||
"\"hello\" should not trigger profanity detection"
|
||||
"\"hello\" should not trigger profanity"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prepare_not_escalation() {
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
// "prepare" should not trigger escalation
|
||||
let messages = vec![create_message(
|
||||
Role::User,
|
||||
"Can you help me prepare for the meeting?",
|
||||
)];
|
||||
|
||||
let normalized_messages = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_escalation(&normalized_messages);
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_escalation(&normalized);
|
||||
assert!(
|
||||
!signal.escalation_requested,
|
||||
"\"prepare\" should not trigger escalation (rep pattern removed)"
|
||||
"\"prepare\" should not trigger escalation"
|
||||
);
|
||||
|
||||
// "absolute" should not trigger 'bs' match
|
||||
let messages = vec![create_message(Role::User, "That's absolute nonsense")];
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_frustration(&normalized);
|
||||
let has_bs_match = signal
|
||||
.indicators
|
||||
.iter()
|
||||
.any(|ind| ind.snippet.contains("bs"));
|
||||
assert!(
|
||||
!has_bs_match,
|
||||
"\"absolute\" should not trigger 'bs' profanity match"
|
||||
);
|
||||
|
||||
// Stopwords-only overlap should not be rephrase
|
||||
let messages = vec![
|
||||
create_message(Role::User, "Help me with X"),
|
||||
create_message(Role::Assistant, "Sure"),
|
||||
create_message(Role::User, "Help me with Y"),
|
||||
];
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_follow_up(&normalized);
|
||||
assert_eq!(
|
||||
signal.repair_count, 0,
|
||||
"Messages with only stopword overlap should not be rephrases"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -2485,42 +2422,6 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_absolute_not_profanity() {
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![create_message(Role::User, "That's absolute nonsense")];
|
||||
|
||||
let normalized_messages = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_frustration(&normalized_messages);
|
||||
// Should match on "nonsense" logic, not on "bs" substring
|
||||
let has_bs_match = signal
|
||||
.indicators
|
||||
.iter()
|
||||
.any(|ind| ind.snippet.contains("bs"));
|
||||
assert!(
|
||||
!has_bs_match,
|
||||
"\"absolute\" should not trigger 'bs' profanity match"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stopwords_not_rephrase() {
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![
|
||||
create_message(Role::User, "Help me with X"),
|
||||
create_message(Role::Assistant, "Sure"),
|
||||
create_message(Role::User, "Help me with Y"),
|
||||
];
|
||||
|
||||
let normalized_messages = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_follow_up(&normalized_messages);
|
||||
// Should not detect as rephrase since only stopwords overlap
|
||||
assert_eq!(
|
||||
signal.repair_count, 0,
|
||||
"Messages with only stopword overlap should not be rephrases"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frustrated_user_with_legitimate_repair() {
|
||||
let start = Instant::now();
|
||||
|
|
@ -2794,23 +2695,44 @@ mod tests {
|
|||
|
||||
// false negative tests
|
||||
#[test]
|
||||
fn test_dissatisfaction_polite_not_working_for_me() {
|
||||
fn test_dissatisfaction_and_low_mood() {
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![
|
||||
create_message(Role::User, "Thanks, but this still isn't working for me."), // Polite dissatisfaction, e.g., I appreciate it, but this isn't what I was looking for.
|
||||
create_message(Role::Assistant, "Sorry—what error do you see?"),
|
||||
];
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_frustration(&normalized);
|
||||
assert!(
|
||||
signal.has_frustration,
|
||||
"Polite dissatisfaction should be detected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dissatisfaction_giving_up_without_escalation() {
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
// Cases that should trigger frustration
|
||||
let frustration_cases = [
|
||||
(
|
||||
"Thanks, but this still isn't working for me.",
|
||||
"polite not working",
|
||||
),
|
||||
(
|
||||
"I'm running into the same issue again.",
|
||||
"same problem again",
|
||||
),
|
||||
("This feels incomplete.", "incomplete"),
|
||||
(
|
||||
"This is overwhelming and I'm not sure what to do.",
|
||||
"overwhelming",
|
||||
),
|
||||
(
|
||||
"I'm exhausted trying to get this working.",
|
||||
"exhausted trying",
|
||||
),
|
||||
];
|
||||
for (msg, label) in frustration_cases {
|
||||
let messages = vec![
|
||||
create_message(Role::User, msg),
|
||||
create_message(Role::Assistant, "Sorry about that."),
|
||||
];
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_frustration(&normalized);
|
||||
assert!(
|
||||
signal.has_frustration,
|
||||
"{}: should detect frustration",
|
||||
label
|
||||
);
|
||||
}
|
||||
|
||||
// Case that should trigger escalation (giving up)
|
||||
let messages = vec![create_message(
|
||||
Role::User,
|
||||
"Never mind, I'll figure it out myself.",
|
||||
|
|
@ -2819,61 +2741,7 @@ mod tests {
|
|||
let signal = analyzer.analyze_escalation(&normalized);
|
||||
assert!(
|
||||
signal.escalation_requested,
|
||||
"Giving up should count as escalation/quit intent"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dissatisfaction_same_problem_again() {
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![create_message(
|
||||
Role::User,
|
||||
"I'm running into the same issue again.",
|
||||
)];
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_frustration(&normalized);
|
||||
assert!(
|
||||
signal.has_frustration,
|
||||
"'same issue again' should be detected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsatisfied_incomplete() {
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![create_message(Role::User, "This feels incomplete.")];
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_frustration(&normalized);
|
||||
assert!(
|
||||
signal.has_frustration,
|
||||
"Should detect 'incomplete' dissatisfaction"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_low_mood_overwhelming() {
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![create_message(
|
||||
Role::User,
|
||||
"This is overwhelming and I'm not sure what to do.",
|
||||
)];
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_frustration(&normalized);
|
||||
assert!(signal.has_frustration, "Should detect overwhelmed language");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_low_mood_exhausted_trying() {
|
||||
let analyzer = TextBasedSignalAnalyzer::new();
|
||||
let messages = vec![create_message(
|
||||
Role::User,
|
||||
"I'm exhausted trying to get this working.",
|
||||
)];
|
||||
let normalized = preprocess_messages(&messages);
|
||||
let signal = analyzer.analyze_frustration(&normalized);
|
||||
assert!(
|
||||
signal.has_frustration,
|
||||
"Should detect exhaustion/struggle language"
|
||||
"giving up should count as escalation"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -85,11 +85,30 @@ impl StateStorage for MemoryConversationalStorage {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::state::generate_storage_tests;
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputContent, InputItem, InputMessage, MessageContent, MessageRole,
|
||||
};
|
||||
|
||||
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
|
||||
fn create_test_state(response_id: &str) -> OpenAIConversationState {
|
||||
OpenAIConversationState {
|
||||
response_id: response_id.to_string(),
|
||||
input_items: vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Test message".to_string(),
|
||||
}]),
|
||||
})],
|
||||
created_at: 1234567890,
|
||||
model: "claude-3".to_string(),
|
||||
provider: "anthropic".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_state_with_messages(
|
||||
response_id: &str,
|
||||
num_messages: usize,
|
||||
) -> OpenAIConversationState {
|
||||
let mut input_items = Vec::new();
|
||||
for i in 0..num_messages {
|
||||
input_items.push(InputItem::Message(InputMessage {
|
||||
|
|
@ -113,209 +132,8 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_put_and_get_success() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let state: OpenAIConversationState = create_test_state("resp_001", 3);
|
||||
|
||||
// Store
|
||||
storage.put(state.clone()).await.unwrap();
|
||||
|
||||
// Retrieve
|
||||
let retrieved = storage.get("resp_001").await.unwrap();
|
||||
assert_eq!(retrieved.response_id, state.response_id);
|
||||
assert_eq!(retrieved.model, state.model);
|
||||
assert_eq!(retrieved.provider, state.provider);
|
||||
assert_eq!(retrieved.input_items.len(), 3);
|
||||
assert_eq!(retrieved.created_at, state.created_at);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_put_overwrites_existing() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// First state
|
||||
let state1 = create_test_state("resp_002", 2);
|
||||
storage.put(state1).await.unwrap();
|
||||
|
||||
// Overwrite with new state
|
||||
let state2 = OpenAIConversationState {
|
||||
response_id: "resp_002".to_string(),
|
||||
input_items: vec![],
|
||||
created_at: 9999999999,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
};
|
||||
storage.put(state2.clone()).await.unwrap();
|
||||
|
||||
// Should retrieve the new state
|
||||
let retrieved = storage.get("resp_002").await.unwrap();
|
||||
assert_eq!(retrieved.model, "gpt-4");
|
||||
assert_eq!(retrieved.provider, "openai");
|
||||
assert_eq!(retrieved.input_items.len(), 0);
|
||||
assert_eq!(retrieved.created_at, 9999999999);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_not_found() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let result = storage.get("nonexistent").await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result.unwrap_err() {
|
||||
StateStorageError::NotFound(id) => {
|
||||
assert_eq!(id, "nonexistent");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exists_returns_false_for_nonexistent() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
assert!(!storage.exists("resp_003").await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exists_returns_true_after_put() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let state = create_test_state("resp_004", 1);
|
||||
|
||||
assert!(!storage.exists("resp_004").await.unwrap());
|
||||
storage.put(state).await.unwrap();
|
||||
assert!(storage.exists("resp_004").await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_success() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let state = create_test_state("resp_005", 2);
|
||||
|
||||
storage.put(state).await.unwrap();
|
||||
assert!(storage.exists("resp_005").await.unwrap());
|
||||
|
||||
// Delete
|
||||
storage.delete("resp_005").await.unwrap();
|
||||
|
||||
// Should no longer exist
|
||||
assert!(!storage.exists("resp_005").await.unwrap());
|
||||
assert!(storage.get("resp_005").await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_not_found() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let result = storage.delete("nonexistent").await;
|
||||
assert!(result.is_err());
|
||||
|
||||
match result.unwrap_err() {
|
||||
StateStorageError::NotFound(id) => {
|
||||
assert_eq!(id, "nonexistent");
|
||||
}
|
||||
_ => panic!("Expected NotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_combines_inputs() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Create a previous state with 2 messages
|
||||
let prev_state = create_test_state("resp_006", 2);
|
||||
|
||||
// Create current input with 1 message
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "New message".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
// Merge
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have 3 messages total (2 from prev + 1 current)
|
||||
assert_eq!(merged.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_preserves_order() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Previous state has messages 0 and 1
|
||||
let prev_state = create_test_state("resp_007", 2);
|
||||
|
||||
// Current input has message 2
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Message 2".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Verify order: prev messages first, then current
|
||||
let InputItem::Message(msg) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Message 0"),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg) = &merged[2] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Message 2"),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_empty_current_input() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let prev_state = create_test_state("resp_008", 3);
|
||||
|
||||
let merged = storage.merge(&prev_state, vec![]);
|
||||
|
||||
// Should just have the previous state's items
|
||||
assert_eq!(merged.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_empty_previous_state() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_009".to_string(),
|
||||
input_items: vec![],
|
||||
created_at: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
};
|
||||
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Only message".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should just have the current input
|
||||
assert_eq!(merged.len(), 1);
|
||||
}
|
||||
// Generate the standard CRUD tests via macro
|
||||
generate_storage_tests!(MemoryConversationalStorage::new());
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_access() {
|
||||
|
|
@ -327,7 +145,7 @@ mod tests {
|
|||
for i in 0..10 {
|
||||
let storage_clone = storage.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let state = create_test_state(&format!("resp_{}", i), i % 3);
|
||||
let state = create_test_state_with_messages(&format!("resp_{}", i), i % 3);
|
||||
storage_clone.put(state).await.unwrap();
|
||||
});
|
||||
handles.push(handle);
|
||||
|
|
@ -347,7 +165,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn test_multiple_operations_on_same_id() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let state = create_test_state("resp_010", 1);
|
||||
let state = create_test_state_with_messages("resp_010", 1);
|
||||
|
||||
// Put
|
||||
storage.put(state.clone()).await.unwrap();
|
||||
|
|
@ -360,7 +178,7 @@ mod tests {
|
|||
assert!(storage.exists("resp_010").await.unwrap());
|
||||
|
||||
// Put again (overwrite)
|
||||
let new_state = create_test_state("resp_010", 5);
|
||||
let new_state = create_test_state_with_messages("resp_010", 5);
|
||||
storage.put(new_state).await.unwrap();
|
||||
|
||||
// Get updated
|
||||
|
|
@ -373,266 +191,4 @@ mod tests {
|
|||
// Should not exist
|
||||
assert!(!storage.exists("resp_010").await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_tool_call_flow() {
|
||||
// This test simulates a realistic tool call conversation flow:
|
||||
// 1. User sends message: "What's the weather?"
|
||||
// 2. Model responds with function call (converted to assistant message)
|
||||
// 3. User sends function call output in next request with previous_response_id
|
||||
// The merge should combine: user message + assistant function call + function output
|
||||
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Step 1: Previous state contains the initial exchange
|
||||
// - User message: "What's the weather in SF?"
|
||||
// - Assistant message (converted from FunctionCall): "Called function: get_weather..."
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_tool_001".to_string(),
|
||||
input_items: vec![
|
||||
// Original user message
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "What's the weather in San Francisco?".to_string(),
|
||||
}]),
|
||||
}),
|
||||
// Assistant's function call (converted from OutputItem::FunctionCall)
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_weather with arguments: {\"location\":\"San Francisco, CA\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
],
|
||||
created_at: 1234567890,
|
||||
model: "claude-3".to_string(),
|
||||
provider: "anthropic".to_string(),
|
||||
};
|
||||
|
||||
// Step 2: Current request includes function call output
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}"
|
||||
.to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
// Step 3: Merge should combine all conversation history
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have 3 items: user question + assistant function call + function output
|
||||
assert_eq!(merged.len(), 3);
|
||||
|
||||
// Verify the order and content
|
||||
let InputItem::Message(msg1) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(msg1.role, MessageRole::User));
|
||||
match &msg1.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("weather in San Francisco"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg2) = &merged[1] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(msg2.role, MessageRole::Assistant));
|
||||
match &msg2.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("get_weather"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg3) = &merged[2] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(msg3.role, MessageRole::User));
|
||||
match &msg3.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("Function result"));
|
||||
assert!(text.contains("temperature"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_multiple_tool_calls() {
|
||||
// Test a more complex scenario with multiple tool calls
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Previous state has: user message + 2 function calls from assistant
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_tool_002".to_string(),
|
||||
input_items: vec![
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "What's the weather and time in SF?".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_weather with arguments: {\"location\":\"SF\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_time with arguments: {\"timezone\":\"America/Los_Angeles\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
],
|
||||
created_at: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
};
|
||||
|
||||
// Current input: function outputs for both calls
|
||||
let current_input = vec![
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Weather result: {\"temp\": 68}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Time result: {\"time\": \"14:30\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have 5 items total: 1 user + 2 assistant calls + 2 function outputs
|
||||
assert_eq!(merged.len(), 5);
|
||||
|
||||
// Verify first item is original user message
|
||||
let InputItem::Message(first) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(first.role, MessageRole::User));
|
||||
|
||||
// Verify last two are function outputs
|
||||
let InputItem::Message(second_last) = &merged[3] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(second_last.role, MessageRole::User));
|
||||
match &second_last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("Weather result")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(last) = &merged[4] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(last.role, MessageRole::User));
|
||||
match &last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("Time result")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_preserves_conversation_context_for_multi_turn() {
|
||||
// Simulate a multi-turn conversation with tool calls
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
// Previous state: full conversation history up to this point
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_tool_003".to_string(),
|
||||
input_items: vec![
|
||||
// Turn 1: User asks about weather
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "What's the weather?".to_string(),
|
||||
}]),
|
||||
}),
|
||||
// Turn 1: Assistant calls get_weather
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_weather".to_string(),
|
||||
}]),
|
||||
}),
|
||||
// Turn 2: User provides function output
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Weather: sunny, 72°F".to_string(),
|
||||
}]),
|
||||
}),
|
||||
// Turn 2: Assistant responds with text
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "It's sunny and 72°F in San Francisco today!".to_string(),
|
||||
}]),
|
||||
}),
|
||||
],
|
||||
created_at: 1234567890,
|
||||
model: "claude-3".to_string(),
|
||||
provider: "anthropic".to_string(),
|
||||
};
|
||||
|
||||
// Turn 3: User asks follow-up question
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Should I bring an umbrella?".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have all 5 messages in order
|
||||
assert_eq!(merged.len(), 5);
|
||||
|
||||
// Verify the entire conversation flow is preserved
|
||||
let InputItem::Message(first) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &first.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("What's the weather")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(last) = &merged[4] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("umbrella")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -148,13 +148,128 @@ pub async fn retrieve_and_combine_input(
|
|||
Ok(combined_input)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
macro_rules! generate_storage_tests {
|
||||
($create_storage:expr) => {
|
||||
#[tokio::test]
|
||||
async fn test_put_and_get_success() {
|
||||
let storage = $create_storage;
|
||||
let state = create_test_state("resp_001");
|
||||
storage.put(state.clone()).await.unwrap();
|
||||
let retrieved = storage.get("resp_001").await.unwrap();
|
||||
assert_eq!(retrieved.response_id, "resp_001");
|
||||
assert_eq!(retrieved.model, state.model);
|
||||
assert_eq!(retrieved.provider, state.provider);
|
||||
assert_eq!(retrieved.input_items.len(), state.input_items.len());
|
||||
assert_eq!(retrieved.created_at, state.created_at);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_put_overwrites_existing() {
|
||||
let storage = $create_storage;
|
||||
let state1 = create_test_state("resp_002");
|
||||
storage.put(state1).await.unwrap();
|
||||
let state2 = OpenAIConversationState {
|
||||
response_id: "resp_002".to_string(),
|
||||
input_items: vec![],
|
||||
created_at: 9999999999,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
};
|
||||
storage.put(state2).await.unwrap();
|
||||
let retrieved = storage.get("resp_002").await.unwrap();
|
||||
assert_eq!(retrieved.model, "gpt-4");
|
||||
assert_eq!(retrieved.provider, "openai");
|
||||
assert_eq!(retrieved.input_items.len(), 0);
|
||||
assert_eq!(retrieved.created_at, 9999999999);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_not_found() {
|
||||
let storage = $create_storage;
|
||||
let result = storage.get("nonexistent").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
StateStorageError::NotFound(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exists_returns_false_for_nonexistent() {
|
||||
let storage = $create_storage;
|
||||
assert!(!storage.exists("nonexistent").await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exists_returns_true_after_put() {
|
||||
let storage = $create_storage;
|
||||
let state = create_test_state("resp_004");
|
||||
assert!(!storage.exists("resp_004").await.unwrap());
|
||||
storage.put(state).await.unwrap();
|
||||
assert!(storage.exists("resp_004").await.unwrap());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_success() {
|
||||
let storage = $create_storage;
|
||||
let state = create_test_state("resp_005");
|
||||
storage.put(state).await.unwrap();
|
||||
assert!(storage.exists("resp_005").await.unwrap());
|
||||
storage.delete("resp_005").await.unwrap();
|
||||
assert!(!storage.exists("resp_005").await.unwrap());
|
||||
assert!(storage.get("resp_005").await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_not_found() {
|
||||
let storage = $create_storage;
|
||||
let result = storage.delete("nonexistent").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
StateStorageError::NotFound(_)
|
||||
));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use generate_storage_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::extract_input_items;
|
||||
use super::memory::MemoryConversationalStorage;
|
||||
use super::{OpenAIConversationState, StateStorage};
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputContent, InputItem, InputMessage, InputParam, MessageContent, MessageRole,
|
||||
};
|
||||
|
||||
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
|
||||
let mut input_items = Vec::new();
|
||||
for i in 0..num_messages {
|
||||
input_items.push(InputItem::Message(InputMessage {
|
||||
role: if i % 2 == 0 {
|
||||
MessageRole::User
|
||||
} else {
|
||||
MessageRole::Assistant
|
||||
},
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: format!("Message {}", i),
|
||||
}]),
|
||||
}));
|
||||
}
|
||||
|
||||
OpenAIConversationState {
|
||||
response_id: response_id.to_string(),
|
||||
input_items,
|
||||
created_at: 1234567890,
|
||||
model: "claude-3".to_string(),
|
||||
provider: "anthropic".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_input_items_converts_text_to_user_message_item() {
|
||||
let extracted = extract_input_items(&InputParam::Text("hello world".to_string()));
|
||||
|
|
@ -244,4 +359,320 @@ mod tests {
|
|||
};
|
||||
assert_eq!(second_text, "second");
|
||||
}
|
||||
|
||||
// === Merge tests (testing the default trait method) ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_combines_inputs() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let prev_state = create_test_state("resp_006", 2);
|
||||
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "New message".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
assert_eq!(merged.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_preserves_order() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let prev_state = create_test_state("resp_007", 2);
|
||||
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Message 2".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
let InputItem::Message(msg) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Message 0"),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg) = &merged[2] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &msg.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert_eq!(text, "Message 2"),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_empty_current_input() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
let prev_state = create_test_state("resp_008", 3);
|
||||
|
||||
let merged = storage.merge(&prev_state, vec![]);
|
||||
assert_eq!(merged.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_empty_previous_state() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_009".to_string(),
|
||||
input_items: vec![],
|
||||
created_at: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
};
|
||||
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Only message".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
assert_eq!(merged.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_tool_call_flow() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_tool_001".to_string(),
|
||||
input_items: vec![
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "What's the weather in San Francisco?".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_weather with arguments: {\"location\":\"San Francisco, CA\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
],
|
||||
created_at: 1234567890,
|
||||
model: "claude-3".to_string(),
|
||||
provider: "anthropic".to_string(),
|
||||
};
|
||||
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}"
|
||||
.to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
assert_eq!(merged.len(), 3);
|
||||
|
||||
let InputItem::Message(msg1) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(msg1.role, MessageRole::User));
|
||||
match &msg1.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("weather in San Francisco"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg2) = &merged[1] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(msg2.role, MessageRole::Assistant));
|
||||
match &msg2.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("get_weather"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(msg3) = &merged[2] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(msg3.role, MessageRole::User));
|
||||
match &msg3.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => {
|
||||
assert!(text.contains("Function result"));
|
||||
assert!(text.contains("temperature"));
|
||||
}
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_with_multiple_tool_calls() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_tool_002".to_string(),
|
||||
input_items: vec![
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "What's the weather and time in SF?".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_weather with arguments: {\"location\":\"SF\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_time with arguments: {\"timezone\":\"America/Los_Angeles\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
],
|
||||
created_at: 1234567890,
|
||||
model: "gpt-4".to_string(),
|
||||
provider: "openai".to_string(),
|
||||
};
|
||||
|
||||
let current_input = vec![
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Weather result: {\"temp\": 68}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Time result: {\"time\": \"14:30\"}".to_string(),
|
||||
}]),
|
||||
}),
|
||||
];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
assert_eq!(merged.len(), 5);
|
||||
|
||||
let InputItem::Message(first) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(first.role, MessageRole::User));
|
||||
|
||||
let InputItem::Message(second_last) = &merged[3] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(second_last.role, MessageRole::User));
|
||||
match &second_last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("Weather result")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(last) = &merged[4] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
assert!(matches!(last.role, MessageRole::User));
|
||||
match &last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("Time result")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_preserves_conversation_context_for_multi_turn() {
|
||||
let storage = MemoryConversationalStorage::new();
|
||||
|
||||
let prev_state = OpenAIConversationState {
|
||||
response_id: "resp_tool_003".to_string(),
|
||||
input_items: vec![
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "What's the weather?".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Called function: get_weather".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Weather: sunny, 72\u{00b0}F".to_string(),
|
||||
}]),
|
||||
}),
|
||||
InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "It's sunny and 72\u{00b0}F in San Francisco today!".to_string(),
|
||||
}]),
|
||||
}),
|
||||
],
|
||||
created_at: 1234567890,
|
||||
model: "claude-3".to_string(),
|
||||
provider: "anthropic".to_string(),
|
||||
};
|
||||
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Should I bring an umbrella?".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
assert_eq!(merged.len(), 5);
|
||||
|
||||
let InputItem::Message(first) = &merged[0] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &first.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("What's the weather")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
|
||||
let InputItem::Message(last) = &merged[4] else {
|
||||
panic!("Expected Message")
|
||||
};
|
||||
match &last.content {
|
||||
MessageContent::Items(items) => match &items[0] {
|
||||
InputContent::InputText { text } => assert!(text.contains("umbrella")),
|
||||
_ => panic!("Expected InputText"),
|
||||
},
|
||||
_ => panic!("Expected MessageContent::Items"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -229,6 +229,7 @@ Run that SQL file against your database before using this storage backend.
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::state::generate_storage_tests;
|
||||
use hermesllm::apis::openai_responses::{
|
||||
InputContent, InputItem, InputMessage, MessageContent, MessageRole,
|
||||
};
|
||||
|
|
@ -267,140 +268,13 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_put_and_get_success() {
|
||||
// Generate the standard CRUD tests via macro
|
||||
generate_storage_tests!({
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let state = create_test_state("test_resp_001");
|
||||
storage.put(state.clone()).await.unwrap();
|
||||
|
||||
let retrieved = storage.get("test_resp_001").await.unwrap();
|
||||
assert_eq!(retrieved.response_id, "test_resp_001");
|
||||
assert_eq!(retrieved.input_items.len(), 1);
|
||||
assert_eq!(retrieved.model, "gpt-4");
|
||||
assert_eq!(retrieved.provider, "openai");
|
||||
|
||||
// Cleanup
|
||||
let _ = storage.delete("test_resp_001").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_put_overwrites_existing() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let state1 = create_test_state("test_resp_002");
|
||||
storage.put(state1).await.unwrap();
|
||||
|
||||
let mut state2 = create_test_state("test_resp_002");
|
||||
state2.model = "gpt-4-turbo".to_string();
|
||||
state2.input_items.push(InputItem::Message(InputMessage {
|
||||
role: MessageRole::Assistant,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "Response".to_string(),
|
||||
}]),
|
||||
}));
|
||||
storage.put(state2).await.unwrap();
|
||||
|
||||
let retrieved = storage.get("test_resp_002").await.unwrap();
|
||||
assert_eq!(retrieved.model, "gpt-4-turbo");
|
||||
assert_eq!(retrieved.input_items.len(), 2);
|
||||
|
||||
// Cleanup
|
||||
let _ = storage.delete("test_resp_002").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_get_not_found() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let result = storage.get("nonexistent_id").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
StateStorageError::NotFound(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_exists_returns_false() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let exists = storage.exists("nonexistent_id").await.unwrap();
|
||||
assert!(!exists);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_exists_returns_true_after_put() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let state = create_test_state("test_resp_003");
|
||||
storage.put(state).await.unwrap();
|
||||
|
||||
let exists = storage.exists("test_resp_003").await.unwrap();
|
||||
assert!(exists);
|
||||
|
||||
// Cleanup
|
||||
let _ = storage.delete("test_resp_003").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_delete_success() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let state = create_test_state("test_resp_004");
|
||||
storage.put(state).await.unwrap();
|
||||
|
||||
storage.delete("test_resp_004").await.unwrap();
|
||||
|
||||
let exists = storage.exists("test_resp_004").await.unwrap();
|
||||
assert!(!exists);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_delete_not_found() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let result = storage.delete("nonexistent_id").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(
|
||||
result.unwrap_err(),
|
||||
StateStorageError::NotFound(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_merge_works() {
|
||||
let Some(storage) = get_test_storage().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let prev_state = create_test_state("test_resp_005");
|
||||
let current_input = vec![InputItem::Message(InputMessage {
|
||||
role: MessageRole::User,
|
||||
content: MessageContent::Items(vec![InputContent::InputText {
|
||||
text: "New message".to_string(),
|
||||
}]),
|
||||
})];
|
||||
|
||||
let merged = storage.merge(&prev_state, current_input);
|
||||
|
||||
// Should have 2 messages (1 from prev + 1 current)
|
||||
assert_eq!(merged.len(), 2);
|
||||
}
|
||||
storage
|
||||
});
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_supabase_table_verification() {
|
||||
|
|
@ -428,7 +302,7 @@ mod tests {
|
|||
let state = create_test_state("manual_test_verification");
|
||||
storage.put(state).await.unwrap();
|
||||
|
||||
println!("✅ Data written to Supabase!");
|
||||
println!("Data written to Supabase!");
|
||||
println!("Check your Supabase dashboard:");
|
||||
println!(
|
||||
" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';"
|
||||
|
|
|
|||
|
|
@ -472,6 +472,37 @@ mod tests {
|
|||
use crate::clients::endpoints::SupportedAPIsFromClient;
|
||||
use serde_json::json;
|
||||
|
||||
/// Helper to build an SseEvent from optional JSON data and optional event type.
|
||||
fn make_sse_event(data: Option<serde_json::Value>, event: Option<&str>) -> SseEvent {
|
||||
match data {
|
||||
Some(ref json_val) => SseEvent {
|
||||
data: Some(json_val.to_string()),
|
||||
event: event.map(|s| s.to_string()),
|
||||
raw_line: format!("data: {}", json_val),
|
||||
sse_transformed_lines: format!("data: {}", json_val),
|
||||
provider_stream_response: None,
|
||||
},
|
||||
None => SseEvent {
|
||||
data: None,
|
||||
event: event.map(|s| s.to_string()),
|
||||
raw_line: event.map(|e| format!("event: {}", e)).unwrap_or_default(),
|
||||
sse_transformed_lines: event.map(|e| format!("event: {}", e)).unwrap_or_default(),
|
||||
provider_stream_response: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to build a standard OpenAI content chunk with the given content text.
|
||||
fn openai_content_chunk(content: &str) -> serde_json::Value {
|
||||
json!({
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [{"index": 0, "delta": {"content": content}, "finish_reason": null}]
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_event_parsing() {
|
||||
// Test valid SSE data line
|
||||
|
|
@ -1099,14 +1130,7 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
// Create SSE event with this data
|
||||
let sse_event = SseEvent {
|
||||
data: Some(openai_stream_chunk.to_string()),
|
||||
event: None,
|
||||
raw_line: format!("data: {}", openai_stream_chunk),
|
||||
sse_transformed_lines: format!("data: {}", openai_stream_chunk),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
let sse_event = make_sse_event(Some(openai_stream_chunk), None);
|
||||
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
|
@ -1143,26 +1167,7 @@ mod tests {
|
|||
use crate::apis::openai::OpenAIApi;
|
||||
|
||||
// Create an OpenAI stream response with content (which becomes content_block_delta in Anthropic)
|
||||
let openai_stream_chunk = json!({
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": "Hello"},
|
||||
"finish_reason": null
|
||||
}]
|
||||
});
|
||||
|
||||
// Create SSE event with this data
|
||||
let sse_event = SseEvent {
|
||||
data: Some(openai_stream_chunk.to_string()),
|
||||
event: None,
|
||||
raw_line: format!("data: {}", openai_stream_chunk),
|
||||
sse_transformed_lines: format!("data: {}", openai_stream_chunk),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
let sse_event = make_sse_event(Some(openai_content_chunk("Hello")), None);
|
||||
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
|
@ -1198,13 +1203,7 @@ mod tests {
|
|||
use crate::apis::openai::OpenAIApi;
|
||||
|
||||
// Create an Anthropic event-only SSE line (no data)
|
||||
let sse_event = SseEvent {
|
||||
data: None,
|
||||
event: Some("message_start".to_string()),
|
||||
raw_line: "event: message_start".to_string(),
|
||||
sse_transformed_lines: "event: message_start".to_string(),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
let sse_event = make_sse_event(None, Some("message_start"));
|
||||
|
||||
let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
|
|
@ -1245,13 +1244,7 @@ mod tests {
|
|||
}
|
||||
});
|
||||
|
||||
let sse_event = SseEvent {
|
||||
data: Some(anthropic_event.to_string()),
|
||||
event: None,
|
||||
raw_line: format!("data: {}", anthropic_event),
|
||||
sse_transformed_lines: format!("data: {}", anthropic_event),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
let sse_event = make_sse_event(Some(anthropic_event), None);
|
||||
|
||||
let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
|
|
@ -1279,26 +1272,9 @@ mod tests {
|
|||
use crate::apis::openai::OpenAIApi;
|
||||
|
||||
// Create an OpenAI stream response
|
||||
let openai_stream_chunk = json!({
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": "Hello"},
|
||||
"finish_reason": null
|
||||
}]
|
||||
});
|
||||
|
||||
let original_data = openai_stream_chunk.to_string();
|
||||
let sse_event = SseEvent {
|
||||
data: Some(original_data.clone()),
|
||||
event: None,
|
||||
raw_line: format!("data: {}", original_data),
|
||||
sse_transformed_lines: format!("data: {}\n\n", original_data),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
let mut sse_event = make_sse_event(Some(openai_content_chunk("Hello")), None);
|
||||
// This test requires trailing \n\n in sse_transformed_lines
|
||||
sse_event.sse_transformed_lines = format!("data: {}\n\n", openai_content_chunk("Hello"));
|
||||
|
||||
let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
|
@ -1324,25 +1300,7 @@ mod tests {
|
|||
use crate::apis::openai::OpenAIApi;
|
||||
|
||||
// Create an OpenAI stream response
|
||||
let openai_stream_chunk = json!({
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"content": "Test"},
|
||||
"finish_reason": null
|
||||
}]
|
||||
});
|
||||
|
||||
let sse_event = SseEvent {
|
||||
data: Some(openai_stream_chunk.to_string()),
|
||||
event: None,
|
||||
raw_line: format!("data: {}", openai_stream_chunk),
|
||||
sse_transformed_lines: format!("data: {}", openai_stream_chunk),
|
||||
provider_stream_response: None,
|
||||
};
|
||||
let sse_event = make_sse_event(Some(openai_content_chunk("Test")), None);
|
||||
|
||||
let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages);
|
||||
let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue