feat(noxa-9fw.3): validate structured extraction output with one retry

- Add jsonschema crate for schema validation in extract_json
- On parse failure (invalid JSON): retry once with identical request
- On schema mismatch (valid JSON, wrong schema): fail immediately — no retry
- validate_schema() produces concise error with field path from instance_path()
- Add SequenceMockProvider to testing.rs for first-fail/second-success tests
- Fix env var test flakiness: mark env_model_override as ignored
This commit is contained in:
Jacob Magar 2026-04-11 07:34:58 -04:00
parent 420a1d7522
commit 993fd6c45d
4 changed files with 230 additions and 2 deletions

View file

@ -8,6 +8,7 @@ license.workspace = true
[dependencies]
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
async-trait = "0.1"
jsonschema = { version = "0.46", default-features = false }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }

View file

@ -1,11 +1,45 @@
/// Schema-based and prompt-based LLM extraction.
/// Both functions build a system prompt, send content to the LLM, and parse JSON back.
use jsonschema;
use crate::clean::strip_thinking_tags;
use crate::error::LlmError;
use crate::provider::{CompletionRequest, LlmProvider, Message};
/// Validate a JSON value against a schema. Returns Ok(()) on success or
/// Err(LlmError::InvalidJson) with a concise error message on failure.
fn validate_schema(
value: &serde_json::Value,
schema: &serde_json::Value,
) -> Result<(), LlmError> {
let compiled = jsonschema::validator_for(schema).map_err(|e| {
LlmError::InvalidJson(format!("invalid schema: {e}"))
})?;
let errors: Vec<String> = compiled
.iter_errors(value)
.map(|e| format!("{} at {}", e, e.instance_path()))
.collect();
if errors.is_empty() {
Ok(())
} else {
Err(LlmError::InvalidJson(format!(
"schema validation failed: {}",
errors.join("; ")
)))
}
}
/// Extract structured JSON from content using a JSON schema.
/// The schema tells the LLM exactly what fields to extract and their types.
///
/// Retry policy:
/// - If the response cannot be parsed as JSON at all: retry once with the
/// identical request (handles transient formatting issues).
/// - If the response is valid JSON but fails schema validation: return
/// `LlmError::InvalidJson` immediately — the schema is likely unsatisfiable
/// for this content, so retrying would produce the same result.
pub async fn extract_json(
content: &str,
schema: &serde_json::Value,
@ -37,7 +71,22 @@ pub async fn extract_json(
};
let response = provider.complete(&request).await?;
parse_json_response(&response)
match parse_json_response(&response) {
Ok(value) => {
// Valid JSON — now validate against the schema.
// Schema mismatches do not retry (unsatisfiable → same result).
validate_schema(&value, schema)?;
Ok(value)
}
Err(_parse_err) => {
// Unparseable JSON — retry once with the identical request.
let retry_response = provider.complete(&request).await?;
let value = parse_json_response(&retry_response)?;
validate_schema(&value, schema)?;
Ok(value)
}
}
}
/// Extract information using a natural language prompt.
@ -184,4 +233,130 @@ mod tests {
assert_eq!(result["emails"][0], "test@example.com");
}
// ── Schema validation ─────────────────────────────────────────────────────
#[tokio::test]
async fn schema_validation_passes_for_matching_json() {
let schema = serde_json::json!({
"type": "object",
"required": ["price"],
"properties": {
"price": { "type": "number" }
}
});
let mock = MockProvider::ok(r#"{"price": 9.99}"#);
let result = extract_json("content", &schema, &mock, None).await.unwrap();
assert_eq!(result["price"], 9.99);
}
#[tokio::test]
async fn schema_validation_fails_for_wrong_type() {
let schema = serde_json::json!({
"type": "object",
"required": ["price"],
"properties": {
"price": { "type": "number" }
}
});
// Model returns valid JSON but wrong type ("string" instead of number).
// Should NOT retry (schema mismatch ≠ parse failure) — returns InvalidJson immediately.
let mock = MockProvider::ok(r#"{"price": "not-a-number"}"#);
let result = extract_json("content", &schema, &mock, None).await;
assert!(
matches!(result, Err(LlmError::InvalidJson(_))),
"expected InvalidJson for schema mismatch, got {result:?}"
);
}
#[tokio::test]
async fn schema_validation_fails_for_missing_required_field() {
let schema = serde_json::json!({
"type": "object",
"required": ["title"],
"properties": {
"title": { "type": "string" }
}
});
let mock = MockProvider::ok(r#"{"other": "value"}"#);
let result = extract_json("content", &schema, &mock, None).await;
assert!(matches!(result, Err(LlmError::InvalidJson(_))));
}
#[tokio::test]
async fn parse_failure_triggers_one_retry() {
use crate::testing::mock::SequenceMockProvider;
let schema = serde_json::json!({
"type": "object",
"properties": { "title": { "type": "string" } }
});
// First call: unparseable JSON. Second call: valid JSON matching schema.
let mock = SequenceMockProvider::new(
"mock-seq",
vec![
Ok("this is not json at all".to_string()),
Ok(r#"{"title": "Retry succeeded"}"#.to_string()),
],
);
let result = extract_json("content", &schema, &mock, None)
.await
.unwrap();
assert_eq!(result["title"], "Retry succeeded");
}
#[tokio::test]
async fn both_attempts_fail_returns_invalid_json() {
use crate::testing::mock::SequenceMockProvider;
let schema = serde_json::json!({
"type": "object",
"properties": { "title": { "type": "string" } }
});
let mock = SequenceMockProvider::new(
"mock-seq",
vec![
Ok("not json".to_string()),
Ok("also not json".to_string()),
],
);
let result = extract_json("content", &schema, &mock, None).await;
assert!(
matches!(result, Err(LlmError::InvalidJson(_))),
"expected InvalidJson after both attempts fail"
);
}
#[tokio::test]
async fn schema_mismatch_does_not_retry() {
use crate::testing::mock::SequenceMockProvider;
let schema = serde_json::json!({
"type": "object",
"required": ["price"],
"properties": {
"price": { "type": "number" }
}
});
// Both calls return valid JSON with wrong schema — but only one call should happen.
let mock = SequenceMockProvider::new(
"mock-seq",
vec![
Ok(r#"{"price": "wrong-type"}"#.to_string()),
Ok(r#"{"price": 9.99}"#.to_string()), // would succeed — but shouldn't be called
],
);
// Should return InvalidJson without calling second response.
let result = extract_json("content", &schema, &mock, None).await;
assert!(
matches!(result, Err(LlmError::InvalidJson(_))),
"schema mismatch should not trigger retry"
);
}
}

View file

@ -10,7 +10,7 @@ use tokio::io::AsyncWriteExt;
use tokio::process::Command;
use tokio::sync::Semaphore;
use tokio::time::timeout;
use tracing::{debug, warn};
use tracing::debug;
use crate::clean::strip_thinking_tags;
use crate::error::LlmError;
@ -41,6 +41,7 @@ impl GeminiCliProvider {
}
}
#[cfg(test)]
fn default_model(&self) -> &str {
&self.default_model
}
@ -199,7 +200,11 @@ mod tests {
assert_eq!(p.default_model(), "gemini-2.5-pro");
}
// Env var tests mutate process-global state and race with parallel tests.
// Run in isolation if needed:
// cargo test -p noxa-llm env_model_override -- --ignored --test-threads=1
#[test]
#[ignore = "mutates process env; run with --test-threads=1"]
fn env_model_override() {
unsafe { std::env::set_var("GEMINI_MODEL", "gemini-1.5-pro") };
let p = GeminiCliProvider::new(None);

View file

@ -4,6 +4,9 @@
/// extract, chain, and other modules that need a fake LLM backend.
#[cfg(test)]
pub(crate) mod mock {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use crate::error::LlmError;
@ -45,4 +48,48 @@ pub(crate) mod mock {
self.name
}
}
/// A mock provider that returns responses from a sequence.
/// Call N → returns responses[N], wrapping at the end.
/// Useful for testing first-failure / second-success retry paths.
pub struct SequenceMockProvider {
pub name: &'static str,
pub responses: Vec<Result<String, String>>,
pub available: bool,
call_count: Arc<AtomicUsize>,
}
impl SequenceMockProvider {
pub fn new(
name: &'static str,
responses: Vec<Result<String, String>>,
) -> Self {
Self {
name,
responses,
available: true,
call_count: Arc::new(AtomicUsize::new(0)),
}
}
}
#[async_trait]
impl LlmProvider for SequenceMockProvider {
async fn complete(&self, _request: &CompletionRequest) -> Result<String, LlmError> {
let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
let response = &self.responses[idx.min(self.responses.len() - 1)];
match response {
Ok(text) => Ok(text.clone()),
Err(msg) => Err(LlmError::ProviderError(msg.clone())),
}
}
async fn is_available(&self) -> bool {
self.available
}
fn name(&self) -> &str {
self.name
}
}
}