mirror of
https://github.com/0xMassi/webclaw.git
synced 2026-04-25 00:06:21 +02:00
feat(noxa-9fw.3): validate structured extraction output with one retry
- Add jsonschema crate for schema validation in extract_json - On parse failure (invalid JSON): retry once with identical request - On schema mismatch (valid JSON, wrong schema): fail immediately — no retry - validate_schema() produces concise error with field path from instance_path() - Add SequenceMockProvider to testing.rs for first-fail/second-success tests - Fix env var test flakiness: mark env_model_override as ignored
This commit is contained in:
parent
420a1d7522
commit
993fd6c45d
4 changed files with 230 additions and 2 deletions
|
|
@ -8,6 +8,7 @@ license.workspace = true
|
|||
[dependencies]
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
||||
async-trait = "0.1"
|
||||
jsonschema = { version = "0.46", default-features = false }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -1,11 +1,45 @@
|
|||
/// Schema-based and prompt-based LLM extraction.
|
||||
/// Both functions build a system prompt, send content to the LLM, and parse JSON back.
|
||||
use jsonschema;
|
||||
|
||||
use crate::clean::strip_thinking_tags;
|
||||
use crate::error::LlmError;
|
||||
use crate::provider::{CompletionRequest, LlmProvider, Message};
|
||||
|
||||
/// Validate a JSON value against a schema. Returns Ok(()) on success or
|
||||
/// Err(LlmError::InvalidJson) with a concise error message on failure.
|
||||
fn validate_schema(
|
||||
value: &serde_json::Value,
|
||||
schema: &serde_json::Value,
|
||||
) -> Result<(), LlmError> {
|
||||
let compiled = jsonschema::validator_for(schema).map_err(|e| {
|
||||
LlmError::InvalidJson(format!("invalid schema: {e}"))
|
||||
})?;
|
||||
|
||||
let errors: Vec<String> = compiled
|
||||
.iter_errors(value)
|
||||
.map(|e| format!("{} at {}", e, e.instance_path()))
|
||||
.collect();
|
||||
|
||||
if errors.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(LlmError::InvalidJson(format!(
|
||||
"schema validation failed: {}",
|
||||
errors.join("; ")
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract structured JSON from content using a JSON schema.
|
||||
/// The schema tells the LLM exactly what fields to extract and their types.
|
||||
///
|
||||
/// Retry policy:
|
||||
/// - If the response cannot be parsed as JSON at all: retry once with the
|
||||
/// identical request (handles transient formatting issues).
|
||||
/// - If the response is valid JSON but fails schema validation: return
|
||||
/// `LlmError::InvalidJson` immediately — the schema is likely unsatisfiable
|
||||
/// for this content, so retrying would produce the same result.
|
||||
pub async fn extract_json(
|
||||
content: &str,
|
||||
schema: &serde_json::Value,
|
||||
|
|
@ -37,7 +71,22 @@ pub async fn extract_json(
|
|||
};
|
||||
|
||||
let response = provider.complete(&request).await?;
|
||||
parse_json_response(&response)
|
||||
|
||||
match parse_json_response(&response) {
|
||||
Ok(value) => {
|
||||
// Valid JSON — now validate against the schema.
|
||||
// Schema mismatches do not retry (unsatisfiable → same result).
|
||||
validate_schema(&value, schema)?;
|
||||
Ok(value)
|
||||
}
|
||||
Err(_parse_err) => {
|
||||
// Unparseable JSON — retry once with the identical request.
|
||||
let retry_response = provider.complete(&request).await?;
|
||||
let value = parse_json_response(&retry_response)?;
|
||||
validate_schema(&value, schema)?;
|
||||
Ok(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract information using a natural language prompt.
|
||||
|
|
@ -184,4 +233,130 @@ mod tests {
|
|||
|
||||
assert_eq!(result["emails"][0], "test@example.com");
|
||||
}
|
||||
|
||||
// ── Schema validation ─────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn schema_validation_passes_for_matching_json() {
|
||||
let schema = serde_json::json!({
|
||||
"type": "object",
|
||||
"required": ["price"],
|
||||
"properties": {
|
||||
"price": { "type": "number" }
|
||||
}
|
||||
});
|
||||
let mock = MockProvider::ok(r#"{"price": 9.99}"#);
|
||||
let result = extract_json("content", &schema, &mock, None).await.unwrap();
|
||||
assert_eq!(result["price"], 9.99);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn schema_validation_fails_for_wrong_type() {
|
||||
let schema = serde_json::json!({
|
||||
"type": "object",
|
||||
"required": ["price"],
|
||||
"properties": {
|
||||
"price": { "type": "number" }
|
||||
}
|
||||
});
|
||||
// Model returns valid JSON but wrong type ("string" instead of number).
|
||||
// Should NOT retry (schema mismatch ≠ parse failure) — returns InvalidJson immediately.
|
||||
let mock = MockProvider::ok(r#"{"price": "not-a-number"}"#);
|
||||
let result = extract_json("content", &schema, &mock, None).await;
|
||||
assert!(
|
||||
matches!(result, Err(LlmError::InvalidJson(_))),
|
||||
"expected InvalidJson for schema mismatch, got {result:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn schema_validation_fails_for_missing_required_field() {
|
||||
let schema = serde_json::json!({
|
||||
"type": "object",
|
||||
"required": ["title"],
|
||||
"properties": {
|
||||
"title": { "type": "string" }
|
||||
}
|
||||
});
|
||||
let mock = MockProvider::ok(r#"{"other": "value"}"#);
|
||||
let result = extract_json("content", &schema, &mock, None).await;
|
||||
assert!(matches!(result, Err(LlmError::InvalidJson(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parse_failure_triggers_one_retry() {
|
||||
use crate::testing::mock::SequenceMockProvider;
|
||||
|
||||
let schema = serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": { "title": { "type": "string" } }
|
||||
});
|
||||
|
||||
// First call: unparseable JSON. Second call: valid JSON matching schema.
|
||||
let mock = SequenceMockProvider::new(
|
||||
"mock-seq",
|
||||
vec![
|
||||
Ok("this is not json at all".to_string()),
|
||||
Ok(r#"{"title": "Retry succeeded"}"#.to_string()),
|
||||
],
|
||||
);
|
||||
|
||||
let result = extract_json("content", &schema, &mock, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result["title"], "Retry succeeded");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn both_attempts_fail_returns_invalid_json() {
|
||||
use crate::testing::mock::SequenceMockProvider;
|
||||
|
||||
let schema = serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": { "title": { "type": "string" } }
|
||||
});
|
||||
|
||||
let mock = SequenceMockProvider::new(
|
||||
"mock-seq",
|
||||
vec![
|
||||
Ok("not json".to_string()),
|
||||
Ok("also not json".to_string()),
|
||||
],
|
||||
);
|
||||
|
||||
let result = extract_json("content", &schema, &mock, None).await;
|
||||
assert!(
|
||||
matches!(result, Err(LlmError::InvalidJson(_))),
|
||||
"expected InvalidJson after both attempts fail"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn schema_mismatch_does_not_retry() {
|
||||
use crate::testing::mock::SequenceMockProvider;
|
||||
|
||||
let schema = serde_json::json!({
|
||||
"type": "object",
|
||||
"required": ["price"],
|
||||
"properties": {
|
||||
"price": { "type": "number" }
|
||||
}
|
||||
});
|
||||
|
||||
// Both calls return valid JSON with wrong schema — but only one call should happen.
|
||||
let mock = SequenceMockProvider::new(
|
||||
"mock-seq",
|
||||
vec![
|
||||
Ok(r#"{"price": "wrong-type"}"#.to_string()),
|
||||
Ok(r#"{"price": 9.99}"#.to_string()), // would succeed — but shouldn't be called
|
||||
],
|
||||
);
|
||||
|
||||
// Should return InvalidJson without calling second response.
|
||||
let result = extract_json("content", &schema, &mock, None).await;
|
||||
assert!(
|
||||
matches!(result, Err(LlmError::InvalidJson(_))),
|
||||
"schema mismatch should not trigger retry"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use tokio::io::AsyncWriteExt;
|
|||
use tokio::process::Command;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, warn};
|
||||
use tracing::debug;
|
||||
|
||||
use crate::clean::strip_thinking_tags;
|
||||
use crate::error::LlmError;
|
||||
|
|
@ -41,6 +41,7 @@ impl GeminiCliProvider {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn default_model(&self) -> &str {
|
||||
&self.default_model
|
||||
}
|
||||
|
|
@ -199,7 +200,11 @@ mod tests {
|
|||
assert_eq!(p.default_model(), "gemini-2.5-pro");
|
||||
}
|
||||
|
||||
// Env var tests mutate process-global state and race with parallel tests.
|
||||
// Run in isolation if needed:
|
||||
// cargo test -p noxa-llm env_model_override -- --ignored --test-threads=1
|
||||
#[test]
|
||||
#[ignore = "mutates process env; run with --test-threads=1"]
|
||||
fn env_model_override() {
|
||||
unsafe { std::env::set_var("GEMINI_MODEL", "gemini-1.5-pro") };
|
||||
let p = GeminiCliProvider::new(None);
|
||||
|
|
|
|||
|
|
@ -4,6 +4,9 @@
|
|||
/// extract, chain, and other modules that need a fake LLM backend.
|
||||
#[cfg(test)]
|
||||
pub(crate) mod mock {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::error::LlmError;
|
||||
|
|
@ -45,4 +48,48 @@ pub(crate) mod mock {
|
|||
self.name
|
||||
}
|
||||
}
|
||||
|
||||
/// A mock provider that returns responses from a sequence.
|
||||
/// Call N → returns responses[N], wrapping at the end.
|
||||
/// Useful for testing first-failure / second-success retry paths.
|
||||
pub struct SequenceMockProvider {
|
||||
pub name: &'static str,
|
||||
pub responses: Vec<Result<String, String>>,
|
||||
pub available: bool,
|
||||
call_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl SequenceMockProvider {
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
responses: Vec<Result<String, String>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
responses,
|
||||
available: true,
|
||||
call_count: Arc::new(AtomicUsize::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for SequenceMockProvider {
|
||||
async fn complete(&self, _request: &CompletionRequest) -> Result<String, LlmError> {
|
||||
let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||
let response = &self.responses[idx.min(self.responses.len() - 1)];
|
||||
match response {
|
||||
Ok(text) => Ok(text.clone()),
|
||||
Err(msg) => Err(LlmError::ProviderError(msg.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
async fn is_available(&self) -> bool {
|
||||
self.available
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
self.name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue