Merge branch 'release/v2.1'

This commit is contained in:
Cyber MacGeddon 2026-03-17 20:44:03 +00:00
commit 824f993985
266 changed files with 33195 additions and 5834 deletions

View file

@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"trustgraph-base>=2.0,<2.1",
"trustgraph-base>=2.1,<2.2",
"aiohttp",
"anthropic",
"scylla-driver",
@ -27,7 +27,7 @@ dependencies = [
"langchain-text-splitters",
"mcp",
"minio",
"mistralai",
"mistralai<2.0.0",
"neo4j",
"nltk",
"ollama",
@ -122,6 +122,7 @@ triples-write-falkordb = "trustgraph.storage.triples.falkordb:run"
triples-write-memgraph = "trustgraph.storage.triples.memgraph:run"
triples-write-neo4j = "trustgraph.storage.triples.neo4j:run"
wikipedia-lookup = "trustgraph.external.wikipedia:run"
joke-service = "trustgraph.tool_service.joke:run"
[tool.setuptools.packages.find]
include = ["trustgraph*"]

View file

@ -2,11 +2,15 @@
Simple agent infrastructure broadly implements the ReAct flow.
"""
import asyncio
import base64
import json
import re
import sys
import functools
import logging
import uuid
from datetime import datetime
# Module logger
logger = logging.getLogger(__name__)
@ -14,10 +18,30 @@ logger = logging.getLogger(__name__)
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
from ... base import ProducerSpec
from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl
# Provenance imports for agent explainability
from trustgraph.provenance import (
agent_session_uri,
agent_iteration_uri,
agent_thought_uri,
agent_observation_uri,
agent_final_uri,
agent_session_triples,
agent_iteration_triples,
agent_final_triples,
set_graph,
GRAPH_RETRIEVAL,
)
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl, ToolServiceImpl
from . agent_manager import AgentManager
from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state
@ -25,6 +49,8 @@ from . types import Final, Action, Tool, Argument
default_ident = "agent-manager"
default_max_iterations = 10
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(AgentService):
@ -51,6 +77,9 @@ class Processor(AgentService):
additional_context="",
)
# Track active tool service clients for cleanup
self.tool_service_clients = {}
self.config_handlers.append(self.on_tools_config)
self.register_specification(
@ -102,6 +131,123 @@ class Processor(AgentService):
)
)
# Explainability producer for agent provenance triples
self.register_specification(
ProducerSpec(
name = "explainability",
schema = Triples,
)
)
# Librarian client for storing answer content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_librarian_requests = {}
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or "Agent Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_tools_config(self, config, version):
logger.info(f"Loading configuration version {version}")
@ -110,6 +256,16 @@ class Processor(AgentService):
tools = {}
# Load tool-service configurations first
tool_services = {}
if "tool-service" in config:
for service_id, service_value in config["tool-service"].items():
service_data = json.loads(service_value)
tool_services[service_id] = service_data
logger.debug(f"Loaded tool-service config: {service_id}")
logger.info(f"Loaded {len(tool_services)} tool-service configurations")
# Load tool configurations from the new location
if "tool" in config:
for tool_id, tool_value in config["tool"].items():
@ -177,6 +333,59 @@ class Processor(AgentService):
limit=int(data.get("limit", 10)) # Max results
)
arguments = RowEmbeddingsQueryImpl.get_arguments()
elif impl_id == "tool-service":
# Dynamic tool service - look up the service config
service_ref = data.get("service")
if not service_ref:
raise RuntimeError(
f"Tool {name} has type 'tool-service' but no 'service' reference"
)
if service_ref not in tool_services:
raise RuntimeError(
f"Tool {name} references unknown tool-service '{service_ref}'"
)
service_config = tool_services[service_ref]
request_queue = service_config.get("request-queue")
response_queue = service_config.get("response-queue")
if not request_queue or not response_queue:
raise RuntimeError(
f"Tool-service '{service_ref}' must define 'request-queue' and 'response-queue'"
)
# Build config values from tool config
# Extract any config params defined by the service
config_params = service_config.get("config-params", [])
config_values = {}
for param in config_params:
param_name = param.get("name") if isinstance(param, dict) else param
if param_name in data:
config_values[param_name] = data[param_name]
elif isinstance(param, dict) and param.get("required", False):
raise RuntimeError(
f"Tool {name} missing required config param '{param_name}'"
)
# Arguments come from tool config
config_args = data.get("arguments", [])
arguments = [
Argument(
name=arg.get("name"),
type=arg.get("type"),
description=arg.get("description")
)
for arg in config_args
]
# Store queues for the implementation
impl = functools.partial(
ToolServiceImpl,
request_queue=request_queue,
response_queue=response_queue,
config_values=config_values,
arguments=arguments,
processor=self,
)
else:
raise RuntimeError(
f"Tool type {impl_id} not known"
@ -219,6 +428,10 @@ class Processor(AgentService):
# Check if streaming is enabled
streaming = getattr(request, 'streaming', False)
# Generate or retrieve session ID for provenance tracking
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
collection = getattr(request, 'collection', 'default')
if request.history:
history = [
Action(
@ -232,6 +445,36 @@ class Processor(AgentService):
else:
history = []
# Calculate iteration number (1-based)
iteration_num = len(history) + 1
session_uri = agent_session_uri(session_id)
# On first iteration, emit session triples
if iteration_num == 1:
timestamp = datetime.utcnow().isoformat() + "Z"
triples = set_graph(
agent_session_triples(session_uri, request.question, timestamp),
GRAPH_RETRIEVAL
)
await flow("explainability").send(Triples(
metadata=Metadata(
id=session_uri,
user=request.user,
collection=collection,
),
triples=triples,
))
logger.debug(f"Emitted session triples for {session_uri}")
# Send explain event for session
if streaming:
await respond(AgentResponse(
chunk_type="explain",
content="",
explain_id=session_uri,
explain_graph=GRAPH_RETRIEVAL,
))
logger.info(f"Question: {request.question}")
if len(history) >= self.max_iterations:
@ -381,6 +624,60 @@ class Processor(AgentService):
else:
f = json.dumps(act.final)
# Emit final answer provenance triples
final_uri = agent_final_uri(session_id)
# No iterations: link to question; otherwise: link to last iteration
if iteration_num > 1:
final_question_uri = None
final_previous_uri = agent_iteration_uri(session_id, iteration_num - 1)
else:
final_question_uri = session_uri
final_previous_uri = None
# Save answer to librarian
answer_doc_id = None
if f:
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
try:
await self.save_answer_content(
doc_id=answer_doc_id,
user=request.user,
content=f,
title=f"Agent Answer: {request.question[:50]}...",
)
logger.debug(f"Saved answer to librarian: {answer_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
answer_doc_id = None # Fall back to inline content
final_triples = set_graph(
agent_final_triples(
final_uri,
question_uri=final_question_uri,
previous_uri=final_previous_uri,
document_id=answer_doc_id,
),
GRAPH_RETRIEVAL
)
await flow("explainability").send(Triples(
metadata=Metadata(
id=final_uri,
user=request.user,
collection=collection,
),
triples=final_triples,
))
logger.debug(f"Emitted final triples for {final_uri}")
# Send explain event for conclusion
if streaming:
await respond(AgentResponse(
chunk_type="explain",
content="",
explain_id=final_uri,
explain_graph=GRAPH_RETRIEVAL,
))
if streaming:
# Streaming format - send end-of-dialog marker
# Answer chunks were already sent via answer() callback during parsing
@ -413,8 +710,86 @@ class Processor(AgentService):
logger.debug("Send next...")
# Emit iteration provenance triples
iteration_uri = agent_iteration_uri(session_id, iteration_num)
# First iteration links to question, subsequent to previous
if iteration_num > 1:
iter_question_uri = None
iter_previous_uri = agent_iteration_uri(session_id, iteration_num - 1)
else:
iter_question_uri = session_uri
iter_previous_uri = None
# Save thought to librarian
thought_doc_id = None
if act.thought:
thought_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
try:
await self.save_answer_content(
doc_id=thought_doc_id,
user=request.user,
content=act.thought,
title=f"Agent Thought: {act.name}",
)
logger.debug(f"Saved thought to librarian: {thought_doc_id}")
except Exception as e:
logger.warning(f"Failed to save thought to librarian: {e}")
thought_doc_id = None
# Save observation to librarian
observation_doc_id = None
if act.observation:
observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
try:
await self.save_answer_content(
doc_id=observation_doc_id,
user=request.user,
content=act.observation,
title=f"Agent Observation: {act.name}",
)
logger.debug(f"Saved observation to librarian: {observation_doc_id}")
except Exception as e:
logger.warning(f"Failed to save observation to librarian: {e}")
observation_doc_id = None
thought_entity_uri = agent_thought_uri(session_id, iteration_num)
observation_entity_uri = agent_observation_uri(session_id, iteration_num)
iter_triples = set_graph(
agent_iteration_triples(
iteration_uri,
question_uri=iter_question_uri,
previous_uri=iter_previous_uri,
action=act.name,
arguments=act.arguments,
thought_uri=thought_entity_uri if thought_doc_id else None,
thought_document_id=thought_doc_id,
observation_uri=observation_entity_uri if observation_doc_id else None,
observation_document_id=observation_doc_id,
),
GRAPH_RETRIEVAL
)
await flow("explainability").send(Triples(
metadata=Metadata(
id=iteration_uri,
user=request.user,
collection=collection,
),
triples=iter_triples,
))
logger.debug(f"Emitted iteration triples for {iteration_uri}")
# Send explain event for iteration
if streaming:
await respond(AgentResponse(
chunk_type="explain",
content="",
explain_id=iteration_uri,
explain_graph=GRAPH_RETRIEVAL,
))
history.append(act)
# Handle state transitions if tool execution was successful
next_state = request.state
if act.name in filtered_tools:
@ -435,7 +810,9 @@ class Processor(AgentService):
for h in history
],
user=request.user,
collection=collection,
streaming=streaming,
session_id=session_id, # Pass session_id for provenance continuity
)
await next(r)

View file

@ -154,7 +154,8 @@ class RowEmbeddingsQueryImpl:
logger.debug("Getting embeddings for row query...")
query_text = arguments.get("query")
vectors = await embeddings_client.embed(query_text)
all_vectors = await embeddings_client.embed([query_text])
vector = all_vectors[0] if all_vectors else []
# Now query row embeddings
client = self.context("row-embeddings-query-request")
@ -164,7 +165,7 @@ class RowEmbeddingsQueryImpl:
user = getattr(client, '_current_user', self.user or "trustgraph")
matches = await client.row_embeddings_query(
vectors=vectors,
vector=vector,
schema_name=self.schema_name,
user=user,
collection=self.collection or "default",
@ -202,3 +203,116 @@ class PromptImpl:
id=self.template_id,
variables=arguments
)
# This tool implementation invokes a dynamically configured tool service
class ToolServiceImpl:
"""
Implementation for dynamically pluggable tool services.
Tool services are external Pulsar services that can be invoked as agent tools.
The service is configured via a tool-service descriptor that defines the queues,
and a tool descriptor that provides config values and argument definitions.
"""
def __init__(self, context, request_queue, response_queue, config_values=None, arguments=None, processor=None):
"""
Initialize a tool service implementation.
Args:
context: The context function (provides user info)
request_queue: Full Pulsar topic for requests
response_queue: Full Pulsar topic for responses
config_values: Dict of config values (e.g., {"collection": "customers"})
arguments: List of Argument objects defining the tool's parameters
processor: The Processor instance (for pubsub access)
"""
self.context = context
self.request_queue = request_queue
self.response_queue = response_queue
self.config_values = config_values or {}
self.arguments = arguments or []
self.processor = processor
self._client = None
def get_arguments(self):
return self.arguments
async def _get_or_create_client(self):
"""Get or create the tool service client."""
if self._client is not None:
return self._client
# Check if processor already has a client for this queue pair
client_key = f"{self.request_queue}|{self.response_queue}"
if client_key in self.processor.tool_service_clients:
self._client = self.processor.tool_service_clients[client_key]
return self._client
# Import here to avoid circular imports
from trustgraph.base.tool_service_client import ToolServiceClient
from trustgraph.base.metrics import ProducerMetrics, SubscriberMetrics
from trustgraph.schema import ToolServiceRequest, ToolServiceResponse
import uuid
request_metrics = ProducerMetrics(
processor=self.processor.id,
flow="tool-service",
name=self.request_queue
)
response_metrics = SubscriberMetrics(
processor=self.processor.id,
flow="tool-service",
name=self.response_queue
)
# Create unique subscription for responses
subscription = f"{self.processor.id}--tool-service--{uuid.uuid4()}"
self._client = ToolServiceClient(
backend=self.processor.pubsub,
subscription=subscription,
consumer_name=self.processor.id,
request_topic=self.request_queue,
request_schema=ToolServiceRequest,
request_metrics=request_metrics,
response_topic=self.response_queue,
response_schema=ToolServiceResponse,
response_metrics=response_metrics,
)
# Start the client
await self._client.start()
# Register for cleanup
self.processor.tool_service_clients[client_key] = self._client
logger.debug(f"Created tool service client for {self.request_queue}")
return self._client
async def invoke(self, **arguments):
logger.debug(f"Tool service invocation: {self.request_queue}...")
logger.debug(f"Config: {self.config_values}")
logger.debug(f"Arguments: {arguments}")
# Get user from context if available
user = "trustgraph"
if hasattr(self.context, '_user'):
user = self.context._user
# Get or create the client
client = await self._get_or_create_client()
# Call the tool service
response = await client.call(
user=user,
config=self.config_values,
arguments=arguments,
)
logger.debug(f"Tool service response: {response}")
if isinstance(response, str):
return response
else:
return json.dumps(response)

View file

@ -8,14 +8,24 @@ import logging
from langchain_text_splitters import RecursiveCharacterTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk
from ... schema import TextDocument, Chunk, Metadata, Triples
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
from ... provenance import (
derived_entity_triples,
set_graph, GRAPH_SOURCE,
)
# Component identification for provenance
COMPONENT_NAME = "chunker"
COMPONENT_VERSION = "1.0.0"
# Module logger
logger = logging.getLogger(__name__)
default_ident = "chunker"
class Processor(ChunkingService):
def __init__(self, **params):
@ -23,7 +33,7 @@ class Processor(ChunkingService):
id = params.get("id", default_ident)
chunk_size = params.get("chunk_size", 2000)
chunk_overlap = params.get("chunk_overlap", 100)
super(Processor, self).__init__(
**params | { "id": id }
)
@ -62,6 +72,13 @@ class Processor(ChunkingService):
)
)
self.register_specification(
ProducerSpec(
name = "triples",
schema = Triples,
)
)
logger.info("Recursive chunker initialized")
async def on_message(self, msg, consumer, flow):
@ -69,6 +86,9 @@ class Processor(ChunkingService):
v = msg.value()
logger.info(f"Chunking document {v.metadata.id}...")
# Get text content (fetches from librarian if needed)
text = await self.get_document_text(v)
# Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document(
msg, consumer, flow,
@ -90,25 +110,84 @@ class Processor(ChunkingService):
is_separator_regex=False,
)
texts = text_splitter.create_documents(
[v.text.decode("utf-8")]
)
texts = text_splitter.create_documents([text])
# Get parent document ID for provenance linking
# This could be a page URI (doc/p3) or document URI (doc) - we don't need to parse it
parent_doc_id = v.document_id or v.metadata.id
# Track character offset for provenance
char_offset = 0
for ix, chunk in enumerate(texts):
chunk_index = ix + 1 # 1-indexed
logger.debug(f"Created chunk of size {len(chunk.page_content)}")
# Generate chunk document ID by appending /c{index} to parent
# Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1)
chunk_doc_id = f"{parent_doc_id}/c{chunk_index}"
chunk_uri = chunk_doc_id # URI is same as document ID
parent_uri = parent_doc_id
chunk_content = chunk.page_content.encode("utf-8")
chunk_length = len(chunk.page_content)
# Save chunk to librarian as child document
await self.save_child_document(
doc_id=chunk_doc_id,
parent_id=parent_doc_id,
user=v.metadata.user,
content=chunk_content,
document_type="chunk",
title=f"Chunk {chunk_index}",
)
# Emit provenance triples (stored in source graph for separation from core knowledge)
prov_triples = derived_entity_triples(
entity_uri=chunk_uri,
parent_uri=parent_uri,
component_name=COMPONENT_NAME,
component_version=COMPONENT_VERSION,
label=f"Chunk {chunk_index}",
chunk_index=chunk_index,
char_offset=char_offset,
char_length=chunk_length,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
await flow("triples").send(Triples(
metadata=Metadata(
id=chunk_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
))
# Forward chunk ID + content (post-chunker optimization)
r = Chunk(
metadata=v.metadata,
chunk=chunk.page_content.encode("utf-8"),
metadata=Metadata(
id=chunk_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
chunk=chunk_content,
document_id=chunk_doc_id,
)
__class__.chunk_metric.labels(
id=consumer.id, flow=consumer.flow
).observe(len(chunk.page_content))
).observe(chunk_length)
await flow("output").send(r)
# Update character offset (approximate, doesn't account for overlap)
char_offset += chunk_length - chunk_overlap
logger.debug("Document chunking complete")
@staticmethod
@ -133,4 +212,3 @@ class Processor(ChunkingService):
def run():
Processor.launch(default_ident, __doc__)

View file

@ -8,14 +8,24 @@ import logging
from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk
from ... schema import TextDocument, Chunk, Metadata, Triples
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
from ... provenance import (
derived_entity_triples,
set_graph, GRAPH_SOURCE,
)
# Component identification for provenance
COMPONENT_NAME = "token-chunker"
COMPONENT_VERSION = "1.0.0"
# Module logger
logger = logging.getLogger(__name__)
default_ident = "chunker"
class Processor(ChunkingService):
def __init__(self, **params):
@ -23,7 +33,7 @@ class Processor(ChunkingService):
id = params.get("id", default_ident)
chunk_size = params.get("chunk_size", 250)
chunk_overlap = params.get("chunk_overlap", 15)
super(Processor, self).__init__(
**params | { "id": id }
)
@ -61,6 +71,13 @@ class Processor(ChunkingService):
)
)
self.register_specification(
ProducerSpec(
name = "triples",
schema = Triples,
)
)
logger.info("Token chunker initialized")
async def on_message(self, msg, consumer, flow):
@ -68,6 +85,9 @@ class Processor(ChunkingService):
v = msg.value()
logger.info(f"Chunking document {v.metadata.id}...")
# Get text content (fetches from librarian if needed)
text = await self.get_document_text(v)
# Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document(
msg, consumer, flow,
@ -88,25 +108,84 @@ class Processor(ChunkingService):
chunk_overlap=chunk_overlap,
)
texts = text_splitter.create_documents(
[v.text.decode("utf-8")]
)
texts = text_splitter.create_documents([text])
# Get parent document ID for provenance linking
# This could be a page URI (doc/p3) or document URI (doc) - we don't need to parse it
parent_doc_id = v.document_id or v.metadata.id
# Track token offset for provenance (approximate)
token_offset = 0
for ix, chunk in enumerate(texts):
chunk_index = ix + 1 # 1-indexed
logger.debug(f"Created chunk of size {len(chunk.page_content)}")
# Generate chunk document ID by appending /c{index} to parent
# Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1)
chunk_doc_id = f"{parent_doc_id}/c{chunk_index}"
chunk_uri = chunk_doc_id # URI is same as document ID
parent_uri = parent_doc_id
chunk_content = chunk.page_content.encode("utf-8")
chunk_length = len(chunk.page_content)
# Save chunk to librarian as child document
await self.save_child_document(
doc_id=chunk_doc_id,
parent_id=parent_doc_id,
user=v.metadata.user,
content=chunk_content,
document_type="chunk",
title=f"Chunk {chunk_index}",
)
# Emit provenance triples (stored in source graph for separation from core knowledge)
prov_triples = derived_entity_triples(
entity_uri=chunk_uri,
parent_uri=parent_uri,
component_name=COMPONENT_NAME,
component_version=COMPONENT_VERSION,
label=f"Chunk {chunk_index}",
chunk_index=chunk_index,
char_offset=token_offset, # Note: this is token offset, not char offset
char_length=chunk_length,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
await flow("triples").send(Triples(
metadata=Metadata(
id=chunk_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
))
# Forward chunk ID + content (post-chunker optimization)
r = Chunk(
metadata=v.metadata,
chunk=chunk.page_content.encode("utf-8"),
metadata=Metadata(
id=chunk_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
chunk=chunk_content,
document_id=chunk_doc_id,
)
__class__.chunk_metric.labels(
id=consumer.id, flow=consumer.flow
).observe(len(chunk.page_content))
).observe(chunk_length)
await flow("output").send(r)
# Update token offset (approximate, doesn't account for overlap)
token_offset += chunk_size - chunk_overlap
logger.debug("Document chunking complete")
@staticmethod
@ -118,17 +197,16 @@ class Processor(ChunkingService):
'-z', '--chunk-size',
type=int,
default=250,
help=f'Chunk size (default: 250)'
help=f'Chunk size in tokens (default: 250)'
)
parser.add_argument(
'-v', '--chunk-overlap',
type=int,
default=15,
help=f'Chunk overlap (default: 15)'
help=f'Chunk overlap in tokens (default: 15)'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -2,21 +2,44 @@
"""
Simple decoder, accepts PDF documents on input, outputs pages from the
PDF document as text as separate output objects.
Supports both inline document data and fetching from librarian via Pulsar
for large documents.
"""
import asyncio
import os
import tempfile
import base64
import logging
import uuid
from langchain_community.document_loaders import PyPDFLoader
from ... schema import Document, TextDocument, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
from ... provenance import (
document_uri, page_uri, derived_entity_triples,
set_graph, GRAPH_SOURCE,
)
# Component identification for provenance
COMPONENT_NAME = "pdf-decoder"
COMPONENT_VERSION = "1.0.0"
# Module logger
logger = logging.getLogger(__name__)
default_ident = "pdf-decoder"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor):
def __init__(self, **params):
@ -44,8 +67,164 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
ProducerSpec(
name = "triples",
schema = Triples,
)
)
# Librarian client for fetching document content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor = id, flow = None, name = "librarian-request"
)
self.librarian_request_producer = Producer(
backend = self.pubsub,
topic = librarian_request_q,
schema = LibrarianRequest,
metrics = librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor = id, flow = None, name = "librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub,
flow = None,
topic = librarian_response_q,
subscriber = f"{id}-librarian",
schema = LibrarianResponse,
handler = self.on_librarian_response,
metrics = librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
logger.info("PDF decoder initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def fetch_document_content(self, document_id, user, timeout=120):
"""
Fetch document content from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=document_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.content
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching document {document_id}")
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="page", title=None, timeout=120):
"""
Save a child document to the librarian.
Args:
doc_id: ID for the new child document
parent_id: ID of the parent document
user: User ID
content: Document content (bytes)
document_type: Type of document ("page", "chunk", etc.)
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
import base64
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
request = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving child document: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving child document {doc_id}")
async def on_message(self, msg, consumer, flow):
logger.debug("PDF message received")
@ -54,26 +233,102 @@ class Processor(FlowProcessor):
logger.info(f"Decoding PDF {v.metadata.id}...")
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
temp_path = fp.name
fp.write(base64.b64decode(v.data))
fp.close()
# Check if we should fetch from librarian or use inline data
if v.document_id:
# Fetch from librarian via Pulsar
logger.info(f"Fetching document {v.document_id} from librarian...")
fp.close()
with open(fp.name, mode='rb') as f:
content = await self.fetch_document_content(
document_id=v.document_id,
user=v.metadata.user,
)
loader = PyPDFLoader(fp.name)
pages = loader.load()
# Content is base64 encoded
if isinstance(content, str):
content = content.encode('utf-8')
decoded_content = base64.b64decode(content)
for ix, page in enumerate(pages):
with open(temp_path, 'wb') as f:
f.write(decoded_content)
logger.debug(f"Processing page {ix}")
logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
else:
# Use inline data (backward compatibility)
fp.write(base64.b64decode(v.data))
fp.close()
r = TextDocument(
metadata=v.metadata,
text=page.page_content.encode("utf-8"),
)
loader = PyPDFLoader(temp_path)
pages = loader.load()
await flow("output").send(r)
# Get the source document ID
source_doc_id = v.document_id or v.metadata.id
for ix, page in enumerate(pages):
page_num = ix + 1 # 1-indexed page numbers
logger.debug(f"Processing page {page_num}")
# Generate page document ID
page_doc_id = f"{source_doc_id}/p{page_num}"
page_content = page.page_content.encode("utf-8")
# Save page as child document in librarian
await self.save_child_document(
doc_id=page_doc_id,
parent_id=source_doc_id,
user=v.metadata.user,
content=page_content,
document_type="page",
title=f"Page {page_num}",
)
# Emit provenance triples (stored in source graph for separation from core knowledge)
doc_uri = document_uri(source_doc_id)
pg_uri = page_uri(source_doc_id, page_num)
prov_triples = derived_entity_triples(
entity_uri=pg_uri,
parent_uri=doc_uri,
component_name=COMPONENT_NAME,
component_version=COMPONENT_VERSION,
label=f"Page {page_num}",
page_number=page_num,
)
await flow("triples").send(Triples(
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
triples=set_graph(prov_triples, GRAPH_SOURCE),
))
# Forward page document ID to chunker
# Chunker will fetch content from librarian
r = TextDocument(
metadata=Metadata(
id=pg_uri,
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
document_id=page_doc_id,
text=b"", # Empty, chunker will fetch from librarian
)
await flow("output").send(r)
# Clean up temp file
try:
os.unlink(temp_path)
except OSError:
pass
logger.debug("PDF decoding complete")
@ -81,7 +336,18 @@ class Processor(FlowProcessor):
def add_args(parser):
FlowProcessor.add_args(parser)
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -589,6 +589,8 @@ class EntityCentricKnowledgeGraph:
# quads_by_entity: primary data table
# Every entity has a partition containing all quads it participates in
# Clustering key includes dtype/lang to distinguish literals with same value
# but different datatype or language tag (e.g., "thing" vs "thing"@en)
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.entity_table} (
collection text,
@ -601,11 +603,13 @@ class EntityCentricKnowledgeGraph:
d text,
dtype text,
lang text,
PRIMARY KEY ((collection, entity), role, p, otype, s, o, d)
PRIMARY KEY ((collection, entity), role, p, otype, s, o, d, dtype, lang)
);
""")
# quads_by_collection: manifest for collection-level queries and deletion
# Clustering key includes otype/dtype/lang to distinguish literals with same
# value but different metadata (e.g., "thing" vs "thing"@en vs "thing"^^xsd:string)
self.session.execute(f"""
CREATE TABLE IF NOT EXISTS {self.collection_table} (
collection text,
@ -616,7 +620,7 @@ class EntityCentricKnowledgeGraph:
otype text,
dtype text,
lang text,
PRIMARY KEY (collection, d, s, p, o)
PRIMARY KEY (collection, d, s, p, o, otype, dtype, lang)
);
""")
@ -718,7 +722,7 @@ class EntityCentricKnowledgeGraph:
)
self.delete_collection_row_stmt = self.session.prepare(
f"DELETE FROM {self.collection_table} WHERE collection = ? AND d = ? AND s = ? AND p = ? AND o = ?"
f"DELETE FROM {self.collection_table} WHERE collection = ? AND d = ? AND s = ? AND p = ? AND o = ? AND otype = ? AND dtype = ? AND lang = ?"
)
logger.info("Prepared statements initialized for entity-centric schema")
@ -797,7 +801,7 @@ class EntityCentricKnowledgeGraph:
def get_s(self, collection, s, g=None, limit=10):
"""
Query by subject. Returns quads where s is the subject.
g=None: default graph, g='*': all graphs
g=None: all graphs, g='': default graph only, g='uri': specific graph
"""
rows = self.session.execute(self.get_entity_as_s_stmt, (collection, s, limit))
@ -805,10 +809,7 @@ class EntityCentricKnowledgeGraph:
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
# Filter by graph if specified
if g is None or g == DEFAULT_GRAPH:
if d != DEFAULT_GRAPH:
continue
elif g != GRAPH_WILDCARD and d != g:
if g is not None and d != g:
continue
results.append(QuadResult(
@ -819,16 +820,13 @@ class EntityCentricKnowledgeGraph:
return results
def get_p(self, collection, p, g=None, limit=10):
"""Query by predicate"""
"""Query by predicate. g=None: all graphs, g='': default graph only"""
rows = self.session.execute(self.get_entity_as_p_stmt, (collection, p, limit))
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is None or g == DEFAULT_GRAPH:
if d != DEFAULT_GRAPH:
continue
elif g != GRAPH_WILDCARD and d != g:
if g is not None and d != g:
continue
results.append(QuadResult(
@ -839,16 +837,13 @@ class EntityCentricKnowledgeGraph:
return results
def get_o(self, collection, o, g=None, limit=10):
"""Query by object"""
"""Query by object. g=None: all graphs, g='': default graph only"""
rows = self.session.execute(self.get_entity_as_o_stmt, (collection, o, limit))
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is None or g == DEFAULT_GRAPH:
if d != DEFAULT_GRAPH:
continue
elif g != GRAPH_WILDCARD and d != g:
if g is not None and d != g:
continue
results.append(QuadResult(
@ -859,16 +854,13 @@ class EntityCentricKnowledgeGraph:
return results
def get_sp(self, collection, s, p, g=None, limit=10):
"""Query by subject and predicate"""
"""Query by subject and predicate. g=None: all graphs, g='': default graph only"""
rows = self.session.execute(self.get_entity_as_s_p_stmt, (collection, s, p, limit))
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is None or g == DEFAULT_GRAPH:
if d != DEFAULT_GRAPH:
continue
elif g != GRAPH_WILDCARD and d != g:
if g is not None and d != g:
continue
results.append(QuadResult(
@ -879,16 +871,13 @@ class EntityCentricKnowledgeGraph:
return results
def get_po(self, collection, p, o, g=None, limit=10):
"""Query by predicate and object"""
"""Query by predicate and object. g=None: all graphs, g='': default graph only"""
rows = self.session.execute(self.get_entity_as_o_p_stmt, (collection, o, p, limit))
results = []
for row in rows:
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is None or g == DEFAULT_GRAPH:
if d != DEFAULT_GRAPH:
continue
elif g != GRAPH_WILDCARD and d != g:
if g is not None and d != g:
continue
results.append(QuadResult(
@ -899,7 +888,7 @@ class EntityCentricKnowledgeGraph:
return results
def get_os(self, collection, o, s, g=None, limit=10):
"""Query by object and subject"""
"""Query by object and subject. g=None: all graphs, g='': default graph only"""
# Use subject partition with role='S', filter by o
rows = self.session.execute(self.get_entity_as_s_stmt, (collection, s, limit))
@ -909,10 +898,7 @@ class EntityCentricKnowledgeGraph:
continue
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is None or g == DEFAULT_GRAPH:
if d != DEFAULT_GRAPH:
continue
elif g != GRAPH_WILDCARD and d != g:
if g is not None and d != g:
continue
results.append(QuadResult(
@ -923,7 +909,7 @@ class EntityCentricKnowledgeGraph:
return results
def get_spo(self, collection, s, p, o, g=None, limit=10):
"""Query by subject, predicate, object (find which graphs)"""
"""Query by subject, predicate, object (find which graphs). g=None: all graphs, g='': default graph only"""
rows = self.session.execute(self.get_entity_as_s_p_stmt, (collection, s, p, limit))
results = []
@ -932,10 +918,7 @@ class EntityCentricKnowledgeGraph:
continue
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
if g is None or g == DEFAULT_GRAPH:
if d != DEFAULT_GRAPH:
continue
elif g != GRAPH_WILDCARD and d != g:
if g is not None and d != g:
continue
results.append(QuadResult(
@ -991,9 +974,9 @@ class EntityCentricKnowledgeGraph:
3. Delete entire entity partitions
4. Delete collection rows
"""
# Read all quads from collection table
# Read all quads from collection table (including type metadata for delete)
rows = self.session.execute(
f"SELECT d, s, p, o, otype FROM {self.collection_table} WHERE collection = %s",
f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} WHERE collection = %s",
(collection,)
)
@ -1002,8 +985,11 @@ class EntityCentricKnowledgeGraph:
quads = []
for row in rows:
d, s, p, o, otype = row.d, row.s, row.p, row.o, row.otype
quads.append((d, s, p, o))
d, s, p, o = row.d, row.s, row.p, row.o
otype = row.otype
dtype = row.dtype if hasattr(row, 'dtype') else ''
lang = row.lang if hasattr(row, 'lang') else ''
quads.append((d, s, p, o, otype, dtype, lang))
# Subject and predicate are always entities
entities.add(s)
@ -1038,8 +1024,8 @@ class EntityCentricKnowledgeGraph:
batch = BatchStatement()
count = 0
for d, s, p, o in quads:
batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o))
for d, s, p, o, otype, dtype, lang in quads:
batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o, otype, dtype, lang))
count += 1
# Execute batch every 50 quads

View file

@ -84,14 +84,14 @@ class DocVectors:
dim=dimension,
)
doc_field = FieldSchema(
name="doc",
chunk_id_field = FieldSchema(
name="chunk_id",
dtype=DataType.VARCHAR,
max_length=65535,
)
schema = CollectionSchema(
fields = [pkey_field, vec_field, doc_field],
fields = [pkey_field, vec_field, chunk_id_field],
description = "Document embedding schema",
)
@ -119,17 +119,17 @@ class DocVectors:
self.collections[(dimension, user, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, doc, user, collection):
def insert(self, embeds, chunk_id, user, collection):
dim = len(embeds)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
data = [
{
"vector": embeds,
"doc": doc,
"chunk_id": chunk_id,
}
]
@ -138,7 +138,7 @@ class DocVectors:
data=data
)
def search(self, embeds, user, collection, fields=["doc"], limit=10):
def search(self, embeds, user, collection, fields=["chunk_id"], limit=10):
dim = len(embeds)

View file

@ -90,8 +90,14 @@ class EntityVectors:
max_length=65535,
)
chunk_id_field = FieldSchema(
name="chunk_id",
dtype=DataType.VARCHAR,
max_length=65535,
)
schema = CollectionSchema(
fields = [pkey_field, vec_field, entity_field],
fields = [pkey_field, vec_field, entity_field, chunk_id_field],
description = "Graph embedding schema",
)
@ -119,17 +125,18 @@ class EntityVectors:
self.collections[(dimension, user, collection)] = collection_name
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
def insert(self, embeds, entity, user, collection):
def insert(self, embeds, entity, user, collection, chunk_id=""):
dim = len(embeds)
if (dim, user, collection) not in self.collections:
self.init_collection(dim, user, collection)
data = [
{
"vector": embeds,
"entity": entity,
"chunk_id": chunk_id,
}
]

View file

@ -62,16 +62,17 @@ class Processor(FlowProcessor):
resp = await flow("embeddings-request").request(
EmbeddingsRequest(
text = v.chunk
texts=[v.chunk]
)
)
vectors = resp.vectors
# vectors[0] is the vector for the first (only) text
vector = resp.vectors[0] if resp.vectors else []
embeds = [
ChunkEmbeddings(
chunk=v.chunk,
vectors=vectors,
chunk_id=v.document_id,
vector=vector,
)
]

View file

@ -46,19 +46,21 @@ class Processor(EmbeddingsService):
else:
logger.debug(f"Using cached model: {model_name}")
async def on_embeddings(self, text, model=None):
async def on_embeddings(self, texts, model=None):
if not texts:
return []
use_model = model or self.default_model
# Reload model if it has changed
self._load_model(use_model)
vecs = self.embeddings.embed([text])
# FastEmbed processes the full batch efficiently
vecs = list(self.embeddings.embed(texts))
return [
v.tolist()
for v in vecs
]
# Return list of vectors, one per input text
return [v.tolist() for v in vecs]
@staticmethod
def add_args(parser):

View file

@ -58,22 +58,25 @@ class Processor(FlowProcessor):
v = msg.value()
logger.info(f"Indexing {v.metadata.id}...")
entities = []
try:
for entity in v.entities:
# Collect all contexts for batch embedding
contexts = [entity.context for entity in v.entities]
vectors = await flow("embeddings-request").embed(
text = entity.context
)
# Single batch embedding call
all_vectors = await flow("embeddings-request").embed(
texts=contexts
)
entities.append(
EntityEmbeddings(
entity=entity.entity,
vectors=vectors
)
# Pair results with entities
entities = [
EntityEmbeddings(
entity=entity.entity,
vector=vector,
chunk_id=entity.chunk_id, # Provenance: source chunk
)
for entity, vector in zip(v.entities, all_vectors)
]
# Send in batches to avoid oversized messages
for i in range(0, len(entities), self.batch_size):

View file

@ -30,16 +30,21 @@ class Processor(EmbeddingsService):
self.client = Client(host=ollama)
self.default_model = model
async def on_embeddings(self, text, model=None):
async def on_embeddings(self, texts, model=None):
if not texts:
return []
use_model = model or self.default_model
# Ollama handles batch input efficiently
embeds = self.client.embed(
model = use_model,
input = text
input = texts
)
return embeds.embeddings
# Return list of vectors, one per input text
return list(embeds.embeddings)
@staticmethod
def add_args(parser):

View file

@ -200,15 +200,23 @@ class Processor(CollectionConfigHandler, FlowProcessor):
embeddings_list = []
try:
for text, (index_name, index_value) in texts_to_embed.items():
vectors = await flow("embeddings-request").embed(text=text)
# Collect texts and metadata for batch embedding
texts = list(texts_to_embed.keys())
metadata = list(texts_to_embed.values())
# Single batch embedding call
all_vectors = await flow("embeddings-request").embed(texts=texts)
# Pair results with metadata
for text, (index_name, index_value), vector in zip(
texts, metadata, all_vectors
):
embeddings_list.append(
RowIndexEmbedding(
index_name=index_name,
index_value=index_value,
text=text,
vectors=vectors
vector=vector
)
)

View file

@ -6,11 +6,13 @@ import logging
from ....schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
from ....schema import EntityContext, EntityContexts
from ....rdf import TRUSTGRAPH_ENTITIES, RDF_LABEL, SUBJECT_OF, DEFINITION
from ....rdf import TRUSTGRAPH_ENTITIES, RDF_LABEL, DEFINITION
from ....base import FlowProcessor, ConsumerSpec, ProducerSpec
from ....base import AgentClientSpec
from ....provenance import subgraph_uri, subgraph_provenance_triples, set_graph, GRAPH_SOURCE
from ....flow_version import __version__ as COMPONENT_VERSION
from ....template import PromptManager
# Module logger
@ -104,7 +106,7 @@ class Processor(FlowProcessor):
tpls = Triples(
metadata = Metadata(
id = metadata.id,
metadata = [],
root = metadata.root,
user = metadata.user,
collection = metadata.collection,
),
@ -117,7 +119,7 @@ class Processor(FlowProcessor):
ecs = EntityContexts(
metadata = Metadata(
id = metadata.id,
metadata = [],
root = metadata.root,
user = metadata.user,
collection = metadata.collection,
),
@ -183,24 +185,8 @@ class Processor(FlowProcessor):
logger.debug(f"Agent prompt: {prompt}")
async def handle(response):
logger.debug(f"Agent response: {response}")
if response.error is not None:
if response.error.message:
raise RuntimeError(str(response.error.message))
else:
raise RuntimeError(str(response.error))
if response.answer is not None:
return True
else:
return False
# Send to agent API
agent_response = await flow("agent-request").invoke(
recipient = handle,
question = prompt
)
@ -212,14 +198,22 @@ class Processor(FlowProcessor):
return
# Process extraction data
triples, entity_contexts = self.process_extraction_data(
extraction_data, v.metadata
)
triples, entity_contexts, extracted_triples = \
self.process_extraction_data(extraction_data, v.metadata)
# Generate subgraph provenance for extracted triples
if extracted_triples:
chunk_uri = v.metadata.id
sg_uri = subgraph_uri()
prov_triples = subgraph_provenance_triples(
subgraph_uri=sg_uri,
extracted_triples=extracted_triples,
chunk_uri=chunk_uri,
component_name=default_ident,
component_version=COMPONENT_VERSION,
)
triples.extend(set_graph(prov_triples, GRAPH_SOURCE))
# Put document metadata into triples
for t in v.metadata.metadata:
triples.append(t)
# Emit outputs
if triples:
await self.emit_triples(flow("triples"), v.metadata, triples)
@ -241,8 +235,13 @@ class Processor(FlowProcessor):
Data is a flat list of objects with 'type' discriminator field:
- {"type": "definition", "entity": "...", "definition": "..."}
- {"type": "relationship", "subject": "...", "predicate": "...", "object": "...", "object-entity": bool}
Returns:
Tuple of (all_triples, entity_contexts, extracted_triples) where
extracted_triples contains only the core knowledge facts (for provenance).
"""
triples = []
extracted_triples = []
entity_contexts = []
# Categorize items by type
@ -262,26 +261,20 @@ class Processor(FlowProcessor):
))
# Add definition
triples.append(Triple(
definition_triple = Triple(
s = Term(type=IRI, iri=entity_uri),
p = Term(type=IRI, iri=DEFINITION),
o = Term(type=LITERAL, value=defn["definition"]),
))
# Add subject-of relationship to document
if metadata.id:
triples.append(Triple(
s = Term(type=IRI, iri=entity_uri),
p = Term(type=IRI, iri=SUBJECT_OF),
o = Term(type=IRI, iri=metadata.id),
))
)
triples.append(definition_triple)
extracted_triples.append(definition_triple)
# Create entity context for embeddings
entity_contexts.append(EntityContext(
entity=Term(type=IRI, iri=entity_uri),
context=defn["definition"]
))
# Process relationships
for rel in relationships:
@ -318,34 +311,15 @@ class Processor(FlowProcessor):
))
# Add the main relationship triple
triples.append(Triple(
relationship_triple = Triple(
s = subject_value,
p = predicate_value,
o = object_value
))
)
triples.append(relationship_triple)
extracted_triples.append(relationship_triple)
# Add subject-of relationships to document
if metadata.id:
triples.append(Triple(
s = subject_value,
p = Term(type=IRI, iri=SUBJECT_OF),
o = Term(type=IRI, iri=metadata.id),
))
triples.append(Triple(
s = predicate_value,
p = Term(type=IRI, iri=SUBJECT_OF),
o = Term(type=IRI, iri=metadata.id),
))
if rel.get("object-entity", True):
triples.append(Triple(
s = object_value,
p = Term(type=IRI, iri=SUBJECT_OF),
o = Term(type=IRI, iri=metadata.id),
))
return triples, entity_contexts
return triples, entity_contexts, extracted_triples
@staticmethod
def add_args(parser):

View file

@ -15,15 +15,16 @@ from .... schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
logger = logging.getLogger(__name__)
from .... schema import EntityContext, EntityContexts
from .... schema import PromptRequest, PromptResponse
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
from .... base import PromptClientSpec, ParameterSpec
from .... provenance import subgraph_uri, subgraph_provenance_triples, set_graph, GRAPH_SOURCE
from .... flow_version import __version__ as COMPONENT_VERSION
DEFINITION_VALUE = Term(type=IRI, iri=DEFINITION)
RDF_LABEL_VALUE = Term(type=IRI, iri=RDF_LABEL)
SUBJECT_OF_VALUE = Term(type=IRI, iri=SUBJECT_OF)
default_ident = "kg-extract-definitions"
default_concurrency = 1
default_triples_batch_size = 50
@ -75,6 +76,10 @@ class Processor(FlowProcessor):
)
)
# Optional flow parameters for provenance
self.register_specification(ParameterSpec("llm-model"))
self.register_specification(ParameterSpec("ontology"))
def to_uri(self, text):
part = text.replace(" ", "-").lower().encode("utf-8")
@ -126,12 +131,19 @@ class Processor(FlowProcessor):
raise e
triples = []
extracted_triples = []
entities = []
# FIXME: Putting metadata into triples store is duplicated in
# relationships extractor too
for t in v.metadata.metadata:
triples.append(t)
# Get chunk document ID for provenance linking
chunk_doc_id = v.document_id if v.document_id else v.metadata.id
chunk_uri = v.metadata.id # The URI form for the chunk
# Get optional provenance parameters
llm_model = flow("llm-model")
ontology_uri = flow("ontology")
# Note: Document metadata is now emitted once by librarian at processing
# initiation, so we don't need to duplicate it here.
for defn in defs:
@ -155,28 +167,43 @@ class Processor(FlowProcessor):
o=Term(type=LITERAL, value=s),
))
triples.append(Triple(
# The definition triple - this is the main extracted fact
definition_triple = Triple(
s=s_value, p=DEFINITION_VALUE, o=o_value
))
triples.append(Triple(
s=s_value,
p=SUBJECT_OF_VALUE,
o=Term(type=IRI, iri=v.metadata.id)
))
)
triples.append(definition_triple)
extracted_triples.append(definition_triple)
# Output entity name as context for direct name matching
# Include chunk_id for embedding provenance
entities.append(EntityContext(
entity=s_value,
context=s,
chunk_id=chunk_doc_id,
))
# Output definition as context for semantic matching
# Include chunk_id for embedding provenance
entities.append(EntityContext(
entity=s_value,
context=defn["definition"],
chunk_id=chunk_doc_id,
))
# Generate subgraph provenance once for all extracted triples
if extracted_triples:
sg_uri = subgraph_uri()
prov_triples = subgraph_provenance_triples(
subgraph_uri=sg_uri,
extracted_triples=extracted_triples,
chunk_uri=chunk_uri,
component_name=default_ident,
component_version=COMPONENT_VERSION,
llm_model=llm_model,
ontology_uri=ontology_uri,
)
triples.extend(set_graph(prov_triples, GRAPH_SOURCE))
# Send triples in batches
for i in range(0, len(triples), self.triples_batch_size):
batch = triples[i:i + self.triples_batch_size]
@ -184,7 +211,7 @@ class Processor(FlowProcessor):
flow("triples"),
Metadata(
id=v.metadata.id,
metadata=[],
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),
@ -198,7 +225,7 @@ class Processor(FlowProcessor):
flow("entity-contexts"),
Metadata(
id=v.metadata.id,
metadata=[],
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),

View file

@ -23,6 +23,9 @@ from .ontology_selector import OntologySelector, OntologySubset
from .simplified_parser import parse_extraction_response
from .triple_converter import TripleConverter
from .... provenance import subgraph_uri, subgraph_provenance_triples, set_graph, GRAPH_SOURCE
from .... flow_version import __version__ as COMPONENT_VERSION
logger = logging.getLogger(__name__)
default_ident = "kg-extract-ontology"
@ -148,8 +151,8 @@ class Processor(FlowProcessor):
# Detect embedding dimension by embedding a test string
logger.info("Detecting embedding dimension from embeddings service...")
test_embedding_response = await embeddings_client.embed("test")
test_embedding = test_embedding_response[0] # Extract from [[vector]]
test_embedding_response = await embeddings_client.embed(["test"])
test_embedding = test_embedding_response[0] # Extract first vector
dimension = len(test_embedding)
logger.info(f"Detected embedding dimension: {dimension}")
@ -306,15 +309,25 @@ class Processor(FlowProcessor):
flow, chunk, ontology_subset, prompt_variables
)
# Add metadata triples
for t in v.metadata.metadata:
triples.append(t)
# Generate subgraph provenance for extracted triples
if triples:
chunk_uri = v.metadata.id
sg_uri = subgraph_uri()
prov_triples = subgraph_provenance_triples(
subgraph_uri=sg_uri,
extracted_triples=triples,
chunk_uri=chunk_uri,
component_name=default_ident,
component_version=COMPONENT_VERSION,
)
# Generate ontology definition triples
ontology_triples = self.build_ontology_triples(ontology_subset)
# Combine extracted triples with ontology triples
# Combine extracted triples with ontology triples and provenance
all_triples = triples + ontology_triples
if triples:
all_triples.extend(set_graph(prov_triples, GRAPH_SOURCE))
# Build entity contexts from all triples (including ontology elements)
entity_contexts = self.build_entity_contexts(all_triples)
@ -558,7 +571,7 @@ class Processor(FlowProcessor):
t = Triples(
metadata=Metadata(
id=metadata.id,
metadata=[],
root=metadata.root,
user=metadata.user,
collection=metadata.collection,
),
@ -571,7 +584,7 @@ class Processor(FlowProcessor):
ec = EntityContexts(
metadata=Metadata(
id=metadata.id,
metadata=[],
root=metadata.root,
user=metadata.user,
collection=metadata.collection,
),

View file

@ -153,16 +153,11 @@ class OntologyEmbedder:
# Get embeddings for batch
texts = [elem['text'] for elem in batch]
try:
# Call embedding service for each text
# Note: embed() returns 2D array [[vector]], so extract first element
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
embeddings_responses = await asyncio.gather(*embedding_tasks)
# Extract vectors from responses (each is [[vector]])
embeddings_list = [resp[0] for resp in embeddings_responses]
# Single batch embedding call - returns list of vectors
embeddings_response = await self.embedding_service.embed(texts)
# Convert to numpy array
embeddings = np.array(embeddings_list)
embeddings = np.array(embeddings_response)
# Log embedding shape for debugging
logger.debug(f"Embeddings shape: {embeddings.shape}, expected: ({len(batch)}, {self.vector_store.dimension})")
@ -218,8 +213,8 @@ class OntologyEmbedder:
return None
try:
# embed() returns 2D array [[vector]], extract first element
embedding_response = await self.embedding_service.embed(text)
# embed() with single text, extract first vector
embedding_response = await self.embedding_service.embed([text])
return np.array(embedding_response[0])
except Exception as e:
logger.error(f"Failed to embed text: {e}")
@ -239,12 +234,9 @@ class OntologyEmbedder:
return None
try:
# Call embed() for each text (returns [[vector]] per call)
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
embeddings_responses = await asyncio.gather(*embedding_tasks)
# Extract first vector from each response
embeddings_list = [resp[0] for resp in embeddings_responses]
return np.array(embeddings_list)
# Single batch embedding call - returns list of vectors
embeddings_response = await self.embedding_service.embed(texts)
return np.array(embeddings_response)
except Exception as e:
logger.error(f"Failed to embed texts: {e}")
return None

View file

@ -15,13 +15,15 @@ logger = logging.getLogger(__name__)
from .... schema import Chunk, Triple, Triples
from .... schema import Metadata, Term, IRI, LITERAL
from .... schema import PromptRequest, PromptResponse
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
from .... base import PromptClientSpec
from .... base import PromptClientSpec, ParameterSpec
from .... provenance import subgraph_uri, subgraph_provenance_triples, set_graph, GRAPH_SOURCE
from .... flow_version import __version__ as COMPONENT_VERSION
RDF_LABEL_VALUE = Term(type=IRI, iri=RDF_LABEL)
SUBJECT_OF_VALUE = Term(type=IRI, iri=SUBJECT_OF)
default_ident = "kg-extract-relationships"
default_concurrency = 1
@ -65,6 +67,10 @@ class Processor(FlowProcessor):
)
)
# Optional flow parameters for provenance
self.register_specification(ParameterSpec("llm-model"))
self.register_specification(ParameterSpec("ontology"))
def to_uri(self, text):
part = text.replace(" ", "-").lower().encode("utf-8")
@ -108,11 +114,18 @@ class Processor(FlowProcessor):
raise e
triples = []
extracted_triples = []
# FIXME: Putting metadata into triples store is duplicated in
# relationships extractor too
for t in v.metadata.metadata:
triples.append(t)
# Get chunk document ID for provenance linking
chunk_doc_id = v.document_id if v.document_id else v.metadata.id
chunk_uri = v.metadata.id # The URI form for the chunk
# Get optional provenance parameters
llm_model = flow("llm-model")
ontology_uri = flow("ontology")
# Note: Document metadata is now emitted once by librarian at processing
# initiation, so we don't need to duplicate it here.
for rel in rels:
@ -140,11 +153,14 @@ class Processor(FlowProcessor):
else:
o_value = Term(type=LITERAL, value=str(o))
triples.append(Triple(
# The relationship triple - this is the main extracted fact
relationship_triple = Triple(
s=s_value,
p=p_value,
o=o_value
))
)
triples.append(relationship_triple)
extracted_triples.append(relationship_triple)
# Label for s
triples.append(Triple(
@ -168,20 +184,19 @@ class Processor(FlowProcessor):
o=Term(type=LITERAL, value=str(o))
))
# 'Subject of' for s
triples.append(Triple(
s=s_value,
p=SUBJECT_OF_VALUE,
o=Term(type=IRI, iri=v.metadata.id)
))
if rel["object-entity"]:
# 'Subject of' for o
triples.append(Triple(
s=o_value,
p=SUBJECT_OF_VALUE,
o=Term(type=IRI, iri=v.metadata.id)
))
# Generate subgraph provenance once for all extracted triples
if extracted_triples:
sg_uri = subgraph_uri()
prov_triples = subgraph_provenance_triples(
subgraph_uri=sg_uri,
extracted_triples=extracted_triples,
chunk_uri=chunk_uri,
component_name=default_ident,
component_version=COMPONENT_VERSION,
llm_model=llm_model,
ontology_uri=ontology_uri,
)
triples.extend(set_graph(prov_triples, GRAPH_SOURCE))
# Send triples in batches
for i in range(0, len(triples), self.triples_batch_size):
@ -190,7 +205,7 @@ class Processor(FlowProcessor):
flow("triples"),
Metadata(
id=v.metadata.id,
metadata=[],
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),

View file

@ -272,7 +272,7 @@ class Processor(FlowProcessor):
extracted = ExtractedObject(
metadata=Metadata(
id=f"{v.metadata.id}:{schema_name}",
metadata=[],
root=v.metadata.root,
user=v.metadata.user,
collection=v.metadata.collection,
),

View file

@ -0,0 +1,65 @@
import asyncio
import uuid
import logging
from . librarian import LibrarianRequestor
# Module logger
logger = logging.getLogger(__name__)
class DocumentStreamExport:
def __init__(self, backend):
self.backend = backend
async def process(self, data, error, ok, request):
user = request.query.get("user")
document_id = request.query.get("document-id")
chunk_size = int(request.query.get("chunk-size", 1024 * 1024))
if not user or not document_id:
return await error("Missing required parameters: user, document-id")
response = await ok()
lr = LibrarianRequestor(
backend=self.backend,
consumer="api-gateway-doc-stream-" + str(uuid.uuid4()),
subscriber="api-gateway-doc-stream-" + str(uuid.uuid4()),
)
try:
await lr.start()
async def responder(resp, fin):
if "content" in resp:
content = resp["content"]
# Content is base64 encoded, write as-is for client to decode
# Or decode here and write raw bytes
import base64
chunk_data = base64.b64decode(content)
await response.write(chunk_data)
await lr.process(
{
"operation": "stream-document",
"user": user,
"document-id": document_id,
"chunk-size": chunk_size,
},
responder
)
except Exception as e:
logger.error(f"Document stream exception: {e}", exc_info=True)
finally:
await lr.stop()
await response.write_eof()
return response

View file

@ -45,6 +45,7 @@ from . rows_import import RowsImport
from . core_export import CoreExport
from . core_import import CoreImport
from . document_stream import DocumentStreamExport
from . mux import Mux
@ -135,6 +136,14 @@ class DispatcherManager:
def dispatch_core_import(self):
return DispatcherWrapper(self.process_core_import)
def dispatch_document_stream(self):
return DispatcherWrapper(self.process_document_stream)
async def process_document_stream(self, data, error, ok, request):
ds = DocumentStreamExport(self.backend)
return await ds.process(data, error, ok, request)
async def process_core_import(self, data, error, ok, request):
ci = CoreImport(self.backend)

View file

@ -53,7 +53,6 @@ class RowsImport:
elt = ExtractedObject(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"].get("metadata", [])),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),

View file

@ -37,35 +37,37 @@ def serialize_triples(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"root": message.metadata.root,
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"triples": serialize_subgraph(message.triples),
}
def serialize_graph_embeddings(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"root": message.metadata.root,
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"entities": [
{
"vectors": entity.vectors,
"vector": entity.vector,
"entity": serialize_value(entity.entity),
}
for entity in message.entities
],
}
def serialize_entity_contexts(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"root": message.metadata.root,
"user": message.metadata.user,
"collection": message.metadata.collection,
},
@ -78,18 +80,19 @@ def serialize_entity_contexts(message):
],
}
def serialize_document_embeddings(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"root": message.metadata.root,
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"chunks": [
{
"vectors": chunk.vectors,
"chunk": chunk.chunk.decode("utf-8"),
"vector": chunk.vector,
"chunk_id": chunk.chunk_id,
}
for chunk in message.chunks
],

View file

@ -48,7 +48,7 @@ class TriplesImport:
elt = Triples(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
root=data["metadata"].get("root", ""),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),

View file

@ -64,6 +64,12 @@ class EndpointManager:
method = "GET",
dispatcher = dispatcher_manager.dispatch_core_export(),
),
StreamEndpoint(
endpoint_path = "/api/v1/document-stream",
auth = auth,
method = "GET",
dispatcher = dispatcher_manager.dispatch_document_stream(),
),
]
def add_routes(self, app):

View file

@ -3,9 +3,12 @@ from .. knowledge import hash
from .. exceptions import RequestError
from minio import Minio
from minio.datatypes import Part
import time
import io
import logging
from typing import Iterator, List, Tuple
from uuid import UUID
# Module logger
logger = logging.getLogger(__name__)
@ -78,3 +81,163 @@ class BlobStore:
return resp.read()
async def get_range(self, object_id, offset: int, length: int) -> bytes:
"""Fetch a specific byte range from an object."""
resp = self.client.get_object(
bucket_name=self.bucket_name,
object_name="doc/" + str(object_id),
offset=offset,
length=length,
)
try:
return resp.read()
finally:
resp.close()
resp.release_conn()
async def get_size(self, object_id) -> int:
"""Get the size of an object without downloading it."""
stat = self.client.stat_object(
bucket_name=self.bucket_name,
object_name="doc/" + str(object_id),
)
return stat.size
def get_stream(self, object_id, chunk_size: int = 1024 * 1024) -> Iterator[bytes]:
"""
Stream document content in chunks.
Yields chunks of the document, allowing processing without loading
the entire document into memory.
Args:
object_id: The UUID of the document object
chunk_size: Size of each chunk in bytes (default 1MB)
Yields:
Chunks of document content as bytes
"""
resp = self.client.get_object(
bucket_name=self.bucket_name,
object_name="doc/" + str(object_id),
)
try:
while True:
chunk = resp.read(chunk_size)
if not chunk:
break
yield chunk
finally:
resp.close()
resp.release_conn()
logger.debug("Stream complete")
def create_multipart_upload(self, object_id: UUID, kind: str) -> str:
"""
Initialize a multipart upload.
Args:
object_id: The UUID for the new object
kind: MIME type of the document
Returns:
The S3 upload_id for this multipart upload session
"""
object_name = "doc/" + str(object_id)
# Use minio's internal method to create multipart upload
upload_id = self.client._create_multipart_upload(
bucket_name=self.bucket_name,
object_name=object_name,
headers={"Content-Type": kind},
)
logger.info(f"Created multipart upload {upload_id} for {object_id}")
return upload_id
def upload_part(
self,
object_id: UUID,
upload_id: str,
part_number: int,
data: bytes
) -> str:
"""
Upload a single part of a multipart upload.
Args:
object_id: The UUID of the object being uploaded
upload_id: The S3 upload_id from create_multipart_upload
part_number: Part number (1-indexed, as per S3 spec)
data: The chunk data to upload
Returns:
The ETag for this part (needed for complete_multipart_upload)
"""
object_name = "doc/" + str(object_id)
etag = self.client._upload_part(
bucket_name=self.bucket_name,
object_name=object_name,
data=data,
headers={"Content-Length": str(len(data))},
upload_id=upload_id,
part_number=part_number,
)
logger.debug(f"Uploaded part {part_number} for {object_id}, etag={etag}")
return etag
def complete_multipart_upload(
self,
object_id: UUID,
upload_id: str,
parts: List[Tuple[int, str]]
) -> None:
"""
Complete a multipart upload, assembling all parts into the final object.
S3 coalesces the parts server-side - no data transfer through this client.
Args:
object_id: The UUID of the object
upload_id: The S3 upload_id from create_multipart_upload
parts: List of (part_number, etag) tuples in order
"""
object_name = "doc/" + str(object_id)
# Convert to Part objects as expected by minio
part_objects = [
Part(part_number, etag)
for part_number, etag in parts
]
self.client._complete_multipart_upload(
bucket_name=self.bucket_name,
object_name=object_name,
upload_id=upload_id,
parts=part_objects,
)
logger.info(f"Completed multipart upload for {object_id}")
def abort_multipart_upload(self, object_id: UUID, upload_id: str) -> None:
"""
Abort a multipart upload, cleaning up any uploaded parts.
Args:
object_id: The UUID of the object
upload_id: The S3 upload_id from create_multipart_upload
"""
object_name = "doc/" + str(object_id)
self.client._abort_multipart_upload(
bucket_name=self.bucket_name,
object_name=object_name,
upload_id=upload_id,
)
logger.info(f"Aborted multipart upload {upload_id} for {object_id}")

View file

@ -1,17 +1,24 @@
from .. schema import LibrarianRequest, LibrarianResponse, Error, Triple
from .. schema import UploadSession
from .. knowledge import hash
from .. exceptions import RequestError
from .. tables.library import LibraryTableStore
from . blob_store import BlobStore
import base64
import json
import logging
import math
import time
import uuid
# Module logger
logger = logging.getLogger(__name__)
# Default chunk size for multipart uploads
DEFAULT_CHUNK_SIZE = 2 * 1024 * 1024 # 2MB default
class Librarian:
def __init__(
@ -20,6 +27,7 @@ class Librarian:
object_store_endpoint, object_store_access_key, object_store_secret_key,
bucket_name, keyspace, load_document,
object_store_use_ssl=False, object_store_region=None,
min_chunk_size=1, # Default: no minimum (for Garage)
):
self.blob_store = BlobStore(
@ -32,6 +40,7 @@ class Librarian:
)
self.load_document = load_document
self.min_chunk_size = min_chunk_size
async def add_document(self, request):
@ -66,13 +75,7 @@ class Librarian:
logger.debug("Add complete")
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = None,
processing_metadatas = None,
)
return LibrarianResponse()
async def remove_document(self, request):
@ -84,6 +87,21 @@ class Librarian:
):
raise RuntimeError("Document does not exist")
# First, cascade delete all child documents
children = await self.table_store.list_children(request.document_id)
for child in children:
logger.debug(f"Cascade deleting child document {child.id}")
try:
child_object_id = await self.table_store.get_document_object_id(
child.user,
child.id
)
await self.blob_store.remove(child_object_id)
await self.table_store.remove_document(child.user, child.id)
except Exception as e:
logger.warning(f"Failed to delete child document {child.id}: {e}")
# Now remove the parent document
object_id = await self.table_store.get_document_object_id(
request.user,
request.document_id
@ -100,13 +118,7 @@ class Librarian:
logger.debug("Remove complete")
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = None,
processing_metadatas = None,
)
return LibrarianResponse()
async def update_document(self, request):
@ -124,13 +136,7 @@ class Librarian:
logger.debug("Update complete")
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = None,
processing_metadatas = None,
)
return LibrarianResponse()
async def get_document_metadata(self, request):
@ -147,8 +153,6 @@ class Librarian:
error = None,
document_metadata = doc,
content = None,
document_metadatas = None,
processing_metadatas = None,
)
async def get_document_content(self, request):
@ -170,8 +174,6 @@ class Librarian:
error = None,
document_metadata = None,
content = base64.b64encode(content),
document_metadatas = None,
processing_metadatas = None,
)
async def add_processing(self, request):
@ -217,13 +219,7 @@ class Librarian:
logger.debug("Add complete")
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = None,
processing_metadatas = None,
)
return LibrarianResponse()
async def remove_processing(self, request):
@ -243,24 +239,23 @@ class Librarian:
logger.debug("Remove complete")
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = None,
processing_metadatas = None,
)
return LibrarianResponse()
async def list_documents(self, request):
docs = await self.table_store.list_documents(request.user)
# Filter out child documents and answer documents by default
include_children = getattr(request, 'include_children', False)
if not include_children:
docs = [
doc for doc in docs
if not doc.parent_id # Only include top-level documents
and doc.document_type != "answer" # Exclude GraphRAG answers
]
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = docs,
processing_metadatas = None,
)
async def list_processing(self, request):
@ -268,10 +263,444 @@ class Librarian:
procs = await self.table_store.list_processing(request.user)
return LibrarianResponse(
error = None,
document_metadata = None,
content = None,
document_metadatas = None,
processing_metadatas = procs,
)
# Chunked upload operations
async def begin_upload(self, request):
"""
Initialize a chunked upload session.
Creates an S3 multipart upload and stores session state in Cassandra.
"""
logger.info(f"Beginning chunked upload for document {request.document_metadata.id}")
if request.document_metadata.kind not in ("text/plain", "application/pdf"):
raise RequestError(
"Invalid document kind: " + request.document_metadata.kind
)
if await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.id
):
raise RequestError("Document already exists")
# Validate sizes
total_size = request.total_size
if total_size <= 0:
raise RequestError("total_size must be positive")
# Use provided chunk size or default
chunk_size = request.chunk_size if request.chunk_size > 0 else DEFAULT_CHUNK_SIZE
if chunk_size < self.min_chunk_size:
raise RequestError(
f"Chunk size {chunk_size} is below minimum {self.min_chunk_size}"
)
# Calculate total chunks
total_chunks = math.ceil(total_size / chunk_size)
# Generate IDs
upload_id = str(uuid.uuid4())
object_id = uuid.uuid4()
# Create S3 multipart upload
s3_upload_id = self.blob_store.create_multipart_upload(
object_id, request.document_metadata.kind
)
# Serialize document metadata for storage
doc_meta_json = json.dumps({
"id": request.document_metadata.id,
"time": request.document_metadata.time,
"kind": request.document_metadata.kind,
"title": request.document_metadata.title,
"comments": request.document_metadata.comments,
"user": request.document_metadata.user,
"tags": request.document_metadata.tags,
})
# Store session in Cassandra
await self.table_store.create_upload_session(
upload_id=upload_id,
user=request.document_metadata.user,
document_id=request.document_metadata.id,
document_metadata=doc_meta_json,
s3_upload_id=s3_upload_id,
object_id=object_id,
total_size=total_size,
chunk_size=chunk_size,
total_chunks=total_chunks,
)
logger.info(f"Created upload session {upload_id} with {total_chunks} chunks")
return LibrarianResponse(
error=None,
upload_id=upload_id,
chunk_size=chunk_size,
total_chunks=total_chunks,
)
async def upload_chunk(self, request):
"""
Upload a single chunk of a document.
Forwards the chunk to S3 and updates session state.
"""
logger.debug(f"Uploading chunk {request.chunk_index} for upload {request.upload_id}")
# Get session
session = await self.table_store.get_upload_session(request.upload_id)
if session is None:
raise RequestError("Upload session not found or expired")
# Validate ownership
if session["user"] != request.user:
raise RequestError("Not authorized to upload to this session")
# Validate chunk index
if request.chunk_index < 0 or request.chunk_index >= session["total_chunks"]:
raise RequestError(
f"Invalid chunk index {request.chunk_index}, "
f"must be 0-{session['total_chunks']-1}"
)
# Decode content
content = base64.b64decode(request.content)
# Upload to S3 (part numbers are 1-indexed in S3)
part_number = request.chunk_index + 1
etag = self.blob_store.upload_part(
object_id=session["object_id"],
upload_id=session["s3_upload_id"],
part_number=part_number,
data=content,
)
# Update session with chunk info
await self.table_store.update_upload_session_chunk(
upload_id=request.upload_id,
chunk_index=request.chunk_index,
etag=etag,
)
# Calculate progress
chunks_received = session["chunks_received"]
# Add this chunk if not already present
if request.chunk_index not in chunks_received:
chunks_received[request.chunk_index] = etag
num_chunks_received = len(chunks_received) + 1 # +1 for this chunk
bytes_received = num_chunks_received * session["chunk_size"]
# Adjust for last chunk potentially being smaller
if bytes_received > session["total_size"]:
bytes_received = session["total_size"]
logger.debug(f"Chunk {request.chunk_index} uploaded, {num_chunks_received}/{session['total_chunks']} complete")
return LibrarianResponse(
error=None,
upload_id=request.upload_id,
chunk_index=request.chunk_index,
chunks_received=num_chunks_received,
total_chunks=session["total_chunks"],
bytes_received=bytes_received,
total_bytes=session["total_size"],
)
async def complete_upload(self, request):
"""
Finalize a chunked upload and create the document.
Completes the S3 multipart upload and creates the document metadata.
"""
logger.info(f"Completing upload {request.upload_id}")
# Get session
session = await self.table_store.get_upload_session(request.upload_id)
if session is None:
raise RequestError("Upload session not found or expired")
# Validate ownership
if session["user"] != request.user:
raise RequestError("Not authorized to complete this upload")
# Verify all chunks received
chunks_received = session["chunks_received"]
if len(chunks_received) != session["total_chunks"]:
missing = [
i for i in range(session["total_chunks"])
if i not in chunks_received
]
raise RequestError(
f"Missing chunks: {missing[:10]}{'...' if len(missing) > 10 else ''}"
)
# Build parts list for S3 (sorted by part number)
parts = [
(chunk_index + 1, etag) # S3 part numbers are 1-indexed
for chunk_index, etag in sorted(chunks_received.items())
]
# Complete S3 multipart upload
self.blob_store.complete_multipart_upload(
object_id=session["object_id"],
upload_id=session["s3_upload_id"],
parts=parts,
)
# Parse document metadata from session
doc_meta_dict = json.loads(session["document_metadata"])
# Create DocumentMetadata object
from .. schema import DocumentMetadata
doc_metadata = DocumentMetadata(
id=doc_meta_dict["id"],
time=doc_meta_dict.get("time", int(time.time())),
kind=doc_meta_dict["kind"],
title=doc_meta_dict.get("title", ""),
comments=doc_meta_dict.get("comments", ""),
user=doc_meta_dict["user"],
tags=doc_meta_dict.get("tags", []),
metadata=[], # Triples not supported in chunked upload yet
)
# Add document to table
await self.table_store.add_document(doc_metadata, session["object_id"])
# Delete upload session
await self.table_store.delete_upload_session(request.upload_id)
logger.info(f"Upload {request.upload_id} completed, document {doc_metadata.id} created")
return LibrarianResponse(
error=None,
document_id=doc_metadata.id,
object_id=str(session["object_id"]),
)
async def abort_upload(self, request):
"""
Cancel a chunked upload and clean up resources.
"""
logger.info(f"Aborting upload {request.upload_id}")
# Get session
session = await self.table_store.get_upload_session(request.upload_id)
if session is None:
raise RequestError("Upload session not found or expired")
# Validate ownership
if session["user"] != request.user:
raise RequestError("Not authorized to abort this upload")
# Abort S3 multipart upload
self.blob_store.abort_multipart_upload(
object_id=session["object_id"],
upload_id=session["s3_upload_id"],
)
# Delete session from Cassandra
await self.table_store.delete_upload_session(request.upload_id)
logger.info(f"Upload {request.upload_id} aborted")
return LibrarianResponse(error=None)
async def get_upload_status(self, request):
"""
Get the status of an in-progress upload.
"""
logger.debug(f"Getting status for upload {request.upload_id}")
# Get session
session = await self.table_store.get_upload_session(request.upload_id)
if session is None:
return LibrarianResponse(
error=None,
upload_id=request.upload_id,
upload_state="expired",
)
# Validate ownership
if session["user"] != request.user:
raise RequestError("Not authorized to view this upload")
chunks_received = session["chunks_received"]
received_list = sorted(chunks_received.keys())
missing_list = [
i for i in range(session["total_chunks"])
if i not in chunks_received
]
bytes_received = len(chunks_received) * session["chunk_size"]
if bytes_received > session["total_size"]:
bytes_received = session["total_size"]
return LibrarianResponse(
error=None,
upload_id=request.upload_id,
upload_state="in-progress",
received_chunks=received_list,
missing_chunks=missing_list,
chunks_received=len(chunks_received),
total_chunks=session["total_chunks"],
bytes_received=bytes_received,
total_bytes=session["total_size"],
)
async def list_uploads(self, request):
"""
List all in-progress uploads for a user.
"""
logger.debug(f"Listing uploads for user {request.user}")
sessions = await self.table_store.list_upload_sessions(request.user)
upload_sessions = [
UploadSession(
upload_id=s["upload_id"],
document_id=s["document_id"],
document_metadata_json=s.get("document_metadata", ""),
total_size=s["total_size"],
chunk_size=s["chunk_size"],
total_chunks=s["total_chunks"],
chunks_received=s["chunks_received"],
created_at=str(s.get("created_at", "")),
)
for s in sessions
]
return LibrarianResponse(
error=None,
upload_sessions=upload_sessions,
)
# Child document operations
async def add_child_document(self, request):
"""
Add a child document linked to a parent document.
Child documents are typically extracted content (e.g., pages from a PDF).
They have a parent_id pointing to the source document and document_type
set to "extracted".
"""
logger.info(f"Adding child document {request.document_metadata.id} "
f"for parent {request.document_metadata.parent_id}")
if not request.document_metadata.parent_id:
raise RequestError("parent_id is required for child documents")
# Verify parent exists
if not await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.parent_id
):
raise RequestError(
f"Parent document {request.document_metadata.parent_id} does not exist"
)
if await self.table_store.document_exists(
request.document_metadata.user,
request.document_metadata.id
):
raise RequestError("Document already exists")
# Set document_type if not specified by caller
# Valid types: "page", "chunk", or "extracted" (legacy)
if not request.document_metadata.document_type or request.document_metadata.document_type == "source":
request.document_metadata.document_type = "extracted"
# Create object ID for blob
object_id = uuid.uuid4()
logger.debug("Adding blob...")
await self.blob_store.add(
object_id, base64.b64decode(request.content),
request.document_metadata.kind
)
logger.debug("Adding to table...")
await self.table_store.add_document(
request.document_metadata, object_id
)
logger.debug("Add child document complete")
return LibrarianResponse(
error=None,
document_id=request.document_metadata.id,
)
async def list_children(self, request):
"""
List all child documents for a given parent document.
"""
logger.debug(f"Listing children for parent {request.document_id}")
children = await self.table_store.list_children(request.document_id)
return LibrarianResponse(
error=None,
document_metadatas=children,
)
async def stream_document(self, request):
"""
Stream document content in chunks.
This is an async generator that yields document content in smaller chunks,
allowing memory-efficient processing of large documents. Each yielded
response includes chunk_index and total_chunks for tracking progress.
Completion is determined by chunk_index reaching total_chunks - 1.
"""
logger.debug(f"Streaming document {request.document_id}")
DEFAULT_CHUNK_SIZE = 1024 * 1024 # 1MB default
chunk_size = request.chunk_size if request.chunk_size > 0 else DEFAULT_CHUNK_SIZE
if chunk_size < self.min_chunk_size:
raise RequestError(
f"Chunk size {chunk_size} is below minimum {self.min_chunk_size}"
)
object_id = await self.table_store.get_document_object_id(
request.user,
request.document_id
)
# Get size via stat (no content download)
total_size = await self.blob_store.get_size(object_id)
total_chunks = math.ceil(total_size / chunk_size)
# Stream all chunks
for chunk_index in range(total_chunks):
# Calculate byte range
offset = chunk_index * chunk_size
length = min(chunk_size, total_size - offset)
# Fetch only the requested range
chunk_content = await self.blob_store.get_range(object_id, offset, length)
is_last = (chunk_index == total_chunks - 1)
logger.debug(f"Streaming chunk {chunk_index + 1}/{total_chunks}, "
f"bytes {offset}-{offset + length} of {total_size}")
yield LibrarianResponse(
error=None,
content=base64.b64encode(chunk_content),
chunk_index=chunk_index,
chunks_received=chunk_index + 1,
total_chunks=total_chunks,
bytes_received=offset + length,
total_bytes=total_size,
is_final=is_last,
)

View file

@ -23,9 +23,14 @@ from .. schema import config_request_queue, config_response_queue
from .. schema import Document, Metadata
from .. schema import TextDocument, Metadata
from .. schema import Triples
from .. exceptions import RequestError
from .. provenance import (
document_uri, document_triples, get_vocabulary_triples,
)
from . librarian import Librarian
from . collection_manager import CollectionManager
@ -47,6 +52,7 @@ default_object_store_secret_key = "object-password"
default_object_store_use_ssl = False
default_object_store_region = None
default_cassandra_host = "cassandra"
default_min_chunk_size = 1 # No minimum by default (for Garage)
bucket_name = "library"
@ -100,6 +106,11 @@ class Processor(AsyncProcessor):
default_object_store_region
)
min_chunk_size = params.get(
"min_chunk_size",
default_min_chunk_size
)
cassandra_host = params.get("cassandra_host")
cassandra_username = params.get("cassandra_username")
cassandra_password = params.get("cassandra_password")
@ -226,6 +237,7 @@ class Processor(AsyncProcessor):
load_document = self.load_document,
object_store_use_ssl = object_store_use_ssl,
object_store_region = object_store_region,
min_chunk_size = min_chunk_size,
)
self.collection_manager = CollectionManager(
@ -271,6 +283,70 @@ class Processor(AsyncProcessor):
pass
# Threshold for sending document_id instead of inline content (2MB)
STREAMING_THRESHOLD = 2 * 1024 * 1024
async def emit_document_provenance(self, document, processing, triples_queue):
"""
Emit document provenance metadata to the knowledge graph.
This emits:
1. Vocabulary bootstrap triples (idempotent, safe to re-emit)
2. Document metadata as PROV-O triples
"""
logger.debug(f"Emitting document provenance for {document.id}")
# Build document URI and provenance triples
doc_uri = document_uri(document.id)
# Get page count for PDFs (if available from document metadata)
page_count = None
if document.kind == "application/pdf":
# Page count might be in document metadata triples
# For now, we don't have it at this point - it gets determined during extraction
pass
# Build document metadata triples
prov_triples = document_triples(
doc_uri=doc_uri,
title=document.title if document.title else None,
mime_type=document.kind,
)
# Include any existing metadata triples from the document
if document.metadata:
prov_triples.extend(document.metadata)
# Get vocabulary bootstrap triples (idempotent)
vocab_triples = get_vocabulary_triples()
# Combine all triples
all_triples = vocab_triples + prov_triples
# Create publisher and emit
triples_pub = Publisher(
self.pubsub, triples_queue, schema=Triples
)
try:
await triples_pub.start()
triples_msg = Triples(
metadata=Metadata(
id=doc_uri,
root=document.id,
user=processing.user,
collection=processing.collection,
),
triples=all_triples,
)
await triples_pub.send(None, triples_msg)
logger.debug(f"Emitted {len(all_triples)} provenance triples for {document.id}")
finally:
await triples_pub.stop()
async def load_document(self, document, processing, content):
logger.debug("Ready for document processing...")
@ -291,27 +367,64 @@ class Processor(AsyncProcessor):
q = flow["interfaces"][kind]
if kind == "text-load":
doc = TextDocument(
metadata = Metadata(
id = document.id,
metadata = document.metadata,
user = processing.user,
collection = processing.collection
),
text = content,
# Emit document provenance to knowledge graph
if "triples-store" in flow["interfaces"]:
await self.emit_document_provenance(
document, processing, flow["interfaces"]["triples-store"]
)
if kind == "text-load":
# For large text documents, send document_id for streaming retrieval
if len(content) >= self.STREAMING_THRESHOLD:
logger.info(f"Text document {document.id} is large ({len(content)} bytes), "
f"sending document_id for streaming retrieval")
doc = TextDocument(
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
document_id = document.id,
text = b"", # Empty, receiver will fetch via librarian
)
else:
doc = TextDocument(
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
text = content,
)
schema = TextDocument
else:
doc = Document(
metadata = Metadata(
id = document.id,
metadata = document.metadata,
user = processing.user,
collection = processing.collection
),
data = base64.b64encode(content).decode("utf-8")
)
# For large PDF documents, send document_id for streaming retrieval
# instead of embedding the entire content in the message
if len(content) >= self.STREAMING_THRESHOLD:
logger.info(f"Document {document.id} is large ({len(content)} bytes), "
f"sending document_id for streaming retrieval")
doc = Document(
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
document_id = document.id,
data = b"", # Empty data, receiver will fetch via API
)
else:
doc = Document(
metadata = Metadata(
id = document.id,
root = document.id,
user = processing.user,
collection = processing.collection
),
data = base64.b64encode(content).decode("utf-8")
)
schema = Document
logger.debug(f"Submitting to queue {q}...")
@ -361,6 +474,17 @@ class Processor(AsyncProcessor):
"remove-processing": self.librarian.remove_processing,
"list-documents": self.librarian.list_documents,
"list-processing": self.librarian.list_processing,
# Chunked upload operations
"begin-upload": self.librarian.begin_upload,
"upload-chunk": self.librarian.upload_chunk,
"complete-upload": self.librarian.complete_upload,
"abort-upload": self.librarian.abort_upload,
"get-upload-status": self.librarian.get_upload_status,
"list-uploads": self.librarian.list_uploads,
# Child document and streaming operations
"add-child-document": self.librarian.add_child_document,
"list-children": self.librarian.list_children,
"stream-document": self.librarian.stream_document,
}
if v.operation not in impls:
@ -380,6 +504,15 @@ class Processor(AsyncProcessor):
try:
# Handle streaming operations specially
if v.operation == "stream-document":
async for resp in self.librarian.stream_document(v):
await self.librarian_response_producer.send(
resp, properties={"id": id}
)
return
# Non-streaming operations
resp = await self.process_request(v)
await self.librarian_response_producer.send(
@ -393,7 +526,7 @@ class Processor(AsyncProcessor):
error = Error(
type = "request-error",
message = str(e),
)
),
)
await self.librarian_response_producer.send(
@ -406,7 +539,7 @@ class Processor(AsyncProcessor):
error = Error(
type = "unexpected-error",
message = str(e),
)
),
)
await self.librarian_response_producer.send(
@ -538,6 +671,14 @@ class Processor(AsyncProcessor):
help='Object storage region (optional)',
)
parser.add_argument(
'--min-chunk-size',
type=int,
default=default_min_chunk_size,
help=f'Minimum chunk size in bytes for uploads/downloads '
f'(default: {default_min_chunk_size})',
)
add_cassandra_args(parser)
def run():

View file

@ -55,11 +55,13 @@ class Processor(LlmService):
self.max_output = max_output
self.default_model = model
def build_prompt(self, system, content, temperature=None, stream=False):
def build_prompt(self, system, content, temperature=None, stream=False, model=None):
# Use provided temperature or fall back to default
effective_temperature = temperature if temperature is not None else self.temperature
model_name = model or self.default_model
data = {
"model": model_name,
"messages": [
{
"role": "system", "content": system
@ -100,7 +102,8 @@ class Processor(LlmService):
raise TooManyRequests()
if resp.status_code != 200:
raise RuntimeError("LLM failure")
logger.error(f"Azure API error: status={resp.status_code}, body={resp.text}")
raise RuntimeError(f"LLM failure: HTTP {resp.status_code}")
result = resp.json()
@ -121,7 +124,8 @@ class Processor(LlmService):
prompt = self.build_prompt(
system,
prompt,
effective_temperature
effective_temperature,
model=model_name
)
response = self.call_llm(prompt)
@ -174,7 +178,7 @@ class Processor(LlmService):
logger.debug(f"Using temperature: {effective_temperature}")
try:
body = self.build_prompt(system, prompt, effective_temperature, stream=True)
body = self.build_prompt(system, prompt, effective_temperature, stream=True, model=model_name)
url = self.endpoint
api_key = self.token
@ -190,7 +194,11 @@ class Processor(LlmService):
raise TooManyRequests()
if response.status_code != 200:
raise RuntimeError("LLM failure")
logger.error(f"Azure API error: status={response.status_code}, body={response.text}")
raise RuntimeError(f"LLM failure: HTTP {response.status_code}")
total_input_tokens = 0
total_output_tokens = 0
total_input_tokens = 0
total_output_tokens = 0
@ -279,6 +287,12 @@ class Processor(LlmService):
help=f'LLM max output tokens (default: {default_max_output})'
)
parser.add_argument(
'-m', '--model',
default=default_model,
help=f'LLM model name (default: {default_model})'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -1,13 +1,13 @@
"""
Document embeddings query service. Input is vector, output is an array
of chunks
of chunk_ids
"""
import logging
from .... direct.milvus_doc_embeddings import DocVectors
from .... schema import DocumentEmbeddingsResponse
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
@ -35,24 +35,31 @@ class Processor(DocumentEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case
if msg.limit <= 0:
return []
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit
)
chunks = []
for vec in msg.vectors:
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit
)
for r in resp:
chunk = r["entity"]["doc"]
chunks.append(chunk)
for r in resp:
chunk_id = r["entity"]["chunk_id"]
# Milvus returns distance, convert to similarity score
distance = r.get("distance", 0.0)
score = 1.0 - distance if distance else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
return chunks

View file

@ -1,7 +1,7 @@
"""
Document embeddings query service. Input is vector, output is an array
of chunks. Pinecone implementation.
of chunk_ids. Pinecone implementation.
"""
import logging
@ -11,6 +11,7 @@ import os
from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
from .... schema import ChunkMatch
from .... base import DocumentEmbeddingsQueryService
# Module logger
@ -51,36 +52,41 @@ class Processor(DocumentEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case
if msg.limit <= 0:
return []
dim = len(vec)
# Use dimension suffix in index name
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
# Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist")
return []
index = self.pinecone.Index(index_name)
results = index.query(
vector=vec,
top_k=msg.limit,
include_values=False,
include_metadata=True
)
chunks = []
for vec in msg.vectors:
dim = len(vec)
# Use dimension suffix in index name
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
# Check if index exists - skip if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist, skipping this vector")
continue
index = self.pinecone.Index(index_name)
results = index.query(
vector=vec,
top_k=msg.limit,
include_values=False,
include_metadata=True
)
for r in results.matches:
doc = r.metadata["doc"]
chunks.append(doc)
for r in results.matches:
chunk_id = r.metadata["chunk_id"]
score = r.score if hasattr(r, 'score') else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
return chunks

View file

@ -1,7 +1,7 @@
"""
Document embeddings query service. Input is vector, output is an array
of chunks
of chunk_ids
"""
import logging
@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import DocumentEmbeddingsResponse
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
from .... schema import Error
from .... base import DocumentEmbeddingsQueryService
@ -69,29 +69,34 @@ class Processor(DocumentEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
return []
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit,
with_payload=True,
).points
chunks = []
for vec in msg.vectors:
# Use dimension suffix in collection name
dim = len(vec)
collection = f"d_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, returning empty results")
continue
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit,
with_payload=True,
).points
for r in search_result:
ent = r.payload["doc"]
chunks.append(ent)
for r in search_result:
chunk_id = r.payload["chunk_id"]
score = r.score if hasattr(r, 'score') else 0.0
chunks.append(ChunkMatch(
chunk_id=chunk_id,
score=score,
))
return chunks

View file

@ -7,7 +7,7 @@ entities
import logging
from .... direct.milvus_graph_embeddings import EntityVectors
from .... schema import GraphEmbeddingsResponse
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
@ -41,42 +41,41 @@ class Processor(GraphEmbeddingsQueryService):
try:
entity_set = set()
entities = []
vec = msg.vector
if not vec:
return []
# Handle zero limit case
if msg.limit <= 0:
return []
for vec in msg.vectors:
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit * 2
)
resp = self.vecstore.search(
vec,
msg.user,
msg.collection,
limit=msg.limit * 2
)
entity_set = set()
entities = []
for r in resp:
ent = r["entity"]["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
for r in resp:
ent = r["entity"]["entity"]
# Milvus returns distance, convert to similarity score
distance = r.get("distance", 0.0)
score = 1.0 - distance if distance else 0.0
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# De-dupe entities, keep highest score
if ent not in entity_set:
entity_set.add(ent)
entities.append(EntityMatch(
entity=self.create_value(ent),
score=score,
))
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
if len(entities) >= msg.limit:
break
logger.debug("Send response...")
return entities

View file

@ -11,7 +11,7 @@ import os
from pinecone import Pinecone, ServerlessSpec
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
from .... schema import GraphEmbeddingsResponse
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
@ -59,57 +59,53 @@ class Processor(GraphEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Handle zero limit case
if msg.limit <= 0:
return []
dim = len(vec)
# Use dimension suffix in index name
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
# Check if index exists - return empty if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist")
return []
index = self.pinecone.Index(index_name)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities
results = index.query(
vector=vec,
top_k=msg.limit * 2,
include_values=False,
include_metadata=True
)
entity_set = set()
entities = []
for vec in msg.vectors:
for r in results.matches:
ent = r.metadata["entity"]
score = r.score if hasattr(r, 'score') else 0.0
dim = len(vec)
# Use dimension suffix in index name
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
# Check if index exists - skip if not
if not self.pinecone.has_index(index_name):
logger.info(f"Index {index_name} does not exist, skipping this vector")
continue
index = self.pinecone.Index(index_name)
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
results = index.query(
vector=vec,
top_k=msg.limit * 2,
include_values=False,
include_metadata=True
)
for r in results.matches:
ent = r.metadata["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# De-dupe entities, keep highest score
if ent not in entity_set:
entity_set.add(ent)
entities.append(EntityMatch(
entity=self.create_value(ent),
score=score,
))
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
if len(entities) >= msg.limit:
break
return entities

View file

@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams
from .... schema import GraphEmbeddingsResponse
from .... schema import GraphEmbeddingsResponse, EntityMatch
from .... schema import Error, Term, IRI, LITERAL
from .... base import GraphEmbeddingsQueryService
@ -75,49 +75,46 @@ class Processor(GraphEmbeddingsQueryService):
try:
vec = msg.vector
if not vec:
return []
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist")
return []
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) unique entities
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit * 2,
with_payload=True,
).points
entity_set = set()
entities = []
for vec in msg.vectors:
for r in search_result:
ent = r.payload["entity"]
score = r.score if hasattr(r, 'score') else 0.0
# Use dimension suffix in collection name
dim = len(vec)
collection = f"t_{msg.user}_{msg.collection}_{dim}"
# Check if collection exists - return empty if not
if not self.collection_exists(collection):
logger.info(f"Collection {collection} does not exist, skipping this vector")
continue
# Heuristic hack, get (2*limit), so that we have more chance
# of getting (limit) entities
search_result = self.qdrant.query_points(
collection_name=collection,
query=vec,
limit=msg.limit * 2,
with_payload=True,
).points
for r in search_result:
ent = r.payload["entity"]
# De-dupe entities
if ent not in entity_set:
entity_set.add(ent)
entities.append(ent)
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
# De-dupe entities, keep highest score
if ent not in entity_set:
entity_set.add(ent)
entities.append(EntityMatch(
entity=self.create_value(ent),
score=score,
))
# Keep adding entities until limit
if len(entity_set) >= msg.limit: break
ents2 = []
for ent in entities:
ents2.append(self.create_value(ent))
entities = ents2
if len(entities) >= msg.limit:
break
logger.debug("Send response...")
return entities

View file

@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
default_ident = "row-embeddings-query"
default_store_uri = 'http://localhost:6333'
default_concurrency = 10
class Processor(FlowProcessor):
@ -31,6 +32,7 @@ class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
store_uri = params.get("store_uri", default_store_uri)
api_key = params.get("api_key", None)
@ -47,7 +49,8 @@ class Processor(FlowProcessor):
ConsumerSpec(
name="request",
schema=RowEmbeddingsRequest,
handler=self.on_message
handler=self.on_message,
concurrency=concurrency,
)
)
@ -93,7 +96,9 @@ class Processor(FlowProcessor):
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
"""Execute row embeddings query"""
matches = []
vec = request.vector
if not vec:
return []
# Find the collection for this user/collection/schema
qdrant_collection = self.find_collection(
@ -105,47 +110,47 @@ class Processor(FlowProcessor):
f"No Qdrant collection found for "
f"{request.user}/{request.collection}/{request.schema_name}"
)
return []
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
matches = []
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
return matches
for vec in request.vectors:
try:
# Build optional filter for index_name
query_filter = None
if request.index_name:
query_filter = Filter(
must=[
FieldCondition(
key="index_name",
match=MatchValue(value=request.index_name)
)
]
)
# Query Qdrant
search_result = self.qdrant.query_points(
collection_name=qdrant_collection,
query=vec,
limit=request.limit,
with_payload=True,
query_filter=query_filter,
).points
# Convert to RowIndexMatch objects
for point in search_result:
payload = point.payload or {}
match = RowIndexMatch(
index_name=payload.get("index_name", ""),
index_value=payload.get("index_value", []),
text=payload.get("text", ""),
score=point.score if hasattr(point, 'score') else 0.0
)
matches.append(match)
except Exception as e:
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
raise
return matches
except Exception as e:
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
raise
async def on_message(self, msg, consumer, flow):
"""Handle incoming query request"""
@ -203,6 +208,13 @@ class Processor(FlowProcessor):
help='API key for Qdrant (default: None)'
)
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Number of concurrent requests (default: {default_concurrency})'
)
def run():
"""Entry point for row-embeddings-query-qdrant command"""

View file

@ -30,6 +30,7 @@ from ... graphql import GraphQLSchemaBuilder, SortDirection
logger = logging.getLogger(__name__)
default_ident = "rows-query"
default_concurrency = 10
class Processor(FlowProcessor):
@ -37,6 +38,7 @@ class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", default_concurrency)
# Get Cassandra parameters
cassandra_host = params.get("cassandra_host")
@ -69,7 +71,8 @@ class Processor(FlowProcessor):
ConsumerSpec(
name="request",
schema=RowsQueryRequest,
handler=self.on_message
handler=self.on_message,
concurrency=concurrency,
)
)
@ -517,6 +520,13 @@ class Processor(FlowProcessor):
help='Configuration type prefix for schemas (default: schema)'
)
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Number of concurrent requests (default: {default_concurrency})'
)
def run():
"""Entry point for rows-query-cassandra command"""

View file

@ -6,11 +6,14 @@ null. Output is a list of quads.
import logging
import json
from cassandra.query import SimpleStatement
from .... direct.cassandra_kg import (
EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH
EntityCentricKnowledgeGraph, DEFAULT_GRAPH
)
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
from .... schema import Term, Triple, IRI, LITERAL
from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK
from .... base import TriplesQueryService
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
@ -20,6 +23,36 @@ logger = logging.getLogger(__name__)
default_ident = "triples-query"
def serialize_triple(triple):
"""Serialize a Triple object to JSON for querying (must match storage format)."""
if triple is None:
return None
def term_to_dict(term):
if term is None:
return None
result = {"type": term.type}
if term.type == IRI:
result["iri"] = term.iri
elif term.type == LITERAL:
result["value"] = term.value
if term.datatype:
result["datatype"] = term.datatype
if term.language:
result["language"] = term.language
elif term.type == BLANK:
result["id"] = term.id
elif term.type == TRIPLE:
result["triple"] = serialize_triple(term.triple)
return result
return json.dumps({
"s": term_to_dict(triple.s),
"p": term_to_dict(triple.p),
"o": term_to_dict(triple.o),
})
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
@ -28,42 +61,88 @@ def get_term_value(term):
return term.iri
elif term.type == LITERAL:
return term.value
elif term.type == TRIPLE:
# Serialize nested triple to JSON (must match storage format)
return serialize_triple(term.triple)
else:
# For blank nodes or other types, use id or value
return term.id or term.value
def create_term(value, otype=None, dtype=None, lang=None):
def deserialize_term(term_dict):
"""Deserialize a term from JSON structure."""
if term_dict is None:
return None
term_type = term_dict.get("type", "")
if term_type == IRI:
return Term(type=IRI, iri=term_dict.get("iri", ""))
elif term_type == LITERAL:
return Term(
type=LITERAL,
value=term_dict.get("value", ""),
datatype=term_dict.get("datatype", ""),
language=term_dict.get("language", "")
)
elif term_type == TRIPLE:
# Recursive for nested triples
nested = term_dict.get("triple")
if nested:
return Term(
type=TRIPLE,
triple=Triple(
s=deserialize_term(nested.get("s")),
p=deserialize_term(nested.get("p")),
o=deserialize_term(nested.get("o")),
)
)
# Fallback
return Term(type=LITERAL, value=str(term_dict))
def create_term(value, term_type=None, datatype=None, language=None):
"""
Create a Term from a string value, optionally using type metadata.
Args:
value: The string value
otype: Object type - 'u' (URI), 'l' (literal), 't' (triple)
dtype: XSD datatype (for literals)
lang: Language tag (for literals)
term_type: 'u' (IRI), 'l' (literal), 't' (triple)
datatype: XSD datatype for literals
language: Language tag for literals
If otype is provided, uses it to determine Term type.
Otherwise falls back to URL detection heuristic.
If term_type is provided, uses it to determine Term type.
Otherwise falls back to URL detection heuristic for object values.
"""
if otype is not None:
if otype == 'u':
return Term(type=IRI, iri=value)
elif otype == 'l':
return Term(
type=LITERAL,
value=value,
datatype=dtype or "",
language=lang or ""
)
elif otype == 't':
# Triple/reification - treat as IRI for now
return Term(type=IRI, iri=value)
else:
# Unknown otype, fall back to heuristic
pass
if term_type == 'u':
return Term(type=IRI, iri=value)
elif term_type == 'l':
return Term(
type=LITERAL,
value=value,
datatype=datatype or "",
language=language or ""
)
elif term_type == 't':
# Triple/reification - parse JSON and create nested Triple
try:
triple_data = json.loads(value) if isinstance(value, str) else value
if isinstance(triple_data, dict):
return Term(
type=TRIPLE,
triple=Triple(
s=deserialize_term(triple_data.get("s")),
p=deserialize_term(triple_data.get("p")),
o=deserialize_term(triple_data.get("o")),
)
)
except (json.JSONDecodeError, TypeError) as e:
logger.warning(f"Failed to parse triple JSON: {e}")
# Fallback if parsing fails
return Term(type=LITERAL, value=str(value))
elif term_type is not None:
# Unknown term_type, fall back to heuristic
pass
# Heuristic fallback for backwards compatibility
# Heuristic fallback for backwards compatibility (object values only)
if value.startswith("http://") or value.startswith("https://"):
return Term(type=IRI, iri=value)
else:
@ -98,28 +177,30 @@ class Processor(TriplesQueryService):
self.cassandra_password = password
self.table = None
def ensure_connection(self, user):
"""Ensure we have a connection to the correct keyspace."""
if user != self.table:
KGClass = EntityCentricKnowledgeGraph
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
username=self.cassandra_username,
password=self.cassandra_password
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=user,
)
self.table = user
async def query_triples(self, query):
try:
user = query.user
if user != self.table:
# Use factory function to select implementation
KGClass = EntityCentricKnowledgeGraph
if self.cassandra_username and self.cassandra_password:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=query.user,
username=self.cassandra_username, password=self.cassandra_password
)
else:
self.tg = KGClass(
hosts=self.cassandra_host,
keyspace=query.user,
)
self.table = user
self.ensure_connection(query.user)
# Extract values from query
s_val = get_term_value(query.s)
@ -127,13 +208,13 @@ class Processor(TriplesQueryService):
o_val = get_term_value(query.o)
g_val = query.g # Already a string or None
# Helper to extract object metadata from result row
def get_o_metadata(t):
"""Extract otype/dtype/lang from result row if available"""
otype = getattr(t, 'otype', None)
dtype = getattr(t, 'dtype', None)
lang = getattr(t, 'lang', None)
return otype, dtype, lang
def get_object_metadata(row):
"""Extract term type metadata from result row"""
return (
getattr(row, 'otype', None),
getattr(row, 'dtype', None),
getattr(row, 'lang', None),
)
quads = []
@ -148,8 +229,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((s_val, p_val, o_val, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, p_val, o_val, g, term_type, datatype, language))
else:
# SP specified
resp = self.tg.get_sp(
@ -158,8 +239,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((s_val, p_val, t.o, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, p_val, t.o, g, term_type, datatype, language))
else:
if o_val is not None:
# SO specified
@ -169,8 +250,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((s_val, t.p, o_val, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, t.p, o_val, g, term_type, datatype, language))
else:
# S only
resp = self.tg.get_s(
@ -179,8 +260,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((s_val, t.p, t.o, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((s_val, t.p, t.o, g, term_type, datatype, language))
else:
if p_val is not None:
if o_val is not None:
@ -191,8 +272,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((t.s, p_val, o_val, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, p_val, o_val, g, term_type, datatype, language))
else:
# P only
resp = self.tg.get_p(
@ -201,8 +282,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((t.s, p_val, t.o, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, p_val, t.o, g, term_type, datatype, language))
else:
if o_val is not None:
# O only
@ -212,8 +293,8 @@ class Processor(TriplesQueryService):
)
for t in resp:
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((t.s, t.p, o_val, g, otype, dtype, lang))
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, t.p, o_val, g, term_type, datatype, language))
else:
# Nothing specified - get all
resp = self.tg.get_all(
@ -223,16 +304,24 @@ class Processor(TriplesQueryService):
for t in resp:
# Note: quads_by_collection uses 'd' for graph field
g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH
otype, dtype, lang = get_o_metadata(t)
quads.append((t.s, t.p, t.o, g, otype, dtype, lang))
# Filter by graph
# g_val=None means all graphs (no filter)
# g_val="" means default graph only
# otherwise filter to specific named graph
if g_val is not None:
if g != g_val:
continue
term_type, datatype, language = get_object_metadata(t)
quads.append((t.s, t.p, t.o, g, term_type, datatype, language))
# Convert to Triple objects (with g field)
# Use otype/dtype/lang for proper Term reconstruction if available
# s and p are always IRIs in RDF
# Object uses term_type/datatype/language metadata from database
triples = [
Triple(
s=create_term(q[0]),
p=create_term(q[1]),
o=create_term(q[2], otype=q[4], dtype=q[5], lang=q[6]),
s=create_term(q[0], term_type='u'),
p=create_term(q[1], term_type='u'),
o=create_term(q[2], term_type=q[4], datatype=q[5], language=q[6]),
g=q[3] if q[3] != DEFAULT_GRAPH else None
)
for q in quads
@ -245,6 +334,104 @@ class Processor(TriplesQueryService):
logger.error(f"Exception querying triples: {e}", exc_info=True)
raise e
async def query_triples_stream(self, query):
"""
Streaming query - yields (batch, is_final) tuples.
Uses Cassandra's paging to fetch results incrementally.
"""
try:
self.ensure_connection(query.user)
batch_size = query.batch_size if query.batch_size > 0 else 20
limit = query.limit if query.limit > 0 else 10000
# Extract query pattern
s_val = get_term_value(query.s)
p_val = get_term_value(query.p)
o_val = get_term_value(query.o)
g_val = query.g
def get_object_metadata(row):
"""Extract term type metadata from result row"""
return (
getattr(row, 'otype', None),
getattr(row, 'dtype', None),
getattr(row, 'lang', None),
)
# For streaming, we need to execute with fetch_size
# Use the collection table for get_all queries (most common streaming case)
# Determine which query to use based on pattern
if s_val is None and p_val is None and o_val is None:
# Get all - use collection table with paging
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s"
params = [query.collection]
else:
# For specific patterns, fall back to non-streaming
# (these typically return small result sets anyway)
async for batch, is_final in self._fallback_stream(query, batch_size):
yield batch, is_final
return
# Create statement with fetch_size for true streaming
statement = SimpleStatement(cql, fetch_size=batch_size)
result_set = self.tg.session.execute(statement, params)
batch = []
count = 0
for row in result_set:
if count >= limit:
break
g = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
# Filter by graph
# g_val=None means all graphs (no filter)
# g_val="" means default graph only
# otherwise filter to specific named graph
if g_val is not None:
if g != g_val:
continue
term_type, datatype, language = get_object_metadata(row)
# s and p are always IRIs in RDF
triple = Triple(
s=create_term(row.s, term_type='u'),
p=create_term(row.p, term_type='u'),
o=create_term(row.o, term_type=term_type, datatype=datatype, language=language),
g=g if g != DEFAULT_GRAPH else None
)
batch.append(triple)
count += 1
# Yield batch when full (never mark as final mid-stream)
if len(batch) >= batch_size:
yield batch, False
batch = []
# Always yield final batch to signal completion
# This handles: remaining rows, empty result set, or exact batch boundary
yield batch, True
except Exception as e:
logger.error(f"Exception in streaming query: {e}", exc_info=True)
raise e
async def _fallback_stream(self, query, batch_size):
"""Fallback to non-streaming query with post-hoc batching."""
triples = await self.query_triples(query)
for i in range(0, len(triples), batch_size):
batch = triples[i:i + batch_size]
is_final = (i + batch_size >= len(triples))
yield batch, is_final
if len(triples) == 0:
yield [], True
@staticmethod
def add_args(parser):

View file

@ -1,6 +1,22 @@
import asyncio
import logging
import uuid
from datetime import datetime
# Provenance imports
from trustgraph.provenance import (
docrag_question_uri,
docrag_grounding_uri,
docrag_exploration_uri,
docrag_synthesis_uri,
docrag_question_triples,
grounding_triples,
docrag_exploration_triples,
docrag_synthesis_triples,
set_graph,
GRAPH_RETRIEVAL,
)
# Module logger
logger = logging.getLogger(__name__)
@ -19,41 +35,106 @@ class Query:
self.verbose = verbose
self.doc_limit = doc_limit
async def get_vector(self, query):
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
# Fallback to raw query if no concepts extracted
if not concepts:
concepts = [query]
if self.verbose:
logger.debug(f"Extracted concepts: {concepts}")
return concepts
async def get_vectors(self, concepts):
"""Compute embeddings for a list of concepts."""
if self.verbose:
logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed(query)
qembeds = await self.rag.embeddings_client.embed(concepts)
if self.verbose:
logger.debug("Embeddings computed")
return qembeds
async def get_docs(self, query):
async def get_docs(self, concepts):
"""
Get documents (chunks) matching the extracted concepts.
vectors = await self.get_vector(query)
Returns:
tuple: (docs, chunk_ids) where:
- docs: list of document content strings
- chunk_ids: list of chunk IDs that were successfully fetched
"""
vectors = await self.get_vectors(concepts)
if self.verbose:
logger.debug("Getting documents...")
logger.debug("Getting chunks from embeddings store...")
docs = await self.rag.doc_embeddings_client.query(
vectors, limit=self.doc_limit,
user=self.user, collection=self.collection,
# Query chunk matches for each concept concurrently
per_concept_limit = max(
1, self.doc_limit // len(vectors)
)
if self.verbose:
logger.debug("Documents:")
for doc in docs:
logger.debug(f" {doc}")
async def query_concept(vec):
return await self.rag.doc_embeddings_client.query(
vector=vec, limit=per_concept_limit,
user=self.user, collection=self.collection,
)
return docs
results = await asyncio.gather(
*[query_concept(v) for v in vectors]
)
# Deduplicate chunk matches by chunk_id
seen = set()
chunk_matches = []
for matches in results:
for match in matches:
if match.chunk_id and match.chunk_id not in seen:
seen.add(match.chunk_id)
chunk_matches.append(match)
if self.verbose:
logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
# Fetch chunk content from Garage
docs = []
chunk_ids = []
for match in chunk_matches:
if match.chunk_id:
try:
content = await self.rag.fetch_chunk(match.chunk_id, self.user)
docs.append(content)
chunk_ids.append(match.chunk_id)
except Exception as e:
logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}")
if self.verbose:
logger.debug("Documents fetched:")
for doc in docs:
logger.debug(f" {doc[:100]}...")
return docs, chunk_ids
class DocumentRag:
def __init__(
self, prompt_client, embeddings_client, doc_embeddings_client,
fetch_chunk,
verbose=False,
):
@ -62,6 +143,7 @@ class DocumentRag:
self.prompt_client = prompt_client
self.embeddings_client = embeddings_client
self.doc_embeddings_client = doc_embeddings_client
self.fetch_chunk = fetch_chunk
if self.verbose:
logger.debug("DocumentRag initialized")
@ -69,17 +151,69 @@ class DocumentRag:
async def query(
self, query, user="trustgraph", collection="default",
doc_limit=20, streaming=False, chunk_callback=None,
explain_callback=None, save_answer_callback=None,
):
"""
Execute a Document RAG query with optional explainability tracking.
Args:
query: The query string
user: User identifier
collection: Collection identifier
doc_limit: Max chunks to retrieve
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for explainability
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
Returns:
str: The synthesized answer text
"""
if self.verbose:
logger.debug("Constructing prompt...")
# Generate explainability URIs upfront
session_id = str(uuid.uuid4())
q_uri = docrag_question_uri(session_id)
gnd_uri = docrag_grounding_uri(session_id)
exp_uri = docrag_exploration_uri(session_id)
syn_uri = docrag_synthesis_uri(session_id)
timestamp = datetime.utcnow().isoformat() + "Z"
# Emit question explainability immediately
if explain_callback:
q_triples = set_graph(
docrag_question_triples(q_uri, query, timestamp),
GRAPH_RETRIEVAL
)
await explain_callback(q_triples, q_uri)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose,
doc_limit=doc_limit
)
docs = await q.get_docs(query)
# Extract concepts from query (grounding step)
concepts = await q.extract_concepts(query)
# Emit grounding explainability after concept extraction
if explain_callback:
gnd_triples = set_graph(
grounding_triples(gnd_uri, q_uri, concepts),
GRAPH_RETRIEVAL
)
await explain_callback(gnd_triples, gnd_uri)
docs, chunk_ids = await q.get_docs(concepts)
# Emit exploration explainability after chunks retrieved
if explain_callback:
exp_triples = set_graph(
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
GRAPH_RETRIEVAL
)
await explain_callback(exp_triples, exp_uri)
if self.verbose:
logger.debug("Invoking LLM...")
@ -87,12 +221,21 @@ class DocumentRag:
logger.debug(f"Query: {query}")
if streaming and chunk_callback:
# Accumulate chunks for answer storage while forwarding to callback
accumulated_chunks = []
async def accumulating_callback(chunk, end_of_stream):
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
resp = await self.prompt_client.document_prompt(
query=query,
documents=docs,
streaming=True,
chunk_callback=chunk_callback
chunk_callback=accumulating_callback
)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
resp = await self.prompt_client.document_prompt(
query=query,
@ -102,5 +245,33 @@ class DocumentRag:
if self.verbose:
logger.debug("Query processing complete")
# Emit synthesis explainability after answer generated
if explain_callback:
synthesis_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian
if save_answer_callback and answer_text:
synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer"
try:
await save_answer_callback(synthesis_doc_id, answer_text)
if self.verbose:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None
syn_triples = set_graph(
docrag_synthesis_triples(
syn_uri, exp_uri,
document_id=synthesis_doc_id,
),
GRAPH_RETRIEVAL
)
await explain_callback(syn_triples, syn_uri)
if self.verbose:
logger.debug(f"Emitted explain for session {session_id}")
return resp

View file

@ -4,17 +4,30 @@ Simple RAG service, performs query using document RAG an LLM.
Input is query, output is response.
"""
import asyncio
import base64
import logging
import uuid
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples, Metadata
from ... provenance import GRAPH_RETRIEVAL
from . document_rag import DocumentRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import DocumentEmbeddingsClientSpec
from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "document-rag"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor):
@ -69,6 +82,161 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
ProducerSpec(
name = "explainability",
schema = Triples,
)
)
# Librarian client for fetching chunk content from Garage
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def fetch_chunk_content(self, chunk_id, user, timeout=120):
"""Fetch chunk content from librarian/Garage."""
import uuid
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=chunk_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
# Content is base64 encoded
content = response.content
if isinstance(content, str):
content = content.encode('utf-8')
return base64.b64decode(content).decode("utf-8")
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching chunk {chunk_id}")
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or "DocumentRAG Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_request(self, msg, consumer, flow):
try:
@ -77,6 +245,7 @@ class Processor(FlowProcessor):
embeddings_client = flow("embeddings-request"),
doc_embeddings_client = flow("document-embeddings-request"),
prompt_client = flow("prompt-request"),
fetch_chunk = self.fetch_chunk_content,
verbose=True,
)
@ -92,6 +261,39 @@ class Processor(FlowProcessor):
else:
doc_limit = self.doc_limit
# Real-time explainability callback - emits triples and IDs as they're generated
# Triples are stored in the user's collection with a named graph (urn:graph:retrieval)
async def send_explainability(triples, explain_id):
# Send triples to explainability queue - stores in same collection with named graph
await flow("explainability").send(Triples(
metadata=Metadata(
id=explain_id,
user=v.user,
collection=v.collection, # Store in user's collection
),
triples=triples,
))
# Send explain ID and graph to response queue
await flow("response").send(
DocumentRagResponse(
response=None,
explain_id=explain_id,
explain_graph=GRAPH_RETRIEVAL,
message_type="explain",
),
properties={"id": id}
)
# Callback to save answer content to librarian
async def save_answer(doc_id, answer_text):
await self.save_answer_content(
doc_id=doc_id,
user=v.user,
content=answer_text,
title=f"DocumentRAG Answer: {v.query[:50]}...",
)
# Check if streaming is requested
if v.streaming:
# Define async callback for streaming chunks
@ -101,6 +303,7 @@ class Processor(FlowProcessor):
DocumentRagResponse(
response=chunk,
end_of_stream=end_of_stream,
message_type="chunk",
error=None
),
properties={"id": id}
@ -115,6 +318,18 @@ class Processor(FlowProcessor):
doc_limit=doc_limit,
streaming=True,
chunk_callback=send_chunk,
explain_callback=send_explainability,
save_answer_callback=save_answer,
)
# Send end_of_session to signal entire session is complete
await flow("response").send(
DocumentRagResponse(
response=None,
end_of_session=True,
message_type="end",
),
properties={"id": id}
)
else:
# Non-streaming path (existing behavior)
@ -122,7 +337,9 @@ class Processor(FlowProcessor):
v.query,
user=v.user,
collection=v.collection,
doc_limit=doc_limit
doc_limit=doc_limit,
explain_callback=send_explainability,
save_answer_callback=save_answer,
)
await flow("response").send(

View file

@ -1,14 +1,57 @@
import asyncio
import hashlib
import json
import logging
import time
import uuid
from collections import OrderedDict
from datetime import datetime
from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE
# Provenance imports
from trustgraph.provenance import (
question_uri,
grounding_uri as make_grounding_uri,
exploration_uri as make_exploration_uri,
focus_uri as make_focus_uri,
synthesis_uri as make_synthesis_uri,
question_triples,
grounding_triples,
exploration_triples,
focus_triples,
synthesis_triples,
set_graph,
GRAPH_RETRIEVAL, GRAPH_SOURCE,
TG_CONTAINS, PROV_WAS_DERIVED_FROM,
)
# Module logger
logger = logging.getLogger(__name__)
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
def term_to_string(term):
"""Extract string value from a Term object."""
if term is None:
return None
if term.type == IRI:
return term.iri
elif term.type == LITERAL:
return term.value
# Fallback
return term.iri or term.value or str(term)
def edge_id(s, p, o):
"""Generate an 8-character hash ID for an edge (s, p, o)."""
edge_str = f"{s}|{p}|{o}"
return hashlib.sha256(edge_str.encode()).hexdigest()[:8]
class LRUCacheWithTTL:
"""LRU cache with TTL for label caching
@ -67,12 +110,32 @@ class Query:
self.max_subgraph_size = max_subgraph_size
self.max_path_length = max_path_length
async def get_vector(self, query):
async def extract_concepts(self, query):
"""Extract key concepts from query for independent embedding."""
response = await self.rag.prompt_client.prompt(
"extract-concepts",
variables={"query": query}
)
concepts = []
if isinstance(response, str):
for line in response.strip().split('\n'):
line = line.strip()
if line:
concepts.append(line)
if self.verbose:
logger.debug(f"Extracted concepts: {concepts}")
# Fall back to raw query if extraction returns nothing
return concepts if concepts else [query]
async def get_vectors(self, concepts):
"""Embed multiple concepts concurrently."""
if self.verbose:
logger.debug("Computing embeddings...")
qembeds = await self.rag.embeddings_client.embed(query)
qembeds = await self.rag.embeddings_client.embed(concepts)
if self.verbose:
logger.debug("Done.")
@ -80,28 +143,55 @@ class Query:
return qembeds
async def get_entities(self, query):
"""
Extract concepts from query, embed them, and retrieve matching entities.
vectors = await self.get_vector(query)
Returns:
tuple: (entities, concepts) where entities is a list of entity URI
strings and concepts is the list of concept strings extracted
from the query.
"""
concepts = await self.extract_concepts(query)
vectors = await self.get_vectors(concepts)
if self.verbose:
logger.debug("Getting entities...")
entities = await self.rag.graph_embeddings_client.query(
vectors=vectors, limit=self.entity_limit,
user=self.user, collection=self.collection,
# Query entity matches for each concept concurrently
per_concept_limit = max(
1, self.entity_limit // len(vectors)
)
entities = [
str(e)
for e in entities
entity_tasks = [
self.rag.graph_embeddings_client.query(
vector=v, limit=per_concept_limit,
user=self.user, collection=self.collection,
)
for v in vectors
]
results = await asyncio.gather(*entity_tasks, return_exceptions=True)
# Deduplicate while preserving order
seen = set()
entities = []
for result in results:
if isinstance(result, Exception) or not result:
continue
for e in result:
entity = term_to_string(e.entity)
if entity not in seen:
seen.add(entity)
entities.append(entity)
if self.verbose:
logger.debug("Entities:")
for ent in entities:
logger.debug(f" {ent}")
return entities
return entities, concepts
async def maybe_label(self, e):
@ -117,6 +207,7 @@ class Query:
res = await self.rag.triples_client.query(
s=e, p=LABEL, o=None, limit=1,
user=self.user, collection=self.collection,
g="",
)
if len(res) == 0:
@ -128,26 +219,29 @@ class Query:
return label
async def execute_batch_triple_queries(self, entities, limit_per_entity):
"""Execute triple queries for multiple entities concurrently"""
"""Execute triple queries for multiple entities concurrently using streaming"""
tasks = []
for entity in entities:
# Create concurrent tasks for all 3 query types per entity
# Create concurrent streaming tasks for all 3 query types per entity
tasks.extend([
self.rag.triples_client.query(
self.rag.triples_client.query_stream(
s=entity, p=None, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection
user=self.user, collection=self.collection,
batch_size=20, g="",
),
self.rag.triples_client.query(
self.rag.triples_client.query_stream(
s=None, p=entity, o=None,
limit=limit_per_entity,
user=self.user, collection=self.collection
user=self.user, collection=self.collection,
batch_size=20, g="",
),
self.rag.triples_client.query(
self.rag.triples_client.query_stream(
s=None, p=None, o=entity,
limit=limit_per_entity,
user=self.user, collection=self.collection
user=self.user, collection=self.collection,
batch_size=20, g="",
)
])
@ -157,7 +251,7 @@ class Query:
# Combine all results
all_triples = []
for result in results:
if not isinstance(result, Exception):
if not isinstance(result, Exception) and result is not None:
all_triples.extend(result)
return all_triples
@ -220,8 +314,16 @@ class Query:
subgraph.update(batch_result)
async def get_subgraph(self, query):
"""
Get subgraph by extracting concepts, finding entities, and traversing.
entities = await self.get_entities(query)
Returns:
tuple: (subgraph, entities, concepts) where subgraph is a list of
(s, p, o) tuples, entities is the seed entity list, and concepts
is the extracted concept list.
"""
entities, concepts = await self.get_entities(query)
if self.verbose:
logger.debug("Getting subgraph...")
@ -229,7 +331,7 @@ class Query:
# Use optimized batch traversal instead of sequential processing
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
return list(subgraph)
return list(subgraph), entities, concepts
async def resolve_labels_batch(self, entities):
"""Resolve labels for multiple entities in parallel"""
@ -240,8 +342,17 @@ class Query:
return await asyncio.gather(*tasks, return_exceptions=True)
async def get_labelgraph(self, query):
"""
Get subgraph with labels resolved for display.
subgraph = await self.get_subgraph(query)
Returns:
tuple: (labeled_edges, uri_map, entities, concepts) where:
- labeled_edges: list of (label_s, label_p, label_o) tuples
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
- entities: list of seed entity URI strings
- concepts: list of concept strings extracted from query
"""
subgraph, entities, concepts = await self.get_subgraph(query)
# Filter out label triples
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
@ -263,28 +374,151 @@ class Query:
else:
label_map[entity] = entity # Fallback to entity itself
# Apply labels to subgraph
sg2 = []
# Apply labels to subgraph and build URI mapping
labeled_edges = []
uri_map = {} # Maps edge_id of labeled edge -> original URI triple
for s, p, o in filtered_subgraph:
labeled_triple = (
label_map.get(s, s),
label_map.get(p, p),
label_map.get(o, o)
)
sg2.append(labeled_triple)
labeled_edges.append(labeled_triple)
sg2 = sg2[0:self.max_subgraph_size]
# Map from labeled edge ID to original URIs
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
uri_map[labeled_eid] = (s, p, o)
labeled_edges = labeled_edges[0:self.max_subgraph_size]
if self.verbose:
logger.debug("Subgraph:")
for edge in sg2:
for edge in labeled_edges:
logger.debug(f" {str(edge)}")
if self.verbose:
logger.debug("Done.")
return sg2
return labeled_edges, uri_map, entities, concepts
async def trace_source_documents(self, edge_uris):
"""
Trace selected edges back to their source documents via provenance.
Follows the chain: edge subgraph (via tg:contains) chunk
page document (via prov:wasDerivedFrom), all in urn:graph:source.
Args:
edge_uris: List of (s, p, o) URI string tuples
Returns:
List of unique document titles
"""
# Step 1: Find subgraphs containing these edges via tg:contains
subgraph_tasks = []
for s, p, o in edge_uris:
quoted = Term(
type=TRIPLE,
triple=SchemaTriple(
s=Term(type=IRI, iri=s),
p=Term(type=IRI, iri=p),
o=Term(type=IRI, iri=o),
)
)
subgraph_tasks.append(
self.rag.triples_client.query(
s=None, p=TG_CONTAINS, o=quoted, limit=1,
user=self.user, collection=self.collection,
g=GRAPH_SOURCE,
)
)
subgraph_results = await asyncio.gather(
*subgraph_tasks, return_exceptions=True
)
# Collect unique subgraph URIs
subgraph_uris = set()
for result in subgraph_results:
if isinstance(result, Exception) or not result:
continue
for triple in result:
subgraph_uris.add(str(triple.s))
if not subgraph_uris:
return []
# Step 2: Walk prov:wasDerivedFrom chain to find documents
# Each level: query ?entity prov:wasDerivedFrom ?parent
# Stop when we find entities typed tg:Document
current_uris = subgraph_uris
doc_uris = set()
for depth in range(4): # Max depth: subgraph → chunk → page → doc
if not current_uris:
break
derivation_tasks = [
self.rag.triples_client.query(
s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5,
user=self.user, collection=self.collection,
g=GRAPH_SOURCE,
)
for uri in current_uris
]
derivation_results = await asyncio.gather(
*derivation_tasks, return_exceptions=True
)
# URIs with no parent are root documents
next_uris = set()
for uri, result in zip(current_uris, derivation_results):
if isinstance(result, Exception) or not result:
doc_uris.add(uri)
continue
for triple in result:
next_uris.add(str(triple.o))
current_uris = next_uris - doc_uris
if not doc_uris:
return []
# Step 3: Get all document metadata properties
# Skip structural predicates that aren't useful context
SKIP_PREDICATES = {
PROV_WAS_DERIVED_FROM,
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
}
metadata_tasks = [
self.rag.triples_client.query(
s=uri, p=None, o=None, limit=50,
user=self.user, collection=self.collection,
)
for uri in doc_uris
]
metadata_results = await asyncio.gather(
*metadata_tasks, return_exceptions=True
)
doc_edges = []
for result in metadata_results:
if isinstance(result, Exception) or not result:
continue
for triple in result:
p = str(triple.p)
if p in SKIP_PREDICATES:
continue
doc_edges.append((
str(triple.s), p, str(triple.o)
))
return doc_edges
class GraphRag:
"""
CRITICAL SECURITY:
@ -316,12 +550,50 @@ class GraphRag:
async def query(
self, query, user = "trustgraph", collection = "default",
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
max_path_length = 2, streaming = False, chunk_callback = None,
max_path_length = 2, edge_limit = 25, streaming = False,
chunk_callback = None,
explain_callback = None, save_answer_callback = None,
):
"""
Execute a GraphRAG query with real-time explainability tracking.
Args:
query: The query string
user: User identifier
collection: Collection identifier
entity_limit: Max entities to retrieve
triple_limit: Max triples per entity
max_subgraph_size: Max edges in subgraph
max_path_length: Max hops from seed entities
streaming: Enable streaming LLM response
chunk_callback: async def callback(chunk, end_of_stream) for streaming
explain_callback: async def callback(triples, explain_id) for real-time explainability
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
Returns:
str: The synthesized answer text
"""
if self.verbose:
logger.debug("Constructing prompt...")
# Generate explainability URIs upfront
session_id = str(uuid.uuid4())
q_uri = question_uri(session_id)
gnd_uri = make_grounding_uri(session_id)
exp_uri = make_exploration_uri(session_id)
foc_uri = make_focus_uri(session_id)
syn_uri = make_synthesis_uri(session_id)
timestamp = datetime.utcnow().isoformat() + "Z"
# Emit question explainability immediately
if explain_callback:
q_triples = set_graph(
question_triples(q_uri, query, timestamp),
GRAPH_RETRIEVAL
)
await explain_callback(q_triples, q_uri)
q = Query(
rag = self, user = user, collection = collection,
verbose = self.verbose, entity_limit = entity_limit,
@ -330,24 +602,262 @@ class GraphRag:
max_path_length = max_path_length,
)
kg = await q.get_labelgraph(query)
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
# Emit grounding explain after concept extraction
if explain_callback:
gnd_triples = set_graph(
grounding_triples(gnd_uri, q_uri, concepts),
GRAPH_RETRIEVAL
)
await explain_callback(gnd_triples, gnd_uri)
# Emit exploration explain after graph retrieval completes
if explain_callback:
exp_triples = set_graph(
exploration_triples(
exp_uri, gnd_uri, len(kg),
entities=seed_entities,
),
GRAPH_RETRIEVAL
)
await explain_callback(exp_triples, exp_uri)
if self.verbose:
logger.debug("Invoking LLM...")
logger.debug(f"Knowledge graph: {kg}")
logger.debug(f"Query: {query}")
if streaming and chunk_callback:
resp = await self.prompt_client.kg_prompt(
query, kg,
streaming=True,
chunk_callback=chunk_callback
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
edge_map = {}
edges_with_ids = []
for s, p, o in kg:
eid = edge_id(s, p, o)
edge_map[eid] = (s, p, o)
edges_with_ids.append({
"id": eid,
"s": s,
"p": p,
"o": o
})
if self.verbose:
logger.debug(f"Built edge map with {len(edge_map)} edges")
# Step 1a: Edge Scoring - LLM scores edges for relevance
scoring_response = await self.prompt_client.prompt(
"kg-edge-scoring",
variables={
"query": query,
"knowledge": edges_with_ids
}
)
if self.verbose:
logger.debug(f"Edge scoring response: {scoring_response}")
# Parse scoring response to get edge IDs with scores
scored_edges = []
def parse_scored_edge(obj):
if isinstance(obj, dict) and "id" in obj and "score" in obj:
try:
score = int(obj["score"])
except (ValueError, TypeError):
score = 0
scored_edges.append({"id": obj["id"], "score": score})
if isinstance(scoring_response, list):
for obj in scoring_response:
parse_scored_edge(obj)
elif isinstance(scoring_response, str):
for line in scoring_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
parse_scored_edge(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge scoring line: {line}"
)
# Select top N edges by score
scored_edges.sort(key=lambda x: x["score"], reverse=True)
top_edges = scored_edges[:edge_limit]
selected_ids = {e["id"] for e in top_edges}
if self.verbose:
logger.debug(
f"Scored {len(scored_edges)} edges, "
f"selected top {len(selected_ids)}"
)
# Filter to selected edges
selected_edges = []
for eid in selected_ids:
if eid in edge_map:
selected_edges.append(edge_map[eid])
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
selected_edges_with_ids = [
{"id": eid, "s": s, "p": p, "o": o}
for eid in selected_ids
if eid in edge_map
for s, p, o in [edge_map[eid]]
]
# Collect selected edge URIs for document tracing
selected_edge_uris = [
uri_map[eid]
for eid in selected_ids
if eid in uri_map
]
# Run reasoning and document tracing concurrently
reasoning_task = self.prompt_client.prompt(
"kg-edge-reasoning",
variables={
"query": query,
"knowledge": selected_edges_with_ids
}
)
doc_trace_task = q.trace_source_documents(selected_edge_uris)
reasoning_response, source_documents = await asyncio.gather(
reasoning_task, doc_trace_task, return_exceptions=True
)
# Handle exceptions from gather
if isinstance(reasoning_response, Exception):
logger.warning(
f"Edge reasoning failed: {reasoning_response}"
)
reasoning_response = ""
if isinstance(source_documents, Exception):
logger.warning(
f"Document tracing failed: {source_documents}"
)
source_documents = []
if self.verbose:
logger.debug(f"Edge reasoning response: {reasoning_response}")
# Parse reasoning response and build explainability data
reasoning_map = {}
def parse_reasoning(obj):
if isinstance(obj, dict) and "id" in obj:
reasoning_map[obj["id"]] = obj.get("reasoning", "")
if isinstance(reasoning_response, list):
for obj in reasoning_response:
parse_reasoning(obj)
elif isinstance(reasoning_response, str):
for line in reasoning_response.strip().split('\n'):
line = line.strip()
if not line:
continue
try:
parse_reasoning(json.loads(line))
except json.JSONDecodeError:
logger.warning(
f"Failed to parse edge reasoning line: {line}"
)
selected_edges_with_reasoning = []
for eid in selected_ids:
if eid in uri_map:
uri_s, uri_p, uri_o = uri_map[eid]
selected_edges_with_reasoning.append({
"edge": (uri_s, uri_p, uri_o),
"reasoning": reasoning_map.get(eid, ""),
})
if self.verbose:
logger.debug(f"Filtered to {len(selected_edges)} edges")
# Emit focus explain after edge selection completes
if explain_callback:
foc_triples = set_graph(
focus_triples(
foc_uri, exp_uri, selected_edges_with_reasoning, session_id
),
GRAPH_RETRIEVAL
)
await explain_callback(foc_triples, foc_uri)
# Step 2: Synthesis - LLM generates answer from selected edges only
selected_edge_dicts = [
{"s": s, "p": p, "o": o}
for s, p, o in selected_edges
]
# Add source document metadata as knowledge edges
for s, p, o in source_documents:
selected_edge_dicts.append({
"s": s, "p": p, "o": o,
})
synthesis_variables = {
"query": query,
"knowledge": selected_edge_dicts,
}
if streaming and chunk_callback:
# Accumulate chunks for answer storage while forwarding to callback
accumulated_chunks = []
async def accumulating_callback(chunk, end_of_stream):
accumulated_chunks.append(chunk)
await chunk_callback(chunk, end_of_stream)
await self.prompt_client.prompt(
"kg-synthesis",
variables=synthesis_variables,
streaming=True,
chunk_callback=accumulating_callback
)
# Combine all chunks into full response
resp = "".join(accumulated_chunks)
else:
resp = await self.prompt_client.kg_prompt(query, kg)
resp = await self.prompt_client.prompt(
"kg-synthesis",
variables=synthesis_variables,
)
if self.verbose:
logger.debug("Query processing complete")
# Emit synthesis explain after synthesis completes
if explain_callback:
synthesis_doc_id = None
answer_text = resp if resp else ""
# Save answer to librarian
if save_answer_callback and answer_text:
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
try:
await save_answer_callback(synthesis_doc_id, answer_text)
if self.verbose:
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
except Exception as e:
logger.warning(f"Failed to save answer to librarian: {e}")
synthesis_doc_id = None
syn_triples = set_graph(
synthesis_triples(
syn_uri, foc_uri,
document_id=synthesis_doc_id,
),
GRAPH_RETRIEVAL
)
await explain_callback(syn_triples, syn_uri)
if self.verbose:
logger.debug(f"Emitted explain for session {session_id}")
return resp

View file

@ -4,18 +4,29 @@ Simple RAG service, performs query using graph RAG an LLM.
Input is query, output is response.
"""
import asyncio
import base64
import logging
import uuid
from ... schema import GraphRagQuery, GraphRagResponse, Error
from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... provenance import GRAPH_RETRIEVAL
from . graph_rag import GraphRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
# Module logger
logger = logging.getLogger(__name__)
default_ident = "graph-rag"
default_concurrency = 1
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor):
@ -28,6 +39,7 @@ class Processor(FlowProcessor):
triple_limit = params.get("triple_limit", 30)
max_subgraph_size = params.get("max_subgraph_size", 150)
max_path_length = params.get("max_path_length", 2)
edge_limit = params.get("edge_limit", 25)
super(Processor, self).__init__(
**params | {
@ -37,6 +49,7 @@ class Processor(FlowProcessor):
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
"max_path_length": max_path_length,
"edge_limit": edge_limit,
}
)
@ -44,6 +57,7 @@ class Processor(FlowProcessor):
self.default_triple_limit = triple_limit
self.default_max_subgraph_size = max_subgraph_size
self.default_max_path_length = max_path_length
self.default_edge_limit = edge_limit
# CRITICAL SECURITY: NEVER share data between users or collections
# Each user/collection combination MUST have isolated data access
@ -93,10 +107,163 @@ class Processor(FlowProcessor):
)
)
self.register_specification(
ProducerSpec(
name = "explainability",
schema = Triples,
)
)
# Librarian client for storing answer content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_librarian_requests = {}
logger.info("Graph RAG service initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
else:
logger.warning(f"Received unexpected librarian response: {request_id}")
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
user: User ID
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or "GraphRAG Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_request(self, msg, consumer, flow):
try:
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.info(f"Handling input {id}...")
# Track explainability refs for end_of_session signaling
explainability_refs_emitted = []
# Real-time explainability callback - emits triples and IDs as they're generated
# Triples are stored in the user's collection with a named graph (urn:graph:retrieval)
async def send_explainability(triples, explain_id):
# Send triples to explainability queue - stores in same collection with named graph
await flow("explainability").send(Triples(
metadata=Metadata(
id=explain_id,
user=v.user,
collection=v.collection, # Store in user's collection, not separate explainability collection
),
triples=triples,
))
# Send explain ID and graph to response queue
await flow("response").send(
GraphRagResponse(
message_type="explain",
explain_id=explain_id,
explain_graph=GRAPH_RETRIEVAL,
),
properties={"id": id}
)
explainability_refs_emitted.append(explain_id)
# CRITICAL SECURITY: Create new GraphRag instance per request
# This ensures proper isolation between users and collections
# Flow clients are request-scoped and must not be shared
@ -108,13 +275,6 @@ class Processor(FlowProcessor):
verbose=True,
)
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
logger.info(f"Handling input {id}...")
if v.entity_limit:
entity_limit = v.entity_limit
else:
@ -135,6 +295,20 @@ class Processor(FlowProcessor):
else:
max_path_length = self.default_max_path_length
if v.edge_limit:
edge_limit = v.edge_limit
else:
edge_limit = self.default_edge_limit
# Callback to save answer content to librarian
async def save_answer(doc_id, answer_text):
await self.save_answer_content(
doc_id=doc_id,
user=v.user,
content=answer_text,
title=f"GraphRAG Answer: {v.query[:50]}...",
)
# Check if streaming is requested
if v.streaming:
# Define async callback for streaming chunks
@ -142,6 +316,7 @@ class Processor(FlowProcessor):
async def send_chunk(chunk, end_of_stream):
await flow("response").send(
GraphRagResponse(
message_type="chunk",
response=chunk,
end_of_stream=end_of_stream,
error=None
@ -149,34 +324,52 @@ class Processor(FlowProcessor):
properties={"id": id}
)
# Query with streaming enabled
# All chunks (including final one with end_of_stream=True) are sent via callback
await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
streaming = True,
chunk_callback = send_chunk,
)
else:
# Non-streaming path (existing behavior)
# Query with streaming and real-time explain
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_limit = edge_limit,
streaming = True,
chunk_callback = send_chunk,
explain_callback = send_explainability,
save_answer_callback = save_answer,
)
else:
# Non-streaming path with real-time explain
response = await rag.query(
query = v.query, user = v.user, collection = v.collection,
entity_limit = entity_limit, triple_limit = triple_limit,
max_subgraph_size = max_subgraph_size,
max_path_length = max_path_length,
edge_limit = edge_limit,
explain_callback = send_explainability,
save_answer_callback = save_answer,
)
# Send chunk with response
await flow("response").send(
GraphRagResponse(
response = response,
end_of_stream = True,
error = None
message_type="chunk",
response=response,
end_of_stream=True,
error=None,
),
properties = {"id": id}
properties={"id": id}
)
# Send final message to close session
await flow("response").send(
GraphRagResponse(
message_type="chunk",
response="",
end_of_session=True,
),
properties={"id": id}
)
logger.info("Request processing complete")
except Exception as e:
@ -185,22 +378,18 @@ class Processor(FlowProcessor):
logger.debug("Sending error response...")
# Send error response with end_of_stream flag if streaming was requested
error_response = GraphRagResponse(
response = None,
error = Error(
type = "graph-rag-error",
message = str(e),
),
)
# If streaming was requested, indicate stream end
if v.streaming:
error_response.end_of_stream = True
# Send error response and close session
await flow("response").send(
error_response,
properties = {"id": id}
GraphRagResponse(
message_type="chunk",
error=Error(
type="graph-rag-error",
message=str(e),
),
end_of_stream=True,
end_of_session=True,
),
properties={"id": id}
)
@staticmethod
@ -243,6 +432,9 @@ class Processor(FlowProcessor):
help=f'Default max path length (default: 2)'
)
# Note: Explainability triples are now stored in the user's collection
# with the named graph urn:graph:retrieval (no separate collection needed)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -37,14 +37,14 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue
chunk_id = emb.chunk_id
if chunk_id == "":
continue
chunk = emb.chunk.decode("utf-8")
if chunk == "": continue
for vec in emb.vectors:
vec = emb.vector
if vec:
self.vecstore.insert(
vec, chunk,
vec, chunk_id,
message.metadata.user,
message.metadata.collection
)

View file

@ -101,40 +101,41 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for emb in message.chunks:
if emb.chunk is None or emb.chunk == b"": continue
chunk_id = emb.chunk_id
if chunk_id == "":
continue
chunk = emb.chunk.decode("utf-8")
if chunk == "": continue
vec = emb.vector
if not vec:
continue
for vec in emb.vectors:
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
# Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
index = self.pinecone.Index(index_name)
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
records = [
{
"id": vector_id,
"values": vec,
"metadata": { "chunk_id": chunk_id },
}
]
records = [
{
"id": vector_id,
"values": vec,
"metadata": { "doc": chunk },
}
]
index.upsert(
vectors = records,
)
index.upsert(
vectors = records,
)
@staticmethod
def add_args(parser):

View file

@ -52,41 +52,44 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
for emb in message.chunks:
chunk = emb.chunk.decode("utf-8")
if chunk == "": return
chunk_id = emb.chunk_id
if chunk_id == "":
continue
for vec in emb.vectors:
vec = emb.vector
if not vec:
continue
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
self.qdrant.upsert(
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"doc": chunk,
}
)
]
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
self.qdrant.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"chunk_id": chunk_id,
}
)
]
)
@staticmethod
def add_args(parser):

View file

@ -53,11 +53,13 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
entity_value = get_term_value(entity.entity)
if entity_value != "" and entity_value is not None:
for vec in entity.vectors:
vec = entity.vector
if vec:
self.vecstore.insert(
vec, entity_value,
message.metadata.user,
message.metadata.collection
message.metadata.collection,
chunk_id=entity.chunk_id or "",
)
@staticmethod

View file

@ -119,35 +119,41 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity_value == "" or entity_value is None:
continue
for vec in entity.vectors:
vec = entity.vector
if not vec:
continue
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Create index name with dimension suffix for lazy creation
dim = len(vec)
index_name = (
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
)
# Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
# Lazily create index if it doesn't exist (but only if authorized in config)
if not self.pinecone.has_index(index_name):
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
self.create_index(index_name, dim)
index = self.pinecone.Index(index_name)
index = self.pinecone.Index(index_name)
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
# Generate unique ID for each vector
vector_id = str(uuid.uuid4())
records = [
{
"id": vector_id,
"values": vec,
"metadata": { "entity": entity_value },
}
]
metadata = {"entity": entity_value}
if entity.chunk_id:
metadata["chunk_id"] = entity.chunk_id
index.upsert(
vectors = records,
)
records = [
{
"id": vector_id,
"values": vec,
"metadata": metadata,
}
]
index.upsert(
vectors = records,
)
@staticmethod
def add_args(parser):

View file

@ -71,38 +71,44 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
if entity_value == "" or entity_value is None:
continue
for vec in entity.vectors:
vec = entity.vector
if not vec:
continue
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Create collection name with dimension suffix for lazy creation
dim = len(vec)
collection = (
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
)
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
self.qdrant.upsert(
# Lazily create collection if it doesn't exist (but only if authorized in config)
if not self.qdrant.collection_exists(collection):
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
self.qdrant.create_collection(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload={
"entity": entity_value,
}
)
]
vectors_config=VectorParams(
size=dim,
distance=Distance.COSINE
)
)
payload = {
"entity": entity_value,
}
if entity.chunk_id:
payload["chunk_id"] = entity.chunk_id
self.qdrant.upsert(
collection_name=collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vec,
payload=payload,
)
]
)
@staticmethod
def add_args(parser):

View file

@ -133,39 +133,38 @@ class Processor(CollectionConfigHandler, FlowProcessor):
qdrant_collection = None
for row_emb in embeddings.embeddings:
if not row_emb.vectors:
vector = row_emb.vector
if not vector:
logger.warning(
f"No vectors for index {row_emb.index_name} - skipping"
f"No vector for index {row_emb.index_name} - skipping"
)
continue
# Use first vector (there may be multiple from different models)
for vector in row_emb.vectors:
dimension = len(vector)
dimension = len(vector)
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension
)
self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
# Create/get collection name (lazily on first vector)
if qdrant_collection is None:
qdrant_collection = self.get_collection_name(
user, collection, schema_name, dimension
)
embeddings_written += 1
self.ensure_collection(qdrant_collection, dimension)
# Write to Qdrant
self.qdrant.upsert(
collection_name=qdrant_collection,
points=[
PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"index_name": row_emb.index_name,
"index_value": row_emb.index_value,
"text": row_emb.text
}
)
]
)
embeddings_written += 1
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")

