mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
Update to enable knowledge extraction using the agent framework (#439)
* Implement KG extraction agent (kg-extract-agent) * Using ReAct framework (agent-manager-react) * ReAct manager had an issue when emitting JSON, which conflicts which ReAct manager's own JSON messages, so refactored ReAct manager to use traditional ReAct messages, non-JSON structure. * Minor refactor to take the prompt template client out of prompt-template so it can be more readily used by other modules. kg-extract-agent uses this framework.
This commit is contained in:
parent
1fe4ed5226
commit
d83e4e3d59
30 changed files with 3192 additions and 799 deletions
345
tests/unit/test_prompt_manager.py
Normal file
345
tests/unit/test_prompt_manager.py
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
"""
|
||||
Unit tests for PromptManager
|
||||
|
||||
These tests verify the functionality of the PromptManager class,
|
||||
including template rendering, term merging, JSON validation, and error handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from trustgraph.template.prompt_manager import PromptManager, PromptConfiguration, Prompt
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestPromptManager:
|
||||
"""Unit tests for PromptManager template functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config(self):
|
||||
"""Sample configuration dict for PromptManager"""
|
||||
return {
|
||||
"system": json.dumps("You are a helpful assistant."),
|
||||
"template-index": json.dumps(["simple_text", "json_response", "complex_template"]),
|
||||
"template.simple_text": json.dumps({
|
||||
"prompt": "Hello {{ name }}, welcome to {{ system_name }}!",
|
||||
"response-type": "text"
|
||||
}),
|
||||
"template.json_response": json.dumps({
|
||||
"prompt": "Generate a user profile for {{ username }}",
|
||||
"response-type": "json",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "number"}
|
||||
},
|
||||
"required": ["name", "age"]
|
||||
}
|
||||
}),
|
||||
"template.complex_template": json.dumps({
|
||||
"prompt": """
|
||||
{% for item in items %}
|
||||
- {{ item.name }}: {{ item.value }}
|
||||
{% endfor %}
|
||||
""",
|
||||
"response-type": "text"
|
||||
})
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_manager(self, sample_config):
|
||||
"""Create a PromptManager with sample configuration"""
|
||||
pm = PromptManager()
|
||||
pm.load_config(sample_config)
|
||||
# Add global terms manually since load_config doesn't handle them
|
||||
pm.terms["system_name"] = "TrustGraph"
|
||||
pm.terms["version"] = "1.0"
|
||||
return pm
|
||||
|
||||
def test_prompt_manager_initialization(self, prompt_manager, sample_config):
|
||||
"""Test PromptManager initialization with configuration"""
|
||||
assert prompt_manager.config.system_template == "You are a helpful assistant."
|
||||
assert len(prompt_manager.prompts) == 3
|
||||
assert "simple_text" in prompt_manager.prompts
|
||||
|
||||
def test_simple_text_template_rendering(self, prompt_manager):
|
||||
"""Test basic template rendering with text response"""
|
||||
terms = {"name": "Alice"}
|
||||
|
||||
rendered = prompt_manager.render("simple_text", terms)
|
||||
|
||||
assert rendered == "Hello Alice, welcome to TrustGraph!"
|
||||
|
||||
def test_global_terms_merging(self, prompt_manager):
|
||||
"""Test that global terms are properly merged"""
|
||||
terms = {"name": "Bob"}
|
||||
|
||||
# Global terms should be available in template
|
||||
rendered = prompt_manager.render("simple_text", terms)
|
||||
|
||||
assert "TrustGraph" in rendered # From global terms
|
||||
assert "Bob" in rendered # From input terms
|
||||
|
||||
def test_term_override_priority(self):
|
||||
"""Test term override priority: input > prompt > global"""
|
||||
# Create a fresh PromptManager for this test
|
||||
pm = PromptManager()
|
||||
config = {
|
||||
"system": json.dumps("Test"),
|
||||
"template-index": json.dumps(["test"]),
|
||||
"template.test": json.dumps({
|
||||
"prompt": "Value is: {{ value }}",
|
||||
"response-type": "text"
|
||||
})
|
||||
}
|
||||
pm.load_config(config)
|
||||
|
||||
# Set up terms at different levels
|
||||
pm.terms["value"] = "global" # Global term
|
||||
if "test" in pm.prompts:
|
||||
pm.prompts["test"].terms = {"value": "prompt"} # Prompt term
|
||||
|
||||
# Test with no input override - prompt terms should win
|
||||
rendered = pm.render("test", {})
|
||||
if "test" in pm.prompts and pm.prompts["test"].terms:
|
||||
assert rendered == "Value is: prompt" # Prompt terms override global
|
||||
else:
|
||||
assert rendered == "Value is: global" # No prompt terms, use global
|
||||
|
||||
# Test with input override - input terms should win
|
||||
rendered = pm.render("test", {"value": "input"})
|
||||
assert rendered == "Value is: input" # Input terms override all
|
||||
|
||||
def test_complex_template_rendering(self, prompt_manager):
|
||||
"""Test complex template with loops and filters"""
|
||||
terms = {
|
||||
"items": [
|
||||
{"name": "Item1", "value": 10},
|
||||
{"name": "Item2", "value": 20},
|
||||
{"name": "Item3", "value": 30}
|
||||
]
|
||||
}
|
||||
|
||||
rendered = prompt_manager.render("complex_template", terms)
|
||||
|
||||
assert "Item1: 10" in rendered
|
||||
assert "Item2: 20" in rendered
|
||||
assert "Item3: 30" in rendered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_text_response(self, prompt_manager):
|
||||
"""Test invoking a prompt with text response"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = "Welcome Alice to TrustGraph!"
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"simple_text",
|
||||
{"name": "Alice"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert result == "Welcome Alice to TrustGraph!"
|
||||
|
||||
# Verify LLM was called with correct prompts
|
||||
mock_llm.assert_called_once()
|
||||
call_args = mock_llm.call_args[1]
|
||||
assert call_args["system"] == "You are a helpful assistant."
|
||||
assert "Hello Alice, welcome to TrustGraph!" in call_args["prompt"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_json_response_valid(self, prompt_manager):
|
||||
"""Test invoking a prompt with valid JSON response"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = '{"name": "John Doe", "age": 30}'
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"json_response",
|
||||
{"username": "johndoe"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["name"] == "John Doe"
|
||||
assert result["age"] == 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_json_response_with_markdown(self, prompt_manager):
|
||||
"""Test JSON extraction from markdown code blocks"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = """
|
||||
Here is the user profile:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "Jane Smith",
|
||||
"age": 25
|
||||
}
|
||||
```
|
||||
|
||||
This is a valid profile.
|
||||
"""
|
||||
|
||||
result = await prompt_manager.invoke(
|
||||
"json_response",
|
||||
{"username": "janesmith"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["name"] == "Jane Smith"
|
||||
assert result["age"] == 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_json_validation_failure(self, prompt_manager):
|
||||
"""Test JSON schema validation failure"""
|
||||
mock_llm = AsyncMock()
|
||||
# Missing required 'age' field
|
||||
mock_llm.return_value = '{"name": "Invalid User"}'
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await prompt_manager.invoke(
|
||||
"json_response",
|
||||
{"username": "invalid"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert "Schema validation fail" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_json_parse_failure(self, prompt_manager):
|
||||
"""Test invalid JSON parsing"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = "This is not JSON at all"
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await prompt_manager.invoke(
|
||||
"json_response",
|
||||
{"username": "test"},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
assert "JSON parse fail" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_unknown_prompt(self, prompt_manager):
|
||||
"""Test invoking an unknown prompt ID"""
|
||||
mock_llm = AsyncMock()
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
await prompt_manager.invoke(
|
||||
"nonexistent_prompt",
|
||||
{},
|
||||
mock_llm
|
||||
)
|
||||
|
||||
def test_template_rendering_with_undefined_variable(self, prompt_manager):
|
||||
"""Test template rendering with undefined variables"""
|
||||
terms = {} # Missing 'name' variable
|
||||
|
||||
# ibis might handle undefined variables differently than Jinja2
|
||||
# Let's test what actually happens
|
||||
try:
|
||||
result = prompt_manager.render("simple_text", terms)
|
||||
# If no exception, check that undefined variables are handled somehow
|
||||
assert isinstance(result, str)
|
||||
except Exception as e:
|
||||
# If exception is raised, that's also acceptable behavior
|
||||
assert "name" in str(e) or "undefined" in str(e).lower() or "variable" in str(e).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_response_without_schema(self):
|
||||
"""Test JSON response without schema validation"""
|
||||
pm = PromptManager()
|
||||
config = {
|
||||
"system": json.dumps("Test"),
|
||||
"template-index": json.dumps(["no_schema"]),
|
||||
"template.no_schema": json.dumps({
|
||||
"prompt": "Generate any JSON",
|
||||
"response-type": "json"
|
||||
# No schema defined
|
||||
})
|
||||
}
|
||||
pm.load_config(config)
|
||||
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.return_value = '{"any": "json", "is": "valid"}'
|
||||
|
||||
result = await pm.invoke("no_schema", {}, mock_llm)
|
||||
|
||||
assert result == {"any": "json", "is": "valid"}
|
||||
|
||||
def test_prompt_configuration_validation(self):
|
||||
"""Test PromptConfiguration validation"""
|
||||
# Valid configuration
|
||||
config = PromptConfiguration(
|
||||
system_template="Test system",
|
||||
prompts={
|
||||
"test": Prompt(
|
||||
template="Hello {{ name }}",
|
||||
response_type="text"
|
||||
)
|
||||
}
|
||||
)
|
||||
assert config.system_template == "Test system"
|
||||
assert len(config.prompts) == 1
|
||||
|
||||
def test_nested_template_includes(self):
|
||||
"""Test templates with nested variable references"""
|
||||
# Create a fresh PromptManager for this test
|
||||
pm = PromptManager()
|
||||
config = {
|
||||
"system": json.dumps("Test"),
|
||||
"template-index": json.dumps(["nested"]),
|
||||
"template.nested": json.dumps({
|
||||
"prompt": "{{ greeting }} from {{ company }} in {{ year }}!",
|
||||
"response-type": "text"
|
||||
})
|
||||
}
|
||||
pm.load_config(config)
|
||||
|
||||
# Set up global and prompt terms
|
||||
pm.terms["company"] = "TrustGraph"
|
||||
pm.terms["year"] = "2024"
|
||||
if "nested" in pm.prompts:
|
||||
pm.prompts["nested"].terms = {"greeting": "Welcome"}
|
||||
|
||||
rendered = pm.render("nested", {"user": "Alice", "greeting": "Welcome"})
|
||||
|
||||
# Should contain company and year from global terms
|
||||
assert "TrustGraph" in rendered
|
||||
assert "2024" in rendered
|
||||
assert "Welcome" in rendered
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_invocations(self, prompt_manager):
|
||||
"""Test concurrent prompt invocations"""
|
||||
mock_llm = AsyncMock()
|
||||
mock_llm.side_effect = [
|
||||
"Response for Alice",
|
||||
"Response for Bob",
|
||||
"Response for Charlie"
|
||||
]
|
||||
|
||||
# Simulate concurrent invocations
|
||||
import asyncio
|
||||
results = await asyncio.gather(
|
||||
prompt_manager.invoke("simple_text", {"name": "Alice"}, mock_llm),
|
||||
prompt_manager.invoke("simple_text", {"name": "Bob"}, mock_llm),
|
||||
prompt_manager.invoke("simple_text", {"name": "Charlie"}, mock_llm)
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
assert "Alice" in results[0]
|
||||
assert "Bob" in results[1]
|
||||
assert "Charlie" in results[2]
|
||||
|
||||
def test_empty_configuration(self):
|
||||
"""Test PromptManager with minimal configuration"""
|
||||
pm = PromptManager()
|
||||
pm.load_config({}) # Empty config
|
||||
|
||||
assert pm.config.system_template == "Be helpful." # Default system
|
||||
assert pm.terms == {} # Default empty terms
|
||||
assert len(pm.prompts) == 0
|
||||
Loading…
Add table
Add a link
Reference in a new issue