mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-27 17:36:23 +02:00
Update RAG and Agent clients for streaming message handling GraphRAG now sends multiple message types in a stream: - 'explain' messages with explain_id and explain_graph for provenance - 'chunk' messages with response text fragments - end_of_session marker for stream completion Updated all clients to handle this properly: CLI clients (trustgraph-base/trustgraph/clients/): - graph_rag_client.py: Added chunk_callback and explain_callback - document_rag_client.py: Added chunk_callback and explain_callback - agent_client.py: Added think, observe, answer_callback, error_callback Internal clients (trustgraph-base/trustgraph/base/): - graph_rag_client.py: Async callbacks for streaming - agent_client.py: Async callbacks for streaming All clients now: - Route messages by chunk_type/message_type - Stream via optional callbacks for incremental delivery - Wait for proper completion signals (end_of_dialog/end_of_session/end_of_stream) - Accumulate and return complete response for callers not using callbacks Updated callers: - extract/kg/agent/extract.py: Uses new invoke(question=...) API - tests/integration/test_agent_kg_extraction_integration.py: Updated mocks This fixes the agent infinite loop issue where knowledge_query was returning the first 'explain' message (empty response) instead of waiting for the actual answer chunks. Concurrency in triples query
357 lines
11 KiB
Python
357 lines
11 KiB
Python
import re
|
|
import json
|
|
import urllib.parse
|
|
import logging
|
|
|
|
from ....schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
|
|
from ....schema import EntityContext, EntityContexts
|
|
|
|
from ....rdf import TRUSTGRAPH_ENTITIES, RDF_LABEL, SUBJECT_OF, DEFINITION
|
|
|
|
from ....base import FlowProcessor, ConsumerSpec, ProducerSpec
|
|
from ....base import AgentClientSpec
|
|
|
|
from ....template import PromptManager
|
|
|
|
# Module logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
default_ident = "kg-extract-agent"
|
|
default_concurrency = 1
|
|
default_template_id = "agent-kg-extract"
|
|
default_config_type = "prompt"
|
|
|
|
class Processor(FlowProcessor):
|
|
|
|
def __init__(self, **params):
|
|
|
|
id = params.get("id")
|
|
concurrency = params.get("concurrency", 1)
|
|
template_id = params.get("template_id", default_template_id)
|
|
config_key = params.get("config_type", default_config_type)
|
|
|
|
super().__init__(**params | {
|
|
"id": id,
|
|
"template_id": template_id,
|
|
"config_type": config_key,
|
|
"concurrency": concurrency,
|
|
})
|
|
|
|
self.concurrency = concurrency
|
|
self.template_id = template_id
|
|
self.config_key = config_key
|
|
|
|
self.register_config_handler(self.on_prompt_config)
|
|
|
|
self.register_specification(
|
|
ConsumerSpec(
|
|
name = "input",
|
|
schema = Chunk,
|
|
handler = self.on_message,
|
|
concurrency = self.concurrency,
|
|
)
|
|
)
|
|
|
|
self.register_specification(
|
|
AgentClientSpec(
|
|
request_name = "agent-request",
|
|
response_name = "agent-response",
|
|
)
|
|
)
|
|
|
|
self.register_specification(
|
|
ProducerSpec(
|
|
name="triples",
|
|
schema=Triples,
|
|
)
|
|
)
|
|
|
|
self.register_specification(
|
|
ProducerSpec(
|
|
name="entity-contexts",
|
|
schema=EntityContexts,
|
|
)
|
|
)
|
|
|
|
# Null configuration, should reload quickly
|
|
self.manager = PromptManager()
|
|
|
|
async def on_prompt_config(self, config, version):
|
|
|
|
logger.info(f"Loading configuration version {version}")
|
|
|
|
if self.config_key not in config:
|
|
logger.warning(f"No key {self.config_key} in config")
|
|
return
|
|
|
|
config = config[self.config_key]
|
|
|
|
try:
|
|
|
|
self.manager.load_config(config)
|
|
|
|
logger.info("Prompt configuration reloaded")
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Configuration reload exception: {e}", exc_info=True)
|
|
logger.error("Configuration reload failed")
|
|
|
|
def to_uri(self, text):
|
|
return TRUSTGRAPH_ENTITIES + urllib.parse.quote(text)
|
|
|
|
async def emit_triples(self, pub, metadata, triples):
|
|
tpls = Triples(
|
|
metadata = Metadata(
|
|
id = metadata.id,
|
|
root = metadata.root,
|
|
user = metadata.user,
|
|
collection = metadata.collection,
|
|
),
|
|
triples = triples,
|
|
)
|
|
|
|
await pub.send(tpls)
|
|
|
|
async def emit_entity_contexts(self, pub, metadata, entity_contexts):
|
|
ecs = EntityContexts(
|
|
metadata = Metadata(
|
|
id = metadata.id,
|
|
root = metadata.root,
|
|
user = metadata.user,
|
|
collection = metadata.collection,
|
|
),
|
|
entities = entity_contexts,
|
|
)
|
|
|
|
await pub.send(ecs)
|
|
|
|
def parse_jsonl(self, text):
|
|
"""
|
|
Parse JSONL response, returning list of valid objects.
|
|
|
|
Invalid lines (malformed JSON, empty lines) are skipped with warnings.
|
|
This provides truncation resilience - partial output yields partial results.
|
|
"""
|
|
results = []
|
|
|
|
# Strip markdown code fences if present
|
|
text = text.strip()
|
|
if text.startswith('```'):
|
|
# Remove opening fence (possibly with language hint)
|
|
text = re.sub(r'^```(?:json|jsonl)?\s*\n?', '', text)
|
|
if text.endswith('```'):
|
|
text = text[:-3]
|
|
|
|
for line_num, line in enumerate(text.strip().split('\n'), 1):
|
|
line = line.strip()
|
|
|
|
# Skip empty lines
|
|
if not line:
|
|
continue
|
|
|
|
# Skip any remaining fence markers
|
|
if line.startswith('```'):
|
|
continue
|
|
|
|
try:
|
|
obj = json.loads(line)
|
|
results.append(obj)
|
|
except json.JSONDecodeError as e:
|
|
# Log warning but continue - this provides truncation resilience
|
|
logger.warning(f"JSONL parse error on line {line_num}: {e}")
|
|
|
|
return results
|
|
|
|
async def on_message(self, msg, consumer, flow):
|
|
|
|
try:
|
|
|
|
v = msg.value()
|
|
|
|
# Extract chunk text
|
|
chunk_text = v.chunk.decode('utf-8')
|
|
|
|
logger.debug("Processing chunk for agent extraction")
|
|
|
|
prompt = self.manager.render(
|
|
self.template_id,
|
|
{
|
|
"text": chunk_text
|
|
}
|
|
)
|
|
|
|
logger.debug(f"Agent prompt: {prompt}")
|
|
|
|
# Send to agent API
|
|
agent_response = await flow("agent-request").invoke(
|
|
question = prompt
|
|
)
|
|
|
|
# Parse JSONL response
|
|
extraction_data = self.parse_jsonl(agent_response)
|
|
|
|
if not extraction_data:
|
|
logger.warning("JSONL parse returned no valid objects")
|
|
return
|
|
|
|
# Process extraction data
|
|
triples, entity_contexts = self.process_extraction_data(
|
|
extraction_data, v.metadata
|
|
)
|
|
|
|
# Emit outputs
|
|
if triples:
|
|
await self.emit_triples(flow("triples"), v.metadata, triples)
|
|
|
|
if entity_contexts:
|
|
await self.emit_entity_contexts(
|
|
flow("entity-contexts"),
|
|
v.metadata,
|
|
entity_contexts
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing chunk: {e}", exc_info=True)
|
|
raise
|
|
|
|
def process_extraction_data(self, data, metadata):
|
|
"""Process JSONL extraction data to generate triples and entity contexts.
|
|
|
|
Data is a flat list of objects with 'type' discriminator field:
|
|
- {"type": "definition", "entity": "...", "definition": "..."}
|
|
- {"type": "relationship", "subject": "...", "predicate": "...", "object": "...", "object-entity": bool}
|
|
"""
|
|
triples = []
|
|
entity_contexts = []
|
|
|
|
# Categorize items by type
|
|
definitions = [item for item in data if item.get("type") == "definition"]
|
|
relationships = [item for item in data if item.get("type") == "relationship"]
|
|
|
|
# Process definitions
|
|
for defn in definitions:
|
|
|
|
entity_uri = self.to_uri(defn["entity"])
|
|
|
|
# Add entity label
|
|
triples.append(Triple(
|
|
s = Term(type=IRI, iri=entity_uri),
|
|
p = Term(type=IRI, iri=RDF_LABEL),
|
|
o = Term(type=LITERAL, value=defn["entity"]),
|
|
))
|
|
|
|
# Add definition
|
|
triples.append(Triple(
|
|
s = Term(type=IRI, iri=entity_uri),
|
|
p = Term(type=IRI, iri=DEFINITION),
|
|
o = Term(type=LITERAL, value=defn["definition"]),
|
|
))
|
|
|
|
# Add subject-of relationship to document
|
|
if metadata.id:
|
|
triples.append(Triple(
|
|
s = Term(type=IRI, iri=entity_uri),
|
|
p = Term(type=IRI, iri=SUBJECT_OF),
|
|
o = Term(type=IRI, iri=metadata.id),
|
|
))
|
|
|
|
# Create entity context for embeddings
|
|
entity_contexts.append(EntityContext(
|
|
entity=Term(type=IRI, iri=entity_uri),
|
|
context=defn["definition"]
|
|
))
|
|
|
|
# Process relationships
|
|
for rel in relationships:
|
|
|
|
subject_uri = self.to_uri(rel["subject"])
|
|
predicate_uri = self.to_uri(rel["predicate"])
|
|
|
|
subject_value = Term(type=IRI, iri=subject_uri)
|
|
predicate_value = Term(type=IRI, iri=predicate_uri)
|
|
if rel.get("object-entity", True):
|
|
object_uri = self.to_uri(rel["object"])
|
|
object_value = Term(type=IRI, iri=object_uri)
|
|
else:
|
|
object_value = Term(type=LITERAL, value=rel["object"])
|
|
|
|
# Add subject and predicate labels
|
|
triples.append(Triple(
|
|
s = subject_value,
|
|
p = Term(type=IRI, iri=RDF_LABEL),
|
|
o = Term(type=LITERAL, value=rel["subject"]),
|
|
))
|
|
|
|
triples.append(Triple(
|
|
s = predicate_value,
|
|
p = Term(type=IRI, iri=RDF_LABEL),
|
|
o = Term(type=LITERAL, value=rel["predicate"]),
|
|
))
|
|
|
|
# Handle object (entity vs literal)
|
|
if rel.get("object-entity", True):
|
|
triples.append(Triple(
|
|
s = object_value,
|
|
p = Term(type=IRI, iri=RDF_LABEL),
|
|
o = Term(type=LITERAL, value=rel["object"]),
|
|
))
|
|
|
|
# Add the main relationship triple
|
|
triples.append(Triple(
|
|
s = subject_value,
|
|
p = predicate_value,
|
|
o = object_value
|
|
))
|
|
|
|
# Add subject-of relationships to document
|
|
if metadata.id:
|
|
triples.append(Triple(
|
|
s = subject_value,
|
|
p = Term(type=IRI, iri=SUBJECT_OF),
|
|
o = Term(type=IRI, iri=metadata.id),
|
|
))
|
|
|
|
triples.append(Triple(
|
|
s = predicate_value,
|
|
p = Term(type=IRI, iri=SUBJECT_OF),
|
|
o = Term(type=IRI, iri=metadata.id),
|
|
))
|
|
|
|
if rel.get("object-entity", True):
|
|
triples.append(Triple(
|
|
s = object_value,
|
|
p = Term(type=IRI, iri=SUBJECT_OF),
|
|
o = Term(type=IRI, iri=metadata.id),
|
|
))
|
|
|
|
return triples, entity_contexts
|
|
|
|
@staticmethod
|
|
def add_args(parser):
|
|
|
|
parser.add_argument(
|
|
'-c', '--concurrency',
|
|
type=int,
|
|
default=default_concurrency,
|
|
help=f'Concurrent processing threads (default: {default_concurrency})'
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--template-id",
|
|
type=str,
|
|
default=default_template_id,
|
|
help="Template ID to use for agent extraction"
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--config-type',
|
|
default="prompt",
|
|
help=f'Configuration key for prompts (default: prompt)',
|
|
)
|
|
|
|
FlowProcessor.add_args(parser)
|
|
|
|
def run():
|
|
|
|
Processor.launch(default_ident, __doc__)
|