View file

@ -9,6 +9,7 @@ import os
import argparse
import time
import logging
import json
from .... direct.cassandra_kg import (
EntityCentricKnowledgeGraph, DEFAULT_GRAPH
@ -25,6 +26,37 @@ logger = logging.getLogger(__name__)
default_ident = "triples-write"
def serialize_triple(triple):
"""Serialize a Triple object to JSON for storage."""
if triple is None:
return None
def term_to_dict(term):
if term is None:
return None
result = {"type": term.type}
if term.type == IRI:
result["iri"] = term.iri
elif term.type == LITERAL:
result["value"] = term.value
if term.datatype:
result["datatype"] = term.datatype
if term.language:
result["language"] = term.language
elif term.type == BLANK:
result["id"] = term.id
elif term.type == TRIPLE:
result["triple"] = serialize_triple(term.triple)
return result
return json.dumps({
"s": term_to_dict(triple.s),
"p": term_to_dict(triple.p),
"o": term_to_dict(triple.o),
})
def get_term_value(term):
"""Extract the string value from a Term"""
if term is None:
@ -33,6 +65,9 @@ def get_term_value(term):
return term.iri
elif term.type == LITERAL:
return term.value
elif term.type == TRIPLE:
# Serialize nested triple as JSON
return serialize_triple(term.triple)
else:
# For blank nodes or other types, use id or value
return term.id or term.value

View file

@ -114,7 +114,7 @@ class KnowledgeTableStore:
entity_embeddings list<
tuple<
tuple<text, boolean>,
list<list<double>>
list<double>
>
>,
PRIMARY KEY ((user, document_id), id)
@ -140,7 +140,7 @@ class KnowledgeTableStore:
chunks list<
tuple<
blob,
list<list<double>>
list<double>
>
>,
PRIMARY KEY ((user, document_id), id)
@ -218,16 +218,6 @@ class KnowledgeTableStore:
when = int(time.time() * 1000)
if m.metadata.metadata:
metadata = [
(
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
)
for v in m.metadata.metadata
]
else:
metadata = []
triples = [
(
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
@ -243,8 +233,8 @@ class KnowledgeTableStore:
self.insert_triples_stmt,
(
uuid.uuid4(), m.metadata.user,
m.metadata.id, when,
metadata, triples,
m.metadata.root or m.metadata.id, when,
[], triples,
)
)
@ -259,20 +249,10 @@ class KnowledgeTableStore:
when = int(time.time() * 1000)
if m.metadata.metadata:
metadata = [
(
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
)
for v in m.metadata.metadata
]
else:
metadata = []
entities = [
(
term_to_tuple(v.entity),
v.vectors
v.vector
)
for v in m.entities
]
@ -285,8 +265,8 @@ class KnowledgeTableStore:
self.insert_graph_embeddings_stmt,
(
uuid.uuid4(), m.metadata.user,
m.metadata.id, when,
metadata, entities,
m.metadata.root or m.metadata.id, when,
[], entities,
)
)
@ -301,20 +281,10 @@ class KnowledgeTableStore:
when = int(time.time() * 1000)
if m.metadata.metadata:
metadata = [
(
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
)
for v in m.metadata.metadata
]
else:
metadata = []
chunks = [
(
v.chunk,
v.vectors,
v.chunk_id,
v.vector,
)
for v in m.chunks
]
@ -327,8 +297,8 @@ class KnowledgeTableStore:
self.insert_document_embeddings_stmt,
(
uuid.uuid4(), m.metadata.user,
m.metadata.id, when,
metadata, chunks,
m.metadata.root or m.metadata.id, when,
[], chunks,
)
)
@ -423,18 +393,6 @@ class KnowledgeTableStore:
for row in resp:
if row[2]:
metadata = [
Triple(
s = tuple_to_term(elt[0], elt[1]),
p = tuple_to_term(elt[2], elt[3]),
o = tuple_to_term(elt[4], elt[5]),
)
for elt in row[2]
]
else:
metadata = []
if row[3]:
triples = [
Triple(
@ -453,7 +411,6 @@ class KnowledgeTableStore:
id = document_id,
user = user,
collection = "default", # FIXME: What to put here?
metadata = metadata,
),
triples = triples
)
@ -482,18 +439,6 @@ class KnowledgeTableStore:
for row in resp:
if row[2]:
metadata = [
Triple(
s = tuple_to_term(elt[0], elt[1]),
p = tuple_to_term(elt[2], elt[3]),
o = tuple_to_term(elt[4], elt[5]),
)
for elt in row[2]
]
else:
metadata = []
if row[3]:
entities = [
EntityEmbeddings(
@ -511,7 +456,6 @@ class KnowledgeTableStore:
id = document_id,
user = user,
collection = "default", # FIXME: What to put here?
metadata = metadata,
),
entities = entities
)

View file

@ -112,6 +112,34 @@ class LibraryTableStore:
ON document (object_id)
""");
# Add parent_id and document_type columns for child document support
logger.debug("document table parent_id column...")
try:
self.cassandra.execute("""
ALTER TABLE document ADD parent_id text
""");
except Exception as e:
# Column may already exist
if "already exists" not in str(e).lower() and "Invalid column name" not in str(e):
logger.debug(f"parent_id column may already exist: {e}")
try:
self.cassandra.execute("""
ALTER TABLE document ADD document_type text
""");
except Exception as e:
# Column may already exist
if "already exists" not in str(e).lower() and "Invalid column name" not in str(e):
logger.debug(f"document_type column may already exist: {e}")
logger.debug("document parent index...")
self.cassandra.execute("""
CREATE INDEX IF NOT EXISTS document_parent
ON document (parent_id)
""");
logger.debug("processing table...")
self.cassandra.execute("""
@ -127,6 +155,32 @@ class LibraryTableStore:
);
""");
logger.debug("upload_session table...")
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS upload_session (
upload_id text PRIMARY KEY,
user text,
document_id text,
document_metadata text,
s3_upload_id text,
object_id uuid,
total_size bigint,
chunk_size int,
total_chunks int,
chunks_received map<int, text>,
created_at timestamp,
updated_at timestamp
) WITH default_time_to_live = 86400;
""");
logger.debug("upload_session user index...")
self.cassandra.execute("""
CREATE INDEX IF NOT EXISTS upload_session_user
ON upload_session (user)
""");
logger.info("Cassandra schema OK.")
def prepare_statements(self):
@ -136,9 +190,10 @@ class LibraryTableStore:
(
id, user, time,
kind, title, comments,
metadata, tags, object_id
metadata, tags, object_id,
parent_id, document_type
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""")
self.update_document_stmt = self.cassandra.prepare("""
@ -149,7 +204,8 @@ class LibraryTableStore:
""")
self.get_document_stmt = self.cassandra.prepare("""
SELECT time, kind, title, comments, metadata, tags, object_id
SELECT time, kind, title, comments, metadata, tags, object_id,
parent_id, document_type
FROM document
WHERE user = ? AND id = ?
""")
@ -168,14 +224,16 @@ class LibraryTableStore:
self.list_document_stmt = self.cassandra.prepare("""
SELECT
id, time, kind, title, comments, metadata, tags, object_id
id, time, kind, title, comments, metadata, tags, object_id,
parent_id, document_type
FROM document
WHERE user = ?
""")
self.list_document_by_tag_stmt = self.cassandra.prepare("""
SELECT
id, time, kind, title, comments, metadata, tags, object_id
id, time, kind, title, comments, metadata, tags, object_id,
parent_id, document_type
FROM document
WHERE user = ? AND tags CONTAINS ?
ALLOW FILTERING
@ -210,6 +268,57 @@ class LibraryTableStore:
WHERE user = ?
""")
# Upload session prepared statements
self.insert_upload_session_stmt = self.cassandra.prepare("""
INSERT INTO upload_session
(
upload_id, user, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size,
total_chunks, chunks_received, created_at, updated_at
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""")
self.get_upload_session_stmt = self.cassandra.prepare("""
SELECT
upload_id, user, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size,
total_chunks, chunks_received, created_at, updated_at
FROM upload_session
WHERE upload_id = ?
""")
self.update_upload_session_chunk_stmt = self.cassandra.prepare("""
UPDATE upload_session
SET chunks_received = chunks_received + ?,
updated_at = ?
WHERE upload_id = ?
""")
self.delete_upload_session_stmt = self.cassandra.prepare("""
DELETE FROM upload_session
WHERE upload_id = ?
""")
self.list_upload_sessions_stmt = self.cassandra.prepare("""
SELECT
upload_id, document_id, document_metadata,
total_size, chunk_size, total_chunks,
chunks_received, created_at, updated_at
FROM upload_session
WHERE user = ?
""")
# Child document queries
self.list_children_stmt = self.cassandra.prepare("""
SELECT
id, user, time, kind, title, comments, metadata, tags,
object_id, parent_id, document_type
FROM document
WHERE parent_id = ?
ALLOW FILTERING
""")
async def document_exists(self, user, id):
resp = self.cassandra.execute(
@ -236,6 +345,10 @@ class LibraryTableStore:
for v in document.metadata
]
# Get parent_id and document_type from document, defaulting if not set
parent_id = getattr(document, 'parent_id', '') or ''
document_type = getattr(document, 'document_type', 'source') or 'source'
while True:
try:
@ -245,7 +358,8 @@ class LibraryTableStore:
(
document.id, document.user, int(document.time * 1000),
document.kind, document.title, document.comments,
metadata, document.tags, object_id
metadata, document.tags, object_id,
parent_id, document_type
)
)
@ -349,9 +463,58 @@ class LibraryTableStore:
p=tuple_to_term(m[2], m[3]),
o=tuple_to_term(m[4], m[5])
)
for m in row[5]
for m in (row[5] or [])
],
tags = row[6] if row[6] else [],
parent_id = row[8] if row[8] else "",
document_type = row[9] if row[9] else "source",
)
for row in resp
]
logger.debug("Done")
return lst
async def list_children(self, parent_id):
"""List all child documents for a given parent document ID."""
logger.debug(f"List children for parent {parent_id}")
while True:
try:
resp = self.cassandra.execute(
self.list_children_stmt,
(parent_id,)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
lst = [
DocumentMetadata(
id = row[0],
user = row[1],
time = int(time.mktime(row[2].timetuple())),
kind = row[3],
title = row[4],
comments = row[5],
metadata = [
Triple(
s=tuple_to_term(m[0], m[1]),
p=tuple_to_term(m[2], m[3]),
o=tuple_to_term(m[4], m[5])
)
for m in (row[6] or [])
],
tags = row[7] if row[7] else [],
parent_id = row[9] if row[9] else "",
document_type = row[10] if row[10] else "source",
)
for row in resp
]
@ -394,9 +557,11 @@ class LibraryTableStore:
p=tuple_to_term(m[2], m[3]),
o=tuple_to_term(m[4], m[5])
)
for m in row[4]
for m in (row[4] or [])
],
tags = row[5] if row[5] else [],
parent_id = row[7] if row[7] else "",
document_type = row[8] if row[8] else "source",
)
logger.debug("Done")
@ -532,3 +697,152 @@ class LibraryTableStore:
logger.debug("Done")
return lst
# Upload session methods
async def create_upload_session(
self,
upload_id,
user,
document_id,
document_metadata,
s3_upload_id,
object_id,
total_size,
chunk_size,
total_chunks,
):
"""Create a new upload session for chunked upload."""
logger.info(f"Creating upload session {upload_id}")
now = int(time.time() * 1000)
while True:
try:
self.cassandra.execute(
self.insert_upload_session_stmt,
(
upload_id, user, document_id, document_metadata,
s3_upload_id, object_id, total_size, chunk_size,
total_chunks, {}, now, now
)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
logger.debug("Upload session created")
async def get_upload_session(self, upload_id):
"""Get an upload session by ID."""
logger.debug(f"Get upload session {upload_id}")
while True:
try:
resp = self.cassandra.execute(
self.get_upload_session_stmt,
(upload_id,)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
for row in resp:
session = {
"upload_id": row[0],
"user": row[1],
"document_id": row[2],
"document_metadata": row[3],
"s3_upload_id": row[4],
"object_id": row[5],
"total_size": row[6],
"chunk_size": row[7],
"total_chunks": row[8],
"chunks_received": row[9] if row[9] else {},
"created_at": row[10],
"updated_at": row[11],
}
logger.debug("Done")
return session
return None
async def update_upload_session_chunk(self, upload_id, chunk_index, etag):
"""Record a successfully uploaded chunk."""
logger.debug(f"Update upload session {upload_id} chunk {chunk_index}")
now = int(time.time() * 1000)
while True:
try:
self.cassandra.execute(
self.update_upload_session_chunk_stmt,
(
{chunk_index: etag},
now,
upload_id
)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
logger.debug("Chunk recorded")
async def delete_upload_session(self, upload_id):
"""Delete an upload session."""
logger.info(f"Deleting upload session {upload_id}")
while True:
try:
self.cassandra.execute(
self.delete_upload_session_stmt,
(upload_id,)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
logger.debug("Upload session deleted")
async def list_upload_sessions(self, user):
"""List all upload sessions for a user."""
logger.debug(f"List upload sessions for {user}")
while True:
try:
resp = self.cassandra.execute(
self.list_upload_sessions_stmt,
(user,)
)
break
except Exception as e:
logger.error("Exception occurred", exc_info=True)
raise e
sessions = []
for row in resp:
chunks_received = row[6] if row[6] else {}
sessions.append({
"upload_id": row[0],
"document_id": row[1],
"document_metadata": row[2],
"total_size": row[3],
"chunk_size": row[4],
"total_chunks": row[5],
"chunks_received": len(chunks_received),
"created_at": row[7],
"updated_at": row[8],
})
logger.debug("Done")
return sessions

View file

@ -0,0 +1 @@
# Tool service implementations

View file

@ -0,0 +1,2 @@
# Joke tool service
from .service import run

View file

@ -0,0 +1,204 @@
"""
Joke Tool Service - An example dynamic tool service.
This service demonstrates the tool service integration by:
- Using the 'user' field to personalize responses
- Using config params (style) to customize joke style
- Using arguments (topic) to generate topic-specific jokes
Example tool-service config:
{
"id": "joke-service",
"topic": "joke",
"config-params": [
{"name": "style", "required": false}
]
}
Example tool config:
{
"type": "tool-service",
"name": "tell-joke",
"description": "Tell a joke on a given topic",
"service": "joke-service",
"style": "pun",
"arguments": [
{
"name": "topic",
"type": "string",
"description": "The topic for the joke (e.g., programming, animals, food)"
}
]
}
"""
import random
import logging
from ... base import DynamicToolService
# Module logger
logger = logging.getLogger(__name__)
default_ident = "joke-service"
default_topic = "joke"
# Joke database organized by topic and style
JOKES = {
"programming": {
"pun": [
"Why do programmers prefer dark mode? Because light attracts bugs!",
"Why do Java developers wear glasses? Because they can't C#!",
"A SQL query walks into a bar, walks up to two tables and asks... 'Can I join you?'",
"Why was the JavaScript developer sad? Because he didn't Node how to Express himself!",
],
"dad-joke": [
"I told my computer I needed a break, and now it won't stop sending me Kit-Kat ads.",
"My son asked me to explain what a linked list is. I said 'I'll tell you, and then I'll tell you again, and again...'",
"I asked my computer for a joke about UDP. I'm not sure if it got it.",
],
"one-liner": [
"There are only 10 types of people: those who understand binary and those who don't.",
"A programmer's wife tells him: 'Go to the store and get a loaf of bread. If they have eggs, get a dozen.' He returns with 12 loaves.",
"99 little bugs in the code, 99 little bugs. Take one down, patch it around, 127 little bugs in the code.",
],
},
"llama": {
"pun": [
"Why did the llama get a ticket? Because he was caught spitting in a no-spitting zone!",
"What do you call a llama who's a great musician? A llama del Rey!",
"Why did the llama cross the road? To prove he wasn't a chicken!",
],
"dad-joke": [
"What did the llama say when he got kicked out of the zoo? 'Alpaca my bags!'",
"Why don't llamas ever get lost? Because they always know the way to the Andes!",
"What do you call a llama with no legs? A woolly rug!",
],
"one-liner": [
"Llamas are great at meditation. They're always saying 'Dalai Llama.'",
"I asked a llama for directions. He said 'No probllama!'",
"Never trust a llama. They're always up to something woolly.",
],
},
"animals": {
"pun": [
"What do you call a fish without eyes? A fsh!",
"Why don't scientists trust atoms? Because they make up everything... just like that cat who blamed the dog!",
"What do you call a bear with no teeth? A gummy bear!",
],
"dad-joke": [
"I tried to catch some fog earlier. I mist. My dog wasn't impressed either.",
"What do you call a dog that does magic tricks? A Labracadabrador!",
"Why do cows wear bells? Because their horns don't work!",
],
"one-liner": [
"I'm reading a book about anti-gravity. It's impossible to put down, unlike my cat.",
"A horse walks into a bar. The bartender asks 'Why the long face?'",
"What's orange and sounds like a parrot? A carrot!",
],
},
"food": {
"pun": [
"I'm on a seafood diet. I see food and I eat it!",
"Why did the tomato turn red? Because it saw the salad dressing!",
"What do you call cheese that isn't yours? Nacho cheese!",
],
"dad-joke": [
"I used to hate facial hair, but then it grew on me. Speaking of growing, have you tried my garden salad?",
"Why don't eggs tell jokes? They'd crack each other up!",
"I told my wife she was drawing her eyebrows too high. She looked surprised, then made me a sandwich.",
],
"one-liner": [
"I'm reading a book about submarines and sandwiches. It's a sub-genre.",
"Broken puppets for sale. No strings attached. Also, free spaghetti!",
"I ordered a chicken and an egg online. I'll let you know which comes first.",
],
},
"default": {
"pun": [
"Time flies like an arrow. Fruit flies like a banana!",
"I used to be a banker, but I lost interest.",
"I'm reading a book on the history of glue. I can't put it down!",
],
"dad-joke": [
"I'm afraid for the calendar. Its days are numbered.",
"I only know 25 letters of the alphabet. I don't know y.",
"Did you hear about the claustrophobic astronaut? He just needed a little space.",
],
"one-liner": [
"I told my wife she was drawing her eyebrows too high. She looked surprised.",
"I'm not lazy, I'm on energy-saving mode.",
"Parallel lines have so much in common. It's a shame they'll never meet.",
],
},
}
class Processor(DynamicToolService):
"""
Joke tool service that demonstrates the tool service integration.
"""
def __init__(self, **params):
super(Processor, self).__init__(**params)
logger.info("Joke service initialized")
async def invoke(self, user, config, arguments):
"""
Generate a joke based on the topic and style.
Args:
user: The user requesting the joke
config: Config values including 'style' (pun, dad-joke, one-liner)
arguments: Arguments including 'topic' (programming, animals, food)
Returns:
A personalized joke string
"""
# Get style from config (default: random)
style = config.get("style", random.choice(["pun", "dad-joke", "one-liner"]))
# Get topic from arguments (default: random)
topic = arguments.get("topic", "").lower()
# Map topic to our categories
if "program" in topic or "code" in topic or "computer" in topic or "software" in topic:
category = "programming"
elif "llama" in topic:
category = "llama"
elif "animal" in topic or "dog" in topic or "cat" in topic or "bird" in topic:
category = "animals"
elif "food" in topic or "eat" in topic or "cook" in topic or "drink" in topic:
category = "food"
else:
category = "default"
# Normalize style
if style not in ["pun", "dad-joke", "one-liner"]:
style = random.choice(["pun", "dad-joke", "one-liner"])
# Get jokes for this category and style
jokes = JOKES.get(category, JOKES["default"]).get(style, JOKES["default"]["pun"])
# Pick a random joke
joke = random.choice(jokes)
# Personalize the response
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
logger.debug(f"Generated joke for user={user}, style={style}, topic={topic}")
return response
@staticmethod
def add_args(parser):
DynamicToolService.add_args(parser)
# Override the topic default for this service
for action in parser._actions:
if '--topic' in action.option_strings:
action.default = default_topic
break
def run():
Processor.launch(default_ident, __doc__)