mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-02 14:31:01 +02:00
Merge branch 'release/v2.1'
This commit is contained in:
commit
824f993985
266 changed files with 33195 additions and 5834 deletions
|
|
@ -10,7 +10,7 @@ description = "TrustGraph provides a means to run a pipeline of flexible AI proc
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"trustgraph-base>=2.0,<2.1",
|
||||
"trustgraph-base>=2.1,<2.2",
|
||||
"aiohttp",
|
||||
"anthropic",
|
||||
"scylla-driver",
|
||||
|
|
@ -27,7 +27,7 @@ dependencies = [
|
|||
"langchain-text-splitters",
|
||||
"mcp",
|
||||
"minio",
|
||||
"mistralai",
|
||||
"mistralai<2.0.0",
|
||||
"neo4j",
|
||||
"nltk",
|
||||
"ollama",
|
||||
|
|
@ -122,6 +122,7 @@ triples-write-falkordb = "trustgraph.storage.triples.falkordb:run"
|
|||
triples-write-memgraph = "trustgraph.storage.triples.memgraph:run"
|
||||
triples-write-neo4j = "trustgraph.storage.triples.neo4j:run"
|
||||
wikipedia-lookup = "trustgraph.external.wikipedia:run"
|
||||
joke-service = "trustgraph.tool_service.joke:run"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["trustgraph*"]
|
||||
|
|
|
|||
|
|
@ -2,11 +2,15 @@
|
|||
Simple agent infrastructure broadly implements the ReAct flow.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import functools
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -14,10 +18,30 @@ logger = logging.getLogger(__name__)
|
|||
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
|
||||
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
|
||||
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
|
||||
from ... base import ProducerSpec
|
||||
from ... base import Consumer, Producer
|
||||
from ... base import ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ... schema import AgentRequest, AgentResponse, AgentStep, Error
|
||||
from ... schema import Triples, Metadata
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl
|
||||
# Provenance imports for agent explainability
|
||||
from trustgraph.provenance import (
|
||||
agent_session_uri,
|
||||
agent_iteration_uri,
|
||||
agent_thought_uri,
|
||||
agent_observation_uri,
|
||||
agent_final_uri,
|
||||
agent_session_triples,
|
||||
agent_iteration_triples,
|
||||
agent_final_triples,
|
||||
set_graph,
|
||||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
||||
from . tools import KnowledgeQueryImpl, TextCompletionImpl, McpToolImpl, PromptImpl, StructuredQueryImpl, RowEmbeddingsQueryImpl, ToolServiceImpl
|
||||
from . agent_manager import AgentManager
|
||||
from ..tool_filter import validate_tool_config, filter_tools_by_group_and_state, get_next_state
|
||||
|
||||
|
|
@ -25,6 +49,8 @@ from . types import Final, Action, Tool, Argument
|
|||
|
||||
default_ident = "agent-manager"
|
||||
default_max_iterations = 10
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
class Processor(AgentService):
|
||||
|
||||
|
|
@ -51,6 +77,9 @@ class Processor(AgentService):
|
|||
additional_context="",
|
||||
)
|
||||
|
||||
# Track active tool service clients for cleanup
|
||||
self.tool_service_clients = {}
|
||||
|
||||
self.config_handlers.append(self.on_tools_config)
|
||||
|
||||
self.register_specification(
|
||||
|
|
@ -102,6 +131,123 @@ class Processor(AgentService):
|
|||
)
|
||||
)
|
||||
|
||||
# Explainability producer for agent provenance triples
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "explainability",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
# Librarian client for storing answer content
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend=self.pubsub,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
backend=self.pubsub,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self.on_librarian_response,
|
||||
metrics=librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_librarian_requests = {}
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id in self.pending_librarian_requests:
|
||||
future = self.pending_librarian_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
|
||||
"""
|
||||
Save answer content to the librarian.
|
||||
|
||||
Args:
|
||||
doc_id: ID for the answer document
|
||||
user: User ID
|
||||
content: Answer text content
|
||||
title: Optional title
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
The document ID on success
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or "Agent Answer",
|
||||
document_type="answer",
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-document",
|
||||
document_id=doc_id,
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_librarian_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_librarian_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving answer document {doc_id}")
|
||||
|
||||
async def on_tools_config(self, config, version):
|
||||
|
||||
logger.info(f"Loading configuration version {version}")
|
||||
|
|
@ -110,6 +256,16 @@ class Processor(AgentService):
|
|||
|
||||
tools = {}
|
||||
|
||||
# Load tool-service configurations first
|
||||
tool_services = {}
|
||||
if "tool-service" in config:
|
||||
for service_id, service_value in config["tool-service"].items():
|
||||
service_data = json.loads(service_value)
|
||||
tool_services[service_id] = service_data
|
||||
logger.debug(f"Loaded tool-service config: {service_id}")
|
||||
|
||||
logger.info(f"Loaded {len(tool_services)} tool-service configurations")
|
||||
|
||||
# Load tool configurations from the new location
|
||||
if "tool" in config:
|
||||
for tool_id, tool_value in config["tool"].items():
|
||||
|
|
@ -177,6 +333,59 @@ class Processor(AgentService):
|
|||
limit=int(data.get("limit", 10)) # Max results
|
||||
)
|
||||
arguments = RowEmbeddingsQueryImpl.get_arguments()
|
||||
elif impl_id == "tool-service":
|
||||
# Dynamic tool service - look up the service config
|
||||
service_ref = data.get("service")
|
||||
if not service_ref:
|
||||
raise RuntimeError(
|
||||
f"Tool {name} has type 'tool-service' but no 'service' reference"
|
||||
)
|
||||
if service_ref not in tool_services:
|
||||
raise RuntimeError(
|
||||
f"Tool {name} references unknown tool-service '{service_ref}'"
|
||||
)
|
||||
|
||||
service_config = tool_services[service_ref]
|
||||
request_queue = service_config.get("request-queue")
|
||||
response_queue = service_config.get("response-queue")
|
||||
if not request_queue or not response_queue:
|
||||
raise RuntimeError(
|
||||
f"Tool-service '{service_ref}' must define 'request-queue' and 'response-queue'"
|
||||
)
|
||||
|
||||
# Build config values from tool config
|
||||
# Extract any config params defined by the service
|
||||
config_params = service_config.get("config-params", [])
|
||||
config_values = {}
|
||||
for param in config_params:
|
||||
param_name = param.get("name") if isinstance(param, dict) else param
|
||||
if param_name in data:
|
||||
config_values[param_name] = data[param_name]
|
||||
elif isinstance(param, dict) and param.get("required", False):
|
||||
raise RuntimeError(
|
||||
f"Tool {name} missing required config param '{param_name}'"
|
||||
)
|
||||
|
||||
# Arguments come from tool config
|
||||
config_args = data.get("arguments", [])
|
||||
arguments = [
|
||||
Argument(
|
||||
name=arg.get("name"),
|
||||
type=arg.get("type"),
|
||||
description=arg.get("description")
|
||||
)
|
||||
for arg in config_args
|
||||
]
|
||||
|
||||
# Store queues for the implementation
|
||||
impl = functools.partial(
|
||||
ToolServiceImpl,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
config_values=config_values,
|
||||
arguments=arguments,
|
||||
processor=self,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Tool type {impl_id} not known"
|
||||
|
|
@ -219,6 +428,10 @@ class Processor(AgentService):
|
|||
# Check if streaming is enabled
|
||||
streaming = getattr(request, 'streaming', False)
|
||||
|
||||
# Generate or retrieve session ID for provenance tracking
|
||||
session_id = getattr(request, 'session_id', '') or str(uuid.uuid4())
|
||||
collection = getattr(request, 'collection', 'default')
|
||||
|
||||
if request.history:
|
||||
history = [
|
||||
Action(
|
||||
|
|
@ -232,6 +445,36 @@ class Processor(AgentService):
|
|||
else:
|
||||
history = []
|
||||
|
||||
# Calculate iteration number (1-based)
|
||||
iteration_num = len(history) + 1
|
||||
session_uri = agent_session_uri(session_id)
|
||||
|
||||
# On first iteration, emit session triples
|
||||
if iteration_num == 1:
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
triples = set_graph(
|
||||
agent_session_triples(session_uri, request.question, timestamp),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=session_uri,
|
||||
user=request.user,
|
||||
collection=collection,
|
||||
),
|
||||
triples=triples,
|
||||
))
|
||||
logger.debug(f"Emitted session triples for {session_uri}")
|
||||
|
||||
# Send explain event for session
|
||||
if streaming:
|
||||
await respond(AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id=session_uri,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
))
|
||||
|
||||
logger.info(f"Question: {request.question}")
|
||||
|
||||
if len(history) >= self.max_iterations:
|
||||
|
|
@ -381,6 +624,60 @@ class Processor(AgentService):
|
|||
else:
|
||||
f = json.dumps(act.final)
|
||||
|
||||
# Emit final answer provenance triples
|
||||
final_uri = agent_final_uri(session_id)
|
||||
# No iterations: link to question; otherwise: link to last iteration
|
||||
if iteration_num > 1:
|
||||
final_question_uri = None
|
||||
final_previous_uri = agent_iteration_uri(session_id, iteration_num - 1)
|
||||
else:
|
||||
final_question_uri = session_uri
|
||||
final_previous_uri = None
|
||||
|
||||
# Save answer to librarian
|
||||
answer_doc_id = None
|
||||
if f:
|
||||
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
|
||||
try:
|
||||
await self.save_answer_content(
|
||||
doc_id=answer_doc_id,
|
||||
user=request.user,
|
||||
content=f,
|
||||
title=f"Agent Answer: {request.question[:50]}...",
|
||||
)
|
||||
logger.debug(f"Saved answer to librarian: {answer_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
answer_doc_id = None # Fall back to inline content
|
||||
|
||||
final_triples = set_graph(
|
||||
agent_final_triples(
|
||||
final_uri,
|
||||
question_uri=final_question_uri,
|
||||
previous_uri=final_previous_uri,
|
||||
document_id=answer_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=final_uri,
|
||||
user=request.user,
|
||||
collection=collection,
|
||||
),
|
||||
triples=final_triples,
|
||||
))
|
||||
logger.debug(f"Emitted final triples for {final_uri}")
|
||||
|
||||
# Send explain event for conclusion
|
||||
if streaming:
|
||||
await respond(AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id=final_uri,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
))
|
||||
|
||||
if streaming:
|
||||
# Streaming format - send end-of-dialog marker
|
||||
# Answer chunks were already sent via answer() callback during parsing
|
||||
|
|
@ -413,8 +710,86 @@ class Processor(AgentService):
|
|||
|
||||
logger.debug("Send next...")
|
||||
|
||||
# Emit iteration provenance triples
|
||||
iteration_uri = agent_iteration_uri(session_id, iteration_num)
|
||||
# First iteration links to question, subsequent to previous
|
||||
if iteration_num > 1:
|
||||
iter_question_uri = None
|
||||
iter_previous_uri = agent_iteration_uri(session_id, iteration_num - 1)
|
||||
else:
|
||||
iter_question_uri = session_uri
|
||||
iter_previous_uri = None
|
||||
|
||||
# Save thought to librarian
|
||||
thought_doc_id = None
|
||||
if act.thought:
|
||||
thought_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
|
||||
try:
|
||||
await self.save_answer_content(
|
||||
doc_id=thought_doc_id,
|
||||
user=request.user,
|
||||
content=act.thought,
|
||||
title=f"Agent Thought: {act.name}",
|
||||
)
|
||||
logger.debug(f"Saved thought to librarian: {thought_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save thought to librarian: {e}")
|
||||
thought_doc_id = None
|
||||
|
||||
# Save observation to librarian
|
||||
observation_doc_id = None
|
||||
if act.observation:
|
||||
observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
|
||||
try:
|
||||
await self.save_answer_content(
|
||||
doc_id=observation_doc_id,
|
||||
user=request.user,
|
||||
content=act.observation,
|
||||
title=f"Agent Observation: {act.name}",
|
||||
)
|
||||
logger.debug(f"Saved observation to librarian: {observation_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save observation to librarian: {e}")
|
||||
observation_doc_id = None
|
||||
|
||||
thought_entity_uri = agent_thought_uri(session_id, iteration_num)
|
||||
observation_entity_uri = agent_observation_uri(session_id, iteration_num)
|
||||
|
||||
iter_triples = set_graph(
|
||||
agent_iteration_triples(
|
||||
iteration_uri,
|
||||
question_uri=iter_question_uri,
|
||||
previous_uri=iter_previous_uri,
|
||||
action=act.name,
|
||||
arguments=act.arguments,
|
||||
thought_uri=thought_entity_uri if thought_doc_id else None,
|
||||
thought_document_id=thought_doc_id,
|
||||
observation_uri=observation_entity_uri if observation_doc_id else None,
|
||||
observation_document_id=observation_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=iteration_uri,
|
||||
user=request.user,
|
||||
collection=collection,
|
||||
),
|
||||
triples=iter_triples,
|
||||
))
|
||||
logger.debug(f"Emitted iteration triples for {iteration_uri}")
|
||||
|
||||
# Send explain event for iteration
|
||||
if streaming:
|
||||
await respond(AgentResponse(
|
||||
chunk_type="explain",
|
||||
content="",
|
||||
explain_id=iteration_uri,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
))
|
||||
|
||||
history.append(act)
|
||||
|
||||
|
||||
# Handle state transitions if tool execution was successful
|
||||
next_state = request.state
|
||||
if act.name in filtered_tools:
|
||||
|
|
@ -435,7 +810,9 @@ class Processor(AgentService):
|
|||
for h in history
|
||||
],
|
||||
user=request.user,
|
||||
collection=collection,
|
||||
streaming=streaming,
|
||||
session_id=session_id, # Pass session_id for provenance continuity
|
||||
)
|
||||
|
||||
await next(r)
|
||||
|
|
|
|||
|
|
@ -154,7 +154,8 @@ class RowEmbeddingsQueryImpl:
|
|||
logger.debug("Getting embeddings for row query...")
|
||||
|
||||
query_text = arguments.get("query")
|
||||
vectors = await embeddings_client.embed(query_text)
|
||||
all_vectors = await embeddings_client.embed([query_text])
|
||||
vector = all_vectors[0] if all_vectors else []
|
||||
|
||||
# Now query row embeddings
|
||||
client = self.context("row-embeddings-query-request")
|
||||
|
|
@ -164,7 +165,7 @@ class RowEmbeddingsQueryImpl:
|
|||
user = getattr(client, '_current_user', self.user or "trustgraph")
|
||||
|
||||
matches = await client.row_embeddings_query(
|
||||
vectors=vectors,
|
||||
vector=vector,
|
||||
schema_name=self.schema_name,
|
||||
user=user,
|
||||
collection=self.collection or "default",
|
||||
|
|
@ -202,3 +203,116 @@ class PromptImpl:
|
|||
id=self.template_id,
|
||||
variables=arguments
|
||||
)
|
||||
|
||||
|
||||
# This tool implementation invokes a dynamically configured tool service
|
||||
class ToolServiceImpl:
|
||||
"""
|
||||
Implementation for dynamically pluggable tool services.
|
||||
|
||||
Tool services are external Pulsar services that can be invoked as agent tools.
|
||||
The service is configured via a tool-service descriptor that defines the queues,
|
||||
and a tool descriptor that provides config values and argument definitions.
|
||||
"""
|
||||
|
||||
def __init__(self, context, request_queue, response_queue, config_values=None, arguments=None, processor=None):
|
||||
"""
|
||||
Initialize a tool service implementation.
|
||||
|
||||
Args:
|
||||
context: The context function (provides user info)
|
||||
request_queue: Full Pulsar topic for requests
|
||||
response_queue: Full Pulsar topic for responses
|
||||
config_values: Dict of config values (e.g., {"collection": "customers"})
|
||||
arguments: List of Argument objects defining the tool's parameters
|
||||
processor: The Processor instance (for pubsub access)
|
||||
"""
|
||||
self.context = context
|
||||
self.request_queue = request_queue
|
||||
self.response_queue = response_queue
|
||||
self.config_values = config_values or {}
|
||||
self.arguments = arguments or []
|
||||
self.processor = processor
|
||||
self._client = None
|
||||
|
||||
def get_arguments(self):
|
||||
return self.arguments
|
||||
|
||||
async def _get_or_create_client(self):
|
||||
"""Get or create the tool service client."""
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
|
||||
# Check if processor already has a client for this queue pair
|
||||
client_key = f"{self.request_queue}|{self.response_queue}"
|
||||
if client_key in self.processor.tool_service_clients:
|
||||
self._client = self.processor.tool_service_clients[client_key]
|
||||
return self._client
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from trustgraph.base.tool_service_client import ToolServiceClient
|
||||
from trustgraph.base.metrics import ProducerMetrics, SubscriberMetrics
|
||||
from trustgraph.schema import ToolServiceRequest, ToolServiceResponse
|
||||
import uuid
|
||||
|
||||
request_metrics = ProducerMetrics(
|
||||
processor=self.processor.id,
|
||||
flow="tool-service",
|
||||
name=self.request_queue
|
||||
)
|
||||
response_metrics = SubscriberMetrics(
|
||||
processor=self.processor.id,
|
||||
flow="tool-service",
|
||||
name=self.response_queue
|
||||
)
|
||||
|
||||
# Create unique subscription for responses
|
||||
subscription = f"{self.processor.id}--tool-service--{uuid.uuid4()}"
|
||||
|
||||
self._client = ToolServiceClient(
|
||||
backend=self.processor.pubsub,
|
||||
subscription=subscription,
|
||||
consumer_name=self.processor.id,
|
||||
request_topic=self.request_queue,
|
||||
request_schema=ToolServiceRequest,
|
||||
request_metrics=request_metrics,
|
||||
response_topic=self.response_queue,
|
||||
response_schema=ToolServiceResponse,
|
||||
response_metrics=response_metrics,
|
||||
)
|
||||
|
||||
# Start the client
|
||||
await self._client.start()
|
||||
|
||||
# Register for cleanup
|
||||
self.processor.tool_service_clients[client_key] = self._client
|
||||
|
||||
logger.debug(f"Created tool service client for {self.request_queue}")
|
||||
return self._client
|
||||
|
||||
async def invoke(self, **arguments):
|
||||
logger.debug(f"Tool service invocation: {self.request_queue}...")
|
||||
logger.debug(f"Config: {self.config_values}")
|
||||
logger.debug(f"Arguments: {arguments}")
|
||||
|
||||
# Get user from context if available
|
||||
user = "trustgraph"
|
||||
if hasattr(self.context, '_user'):
|
||||
user = self.context._user
|
||||
|
||||
# Get or create the client
|
||||
client = await self._get_or_create_client()
|
||||
|
||||
# Call the tool service
|
||||
response = await client.call(
|
||||
user=user,
|
||||
config=self.config_values,
|
||||
arguments=arguments,
|
||||
)
|
||||
|
||||
logger.debug(f"Tool service response: {response}")
|
||||
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
else:
|
||||
return json.dumps(response)
|
||||
|
|
|
|||
|
|
@ -8,14 +8,24 @@ import logging
|
|||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk
|
||||
from ... schema import TextDocument, Chunk, Metadata, Triples
|
||||
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
|
||||
|
||||
from ... provenance import (
|
||||
derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
)
|
||||
|
||||
# Component identification for provenance
|
||||
COMPONENT_NAME = "chunker"
|
||||
COMPONENT_VERSION = "1.0.0"
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "chunker"
|
||||
|
||||
|
||||
class Processor(ChunkingService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
|
@ -23,7 +33,7 @@ class Processor(ChunkingService):
|
|||
id = params.get("id", default_ident)
|
||||
chunk_size = params.get("chunk_size", 2000)
|
||||
chunk_overlap = params.get("chunk_overlap", 100)
|
||||
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | { "id": id }
|
||||
)
|
||||
|
|
@ -62,6 +72,13 @@ class Processor(ChunkingService):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "triples",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Recursive chunker initialized")
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
|
|
@ -69,6 +86,9 @@ class Processor(ChunkingService):
|
|||
v = msg.value()
|
||||
logger.info(f"Chunking document {v.metadata.id}...")
|
||||
|
||||
# Get text content (fetches from librarian if needed)
|
||||
text = await self.get_document_text(v)
|
||||
|
||||
# Extract chunk parameters from flow (allows runtime override)
|
||||
chunk_size, chunk_overlap = await self.chunk_document(
|
||||
msg, consumer, flow,
|
||||
|
|
@ -90,25 +110,84 @@ class Processor(ChunkingService):
|
|||
is_separator_regex=False,
|
||||
)
|
||||
|
||||
texts = text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
)
|
||||
texts = text_splitter.create_documents([text])
|
||||
|
||||
# Get parent document ID for provenance linking
|
||||
# This could be a page URI (doc/p3) or document URI (doc) - we don't need to parse it
|
||||
parent_doc_id = v.document_id or v.metadata.id
|
||||
|
||||
# Track character offset for provenance
|
||||
char_offset = 0
|
||||
|
||||
for ix, chunk in enumerate(texts):
|
||||
chunk_index = ix + 1 # 1-indexed
|
||||
|
||||
logger.debug(f"Created chunk of size {len(chunk.page_content)}")
|
||||
|
||||
# Generate chunk document ID by appending /c{index} to parent
|
||||
# Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1)
|
||||
chunk_doc_id = f"{parent_doc_id}/c{chunk_index}"
|
||||
chunk_uri = chunk_doc_id # URI is same as document ID
|
||||
parent_uri = parent_doc_id
|
||||
|
||||
chunk_content = chunk.page_content.encode("utf-8")
|
||||
chunk_length = len(chunk.page_content)
|
||||
|
||||
# Save chunk to librarian as child document
|
||||
await self.save_child_document(
|
||||
doc_id=chunk_doc_id,
|
||||
parent_id=parent_doc_id,
|
||||
user=v.metadata.user,
|
||||
content=chunk_content,
|
||||
document_type="chunk",
|
||||
title=f"Chunk {chunk_index}",
|
||||
)
|
||||
|
||||
# Emit provenance triples (stored in source graph for separation from core knowledge)
|
||||
prov_triples = derived_entity_triples(
|
||||
entity_uri=chunk_uri,
|
||||
parent_uri=parent_uri,
|
||||
component_name=COMPONENT_NAME,
|
||||
component_version=COMPONENT_VERSION,
|
||||
label=f"Chunk {chunk_index}",
|
||||
chunk_index=chunk_index,
|
||||
char_offset=char_offset,
|
||||
char_length=chunk_length,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
await flow("triples").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=chunk_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
triples=set_graph(prov_triples, GRAPH_SOURCE),
|
||||
))
|
||||
|
||||
# Forward chunk ID + content (post-chunker optimization)
|
||||
r = Chunk(
|
||||
metadata=v.metadata,
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
metadata=Metadata(
|
||||
id=chunk_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
chunk=chunk_content,
|
||||
document_id=chunk_doc_id,
|
||||
)
|
||||
|
||||
__class__.chunk_metric.labels(
|
||||
id=consumer.id, flow=consumer.flow
|
||||
).observe(len(chunk.page_content))
|
||||
).observe(chunk_length)
|
||||
|
||||
await flow("output").send(r)
|
||||
|
||||
# Update character offset (approximate, doesn't account for overlap)
|
||||
char_offset += chunk_length - chunk_overlap
|
||||
|
||||
logger.debug("Document chunking complete")
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -133,4 +212,3 @@ class Processor(ChunkingService):
|
|||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,14 +8,24 @@ import logging
|
|||
from langchain_text_splitters import TokenTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk
|
||||
from ... schema import TextDocument, Chunk, Metadata, Triples
|
||||
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
|
||||
|
||||
from ... provenance import (
|
||||
derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
)
|
||||
|
||||
# Component identification for provenance
|
||||
COMPONENT_NAME = "token-chunker"
|
||||
COMPONENT_VERSION = "1.0.0"
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "chunker"
|
||||
|
||||
|
||||
class Processor(ChunkingService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
|
@ -23,7 +33,7 @@ class Processor(ChunkingService):
|
|||
id = params.get("id", default_ident)
|
||||
chunk_size = params.get("chunk_size", 250)
|
||||
chunk_overlap = params.get("chunk_overlap", 15)
|
||||
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | { "id": id }
|
||||
)
|
||||
|
|
@ -61,6 +71,13 @@ class Processor(ChunkingService):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "triples",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Token chunker initialized")
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
|
|
@ -68,6 +85,9 @@ class Processor(ChunkingService):
|
|||
v = msg.value()
|
||||
logger.info(f"Chunking document {v.metadata.id}...")
|
||||
|
||||
# Get text content (fetches from librarian if needed)
|
||||
text = await self.get_document_text(v)
|
||||
|
||||
# Extract chunk parameters from flow (allows runtime override)
|
||||
chunk_size, chunk_overlap = await self.chunk_document(
|
||||
msg, consumer, flow,
|
||||
|
|
@ -88,25 +108,84 @@ class Processor(ChunkingService):
|
|||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
texts = text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
)
|
||||
texts = text_splitter.create_documents([text])
|
||||
|
||||
# Get parent document ID for provenance linking
|
||||
# This could be a page URI (doc/p3) or document URI (doc) - we don't need to parse it
|
||||
parent_doc_id = v.document_id or v.metadata.id
|
||||
|
||||
# Track token offset for provenance (approximate)
|
||||
token_offset = 0
|
||||
|
||||
for ix, chunk in enumerate(texts):
|
||||
chunk_index = ix + 1 # 1-indexed
|
||||
|
||||
logger.debug(f"Created chunk of size {len(chunk.page_content)}")
|
||||
|
||||
# Generate chunk document ID by appending /c{index} to parent
|
||||
# Works for both page URIs (doc/p3 -> doc/p3/c1) and doc URIs (doc -> doc/c1)
|
||||
chunk_doc_id = f"{parent_doc_id}/c{chunk_index}"
|
||||
chunk_uri = chunk_doc_id # URI is same as document ID
|
||||
parent_uri = parent_doc_id
|
||||
|
||||
chunk_content = chunk.page_content.encode("utf-8")
|
||||
chunk_length = len(chunk.page_content)
|
||||
|
||||
# Save chunk to librarian as child document
|
||||
await self.save_child_document(
|
||||
doc_id=chunk_doc_id,
|
||||
parent_id=parent_doc_id,
|
||||
user=v.metadata.user,
|
||||
content=chunk_content,
|
||||
document_type="chunk",
|
||||
title=f"Chunk {chunk_index}",
|
||||
)
|
||||
|
||||
# Emit provenance triples (stored in source graph for separation from core knowledge)
|
||||
prov_triples = derived_entity_triples(
|
||||
entity_uri=chunk_uri,
|
||||
parent_uri=parent_uri,
|
||||
component_name=COMPONENT_NAME,
|
||||
component_version=COMPONENT_VERSION,
|
||||
label=f"Chunk {chunk_index}",
|
||||
chunk_index=chunk_index,
|
||||
char_offset=token_offset, # Note: this is token offset, not char offset
|
||||
char_length=chunk_length,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
await flow("triples").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=chunk_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
triples=set_graph(prov_triples, GRAPH_SOURCE),
|
||||
))
|
||||
|
||||
# Forward chunk ID + content (post-chunker optimization)
|
||||
r = Chunk(
|
||||
metadata=v.metadata,
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
metadata=Metadata(
|
||||
id=chunk_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
chunk=chunk_content,
|
||||
document_id=chunk_doc_id,
|
||||
)
|
||||
|
||||
__class__.chunk_metric.labels(
|
||||
id=consumer.id, flow=consumer.flow
|
||||
).observe(len(chunk.page_content))
|
||||
).observe(chunk_length)
|
||||
|
||||
await flow("output").send(r)
|
||||
|
||||
# Update token offset (approximate, doesn't account for overlap)
|
||||
token_offset += chunk_size - chunk_overlap
|
||||
|
||||
logger.debug("Document chunking complete")
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -118,17 +197,16 @@ class Processor(ChunkingService):
|
|||
'-z', '--chunk-size',
|
||||
type=int,
|
||||
default=250,
|
||||
help=f'Chunk size (default: 250)'
|
||||
help=f'Chunk size in tokens (default: 250)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-v', '--chunk-overlap',
|
||||
type=int,
|
||||
default=15,
|
||||
help=f'Chunk overlap (default: 15)'
|
||||
help=f'Chunk overlap in tokens (default: 15)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,21 +2,44 @@
|
|||
"""
|
||||
Simple decoder, accepts PDF documents on input, outputs pages from the
|
||||
PDF document as text as separate output objects.
|
||||
|
||||
Supports both inline document data and fetching from librarian via Pulsar
|
||||
for large documents.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
|
||||
from ... schema import Document, TextDocument, Metadata
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from ... schema import Triples
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ... provenance import (
|
||||
document_uri, page_uri, derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
)
|
||||
|
||||
# Component identification for provenance
|
||||
COMPONENT_NAME = "pdf-decoder"
|
||||
COMPONENT_VERSION = "1.0.0"
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "pdf-decoder"
|
||||
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
|
@ -44,8 +67,164 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "triples",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
# Librarian client for fetching document content
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor = id, flow = None, name = "librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend = self.pubsub,
|
||||
topic = librarian_request_q,
|
||||
schema = LibrarianRequest,
|
||||
metrics = librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor = id, flow = None, name = "librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup = self.taskgroup,
|
||||
backend = self.pubsub,
|
||||
flow = None,
|
||||
topic = librarian_response_q,
|
||||
subscriber = f"{id}-librarian",
|
||||
schema = LibrarianResponse,
|
||||
handler = self.on_librarian_response,
|
||||
metrics = librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_requests = {}
|
||||
|
||||
logger.info("PDF decoder initialized")
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id and request_id in self.pending_requests:
|
||||
future = self.pending_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
||||
"""
|
||||
Fetch document content from librarian via Pulsar.
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-content",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return response.content
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout fetching document {document_id}")
|
||||
|
||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
||||
document_type="page", title=None, timeout=120):
|
||||
"""
|
||||
Save a child document to the librarian.
|
||||
|
||||
Args:
|
||||
doc_id: ID for the new child document
|
||||
parent_id: ID of the parent document
|
||||
user: User ID
|
||||
content: Document content (bytes)
|
||||
document_type: Type of document ("page", "chunk", etc.)
|
||||
title: Optional title
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
The document ID on success
|
||||
"""
|
||||
import base64
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or doc_id,
|
||||
parent_id=parent_id,
|
||||
document_type=document_type,
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-child-document",
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content).decode("utf-8"),
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving child document: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving child document {doc_id}")
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
|
||||
logger.debug("PDF message received")
|
||||
|
|
@ -54,26 +233,102 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.info(f"Decoding PDF {v.metadata.id}...")
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
|
||||
with tempfile.NamedTemporaryFile(delete_on_close=False, suffix='.pdf') as fp:
|
||||
temp_path = fp.name
|
||||
|
||||
fp.write(base64.b64decode(v.data))
|
||||
fp.close()
|
||||
# Check if we should fetch from librarian or use inline data
|
||||
if v.document_id:
|
||||
# Fetch from librarian via Pulsar
|
||||
logger.info(f"Fetching document {v.document_id} from librarian...")
|
||||
fp.close()
|
||||
|
||||
with open(fp.name, mode='rb') as f:
|
||||
content = await self.fetch_document_content(
|
||||
document_id=v.document_id,
|
||||
user=v.metadata.user,
|
||||
)
|
||||
|
||||
loader = PyPDFLoader(fp.name)
|
||||
pages = loader.load()
|
||||
# Content is base64 encoded
|
||||
if isinstance(content, str):
|
||||
content = content.encode('utf-8')
|
||||
decoded_content = base64.b64decode(content)
|
||||
|
||||
for ix, page in enumerate(pages):
|
||||
with open(temp_path, 'wb') as f:
|
||||
f.write(decoded_content)
|
||||
|
||||
logger.debug(f"Processing page {ix}")
|
||||
logger.info(f"Fetched {len(decoded_content)} bytes from librarian")
|
||||
else:
|
||||
# Use inline data (backward compatibility)
|
||||
fp.write(base64.b64decode(v.data))
|
||||
fp.close()
|
||||
|
||||
r = TextDocument(
|
||||
metadata=v.metadata,
|
||||
text=page.page_content.encode("utf-8"),
|
||||
)
|
||||
loader = PyPDFLoader(temp_path)
|
||||
pages = loader.load()
|
||||
|
||||
await flow("output").send(r)
|
||||
# Get the source document ID
|
||||
source_doc_id = v.document_id or v.metadata.id
|
||||
|
||||
for ix, page in enumerate(pages):
|
||||
page_num = ix + 1 # 1-indexed page numbers
|
||||
|
||||
logger.debug(f"Processing page {page_num}")
|
||||
|
||||
# Generate page document ID
|
||||
page_doc_id = f"{source_doc_id}/p{page_num}"
|
||||
page_content = page.page_content.encode("utf-8")
|
||||
|
||||
# Save page as child document in librarian
|
||||
await self.save_child_document(
|
||||
doc_id=page_doc_id,
|
||||
parent_id=source_doc_id,
|
||||
user=v.metadata.user,
|
||||
content=page_content,
|
||||
document_type="page",
|
||||
title=f"Page {page_num}",
|
||||
)
|
||||
|
||||
# Emit provenance triples (stored in source graph for separation from core knowledge)
|
||||
doc_uri = document_uri(source_doc_id)
|
||||
pg_uri = page_uri(source_doc_id, page_num)
|
||||
|
||||
prov_triples = derived_entity_triples(
|
||||
entity_uri=pg_uri,
|
||||
parent_uri=doc_uri,
|
||||
component_name=COMPONENT_NAME,
|
||||
component_version=COMPONENT_VERSION,
|
||||
label=f"Page {page_num}",
|
||||
page_number=page_num,
|
||||
)
|
||||
|
||||
await flow("triples").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=pg_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
triples=set_graph(prov_triples, GRAPH_SOURCE),
|
||||
))
|
||||
|
||||
# Forward page document ID to chunker
|
||||
# Chunker will fetch content from librarian
|
||||
r = TextDocument(
|
||||
metadata=Metadata(
|
||||
id=pg_uri,
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
document_id=page_doc_id,
|
||||
text=b"", # Empty, chunker will fetch from librarian
|
||||
)
|
||||
|
||||
await flow("output").send(r)
|
||||
|
||||
# Clean up temp file
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
logger.debug("PDF decoding complete")
|
||||
|
||||
|
|
@ -81,7 +336,18 @@ class Processor(FlowProcessor):
|
|||
def add_args(parser):
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--librarian-request-queue',
|
||||
default=default_librarian_request_queue,
|
||||
help=f'Librarian request queue (default: {default_librarian_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--librarian-response-queue',
|
||||
default=default_librarian_response_queue,
|
||||
help=f'Librarian response queue (default: {default_librarian_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
||||
|
|
|
|||
|
|
@ -589,6 +589,8 @@ class EntityCentricKnowledgeGraph:
|
|||
|
||||
# quads_by_entity: primary data table
|
||||
# Every entity has a partition containing all quads it participates in
|
||||
# Clustering key includes dtype/lang to distinguish literals with same value
|
||||
# but different datatype or language tag (e.g., "thing" vs "thing"@en)
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.entity_table} (
|
||||
collection text,
|
||||
|
|
@ -601,11 +603,13 @@ class EntityCentricKnowledgeGraph:
|
|||
d text,
|
||||
dtype text,
|
||||
lang text,
|
||||
PRIMARY KEY ((collection, entity), role, p, otype, s, o, d)
|
||||
PRIMARY KEY ((collection, entity), role, p, otype, s, o, d, dtype, lang)
|
||||
);
|
||||
""")
|
||||
|
||||
# quads_by_collection: manifest for collection-level queries and deletion
|
||||
# Clustering key includes otype/dtype/lang to distinguish literals with same
|
||||
# value but different metadata (e.g., "thing" vs "thing"@en vs "thing"^^xsd:string)
|
||||
self.session.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.collection_table} (
|
||||
collection text,
|
||||
|
|
@ -616,7 +620,7 @@ class EntityCentricKnowledgeGraph:
|
|||
otype text,
|
||||
dtype text,
|
||||
lang text,
|
||||
PRIMARY KEY (collection, d, s, p, o)
|
||||
PRIMARY KEY (collection, d, s, p, o, otype, dtype, lang)
|
||||
);
|
||||
""")
|
||||
|
||||
|
|
@ -718,7 +722,7 @@ class EntityCentricKnowledgeGraph:
|
|||
)
|
||||
|
||||
self.delete_collection_row_stmt = self.session.prepare(
|
||||
f"DELETE FROM {self.collection_table} WHERE collection = ? AND d = ? AND s = ? AND p = ? AND o = ?"
|
||||
f"DELETE FROM {self.collection_table} WHERE collection = ? AND d = ? AND s = ? AND p = ? AND o = ? AND otype = ? AND dtype = ? AND lang = ?"
|
||||
)
|
||||
|
||||
logger.info("Prepared statements initialized for entity-centric schema")
|
||||
|
|
@ -797,7 +801,7 @@ class EntityCentricKnowledgeGraph:
|
|||
def get_s(self, collection, s, g=None, limit=10):
|
||||
"""
|
||||
Query by subject. Returns quads where s is the subject.
|
||||
g=None: default graph, g='*': all graphs
|
||||
g=None: all graphs, g='': default graph only, g='uri': specific graph
|
||||
"""
|
||||
rows = self.session.execute(self.get_entity_as_s_stmt, (collection, s, limit))
|
||||
|
||||
|
|
@ -805,10 +809,7 @@ class EntityCentricKnowledgeGraph:
|
|||
for row in rows:
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
# Filter by graph if specified
|
||||
if g is None or g == DEFAULT_GRAPH:
|
||||
if d != DEFAULT_GRAPH:
|
||||
continue
|
||||
elif g != GRAPH_WILDCARD and d != g:
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
|
||||
results.append(QuadResult(
|
||||
|
|
@ -819,16 +820,13 @@ class EntityCentricKnowledgeGraph:
|
|||
return results
|
||||
|
||||
def get_p(self, collection, p, g=None, limit=10):
|
||||
"""Query by predicate"""
|
||||
"""Query by predicate. g=None: all graphs, g='': default graph only"""
|
||||
rows = self.session.execute(self.get_entity_as_p_stmt, (collection, p, limit))
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is None or g == DEFAULT_GRAPH:
|
||||
if d != DEFAULT_GRAPH:
|
||||
continue
|
||||
elif g != GRAPH_WILDCARD and d != g:
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
|
||||
results.append(QuadResult(
|
||||
|
|
@ -839,16 +837,13 @@ class EntityCentricKnowledgeGraph:
|
|||
return results
|
||||
|
||||
def get_o(self, collection, o, g=None, limit=10):
|
||||
"""Query by object"""
|
||||
"""Query by object. g=None: all graphs, g='': default graph only"""
|
||||
rows = self.session.execute(self.get_entity_as_o_stmt, (collection, o, limit))
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is None or g == DEFAULT_GRAPH:
|
||||
if d != DEFAULT_GRAPH:
|
||||
continue
|
||||
elif g != GRAPH_WILDCARD and d != g:
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
|
||||
results.append(QuadResult(
|
||||
|
|
@ -859,16 +854,13 @@ class EntityCentricKnowledgeGraph:
|
|||
return results
|
||||
|
||||
def get_sp(self, collection, s, p, g=None, limit=10):
|
||||
"""Query by subject and predicate"""
|
||||
"""Query by subject and predicate. g=None: all graphs, g='': default graph only"""
|
||||
rows = self.session.execute(self.get_entity_as_s_p_stmt, (collection, s, p, limit))
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is None or g == DEFAULT_GRAPH:
|
||||
if d != DEFAULT_GRAPH:
|
||||
continue
|
||||
elif g != GRAPH_WILDCARD and d != g:
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
|
||||
results.append(QuadResult(
|
||||
|
|
@ -879,16 +871,13 @@ class EntityCentricKnowledgeGraph:
|
|||
return results
|
||||
|
||||
def get_po(self, collection, p, o, g=None, limit=10):
|
||||
"""Query by predicate and object"""
|
||||
"""Query by predicate and object. g=None: all graphs, g='': default graph only"""
|
||||
rows = self.session.execute(self.get_entity_as_o_p_stmt, (collection, o, p, limit))
|
||||
|
||||
results = []
|
||||
for row in rows:
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is None or g == DEFAULT_GRAPH:
|
||||
if d != DEFAULT_GRAPH:
|
||||
continue
|
||||
elif g != GRAPH_WILDCARD and d != g:
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
|
||||
results.append(QuadResult(
|
||||
|
|
@ -899,7 +888,7 @@ class EntityCentricKnowledgeGraph:
|
|||
return results
|
||||
|
||||
def get_os(self, collection, o, s, g=None, limit=10):
|
||||
"""Query by object and subject"""
|
||||
"""Query by object and subject. g=None: all graphs, g='': default graph only"""
|
||||
# Use subject partition with role='S', filter by o
|
||||
rows = self.session.execute(self.get_entity_as_s_stmt, (collection, s, limit))
|
||||
|
||||
|
|
@ -909,10 +898,7 @@ class EntityCentricKnowledgeGraph:
|
|||
continue
|
||||
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is None or g == DEFAULT_GRAPH:
|
||||
if d != DEFAULT_GRAPH:
|
||||
continue
|
||||
elif g != GRAPH_WILDCARD and d != g:
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
|
||||
results.append(QuadResult(
|
||||
|
|
@ -923,7 +909,7 @@ class EntityCentricKnowledgeGraph:
|
|||
return results
|
||||
|
||||
def get_spo(self, collection, s, p, o, g=None, limit=10):
|
||||
"""Query by subject, predicate, object (find which graphs)"""
|
||||
"""Query by subject, predicate, object (find which graphs). g=None: all graphs, g='': default graph only"""
|
||||
rows = self.session.execute(self.get_entity_as_s_p_stmt, (collection, s, p, limit))
|
||||
|
||||
results = []
|
||||
|
|
@ -932,10 +918,7 @@ class EntityCentricKnowledgeGraph:
|
|||
continue
|
||||
|
||||
d = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
if g is None or g == DEFAULT_GRAPH:
|
||||
if d != DEFAULT_GRAPH:
|
||||
continue
|
||||
elif g != GRAPH_WILDCARD and d != g:
|
||||
if g is not None and d != g:
|
||||
continue
|
||||
|
||||
results.append(QuadResult(
|
||||
|
|
@ -991,9 +974,9 @@ class EntityCentricKnowledgeGraph:
|
|||
3. Delete entire entity partitions
|
||||
4. Delete collection rows
|
||||
"""
|
||||
# Read all quads from collection table
|
||||
# Read all quads from collection table (including type metadata for delete)
|
||||
rows = self.session.execute(
|
||||
f"SELECT d, s, p, o, otype FROM {self.collection_table} WHERE collection = %s",
|
||||
f"SELECT d, s, p, o, otype, dtype, lang FROM {self.collection_table} WHERE collection = %s",
|
||||
(collection,)
|
||||
)
|
||||
|
||||
|
|
@ -1002,8 +985,11 @@ class EntityCentricKnowledgeGraph:
|
|||
quads = []
|
||||
|
||||
for row in rows:
|
||||
d, s, p, o, otype = row.d, row.s, row.p, row.o, row.otype
|
||||
quads.append((d, s, p, o))
|
||||
d, s, p, o = row.d, row.s, row.p, row.o
|
||||
otype = row.otype
|
||||
dtype = row.dtype if hasattr(row, 'dtype') else ''
|
||||
lang = row.lang if hasattr(row, 'lang') else ''
|
||||
quads.append((d, s, p, o, otype, dtype, lang))
|
||||
|
||||
# Subject and predicate are always entities
|
||||
entities.add(s)
|
||||
|
|
@ -1038,8 +1024,8 @@ class EntityCentricKnowledgeGraph:
|
|||
batch = BatchStatement()
|
||||
count = 0
|
||||
|
||||
for d, s, p, o in quads:
|
||||
batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o))
|
||||
for d, s, p, o, otype, dtype, lang in quads:
|
||||
batch.add(self.delete_collection_row_stmt, (collection, d, s, p, o, otype, dtype, lang))
|
||||
count += 1
|
||||
|
||||
# Execute batch every 50 quads
|
||||
|
|
|
|||
|
|
@ -84,14 +84,14 @@ class DocVectors:
|
|||
dim=dimension,
|
||||
)
|
||||
|
||||
doc_field = FieldSchema(
|
||||
name="doc",
|
||||
chunk_id_field = FieldSchema(
|
||||
name="chunk_id",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [pkey_field, vec_field, doc_field],
|
||||
fields = [pkey_field, vec_field, chunk_id_field],
|
||||
description = "Document embedding schema",
|
||||
)
|
||||
|
||||
|
|
@ -119,17 +119,17 @@ class DocVectors:
|
|||
self.collections[(dimension, user, collection)] = collection_name
|
||||
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
|
||||
|
||||
def insert(self, embeds, doc, user, collection):
|
||||
def insert(self, embeds, chunk_id, user, collection):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if (dim, user, collection) not in self.collections:
|
||||
self.init_collection(dim, user, collection)
|
||||
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"doc": doc,
|
||||
"chunk_id": chunk_id,
|
||||
}
|
||||
]
|
||||
|
||||
|
|
@ -138,7 +138,7 @@ class DocVectors:
|
|||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, user, collection, fields=["doc"], limit=10):
|
||||
def search(self, embeds, user, collection, fields=["chunk_id"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
|
|
|
|||
|
|
@ -90,8 +90,14 @@ class EntityVectors:
|
|||
max_length=65535,
|
||||
)
|
||||
|
||||
chunk_id_field = FieldSchema(
|
||||
name="chunk_id",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [pkey_field, vec_field, entity_field],
|
||||
fields = [pkey_field, vec_field, entity_field, chunk_id_field],
|
||||
description = "Graph embedding schema",
|
||||
)
|
||||
|
||||
|
|
@ -119,17 +125,18 @@ class EntityVectors:
|
|||
self.collections[(dimension, user, collection)] = collection_name
|
||||
logger.info(f"Created Milvus collection {collection_name} with dimension {dimension}")
|
||||
|
||||
def insert(self, embeds, entity, user, collection):
|
||||
def insert(self, embeds, entity, user, collection, chunk_id=""):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if (dim, user, collection) not in self.collections:
|
||||
self.init_collection(dim, user, collection)
|
||||
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"entity": entity,
|
||||
"chunk_id": chunk_id,
|
||||
}
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -62,16 +62,17 @@ class Processor(FlowProcessor):
|
|||
|
||||
resp = await flow("embeddings-request").request(
|
||||
EmbeddingsRequest(
|
||||
text = v.chunk
|
||||
texts=[v.chunk]
|
||||
)
|
||||
)
|
||||
|
||||
vectors = resp.vectors
|
||||
# vectors[0] is the vector for the first (only) text
|
||||
vector = resp.vectors[0] if resp.vectors else []
|
||||
|
||||
embeds = [
|
||||
ChunkEmbeddings(
|
||||
chunk=v.chunk,
|
||||
vectors=vectors,
|
||||
chunk_id=v.document_id,
|
||||
vector=vector,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -46,19 +46,21 @@ class Processor(EmbeddingsService):
|
|||
else:
|
||||
logger.debug(f"Using cached model: {model_name}")
|
||||
|
||||
async def on_embeddings(self, text, model=None):
|
||||
async def on_embeddings(self, texts, model=None):
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
use_model = model or self.default_model
|
||||
|
||||
# Reload model if it has changed
|
||||
self._load_model(use_model)
|
||||
|
||||
vecs = self.embeddings.embed([text])
|
||||
# FastEmbed processes the full batch efficiently
|
||||
vecs = list(self.embeddings.embed(texts))
|
||||
|
||||
return [
|
||||
v.tolist()
|
||||
for v in vecs
|
||||
]
|
||||
# Return list of vectors, one per input text
|
||||
return [v.tolist() for v in vecs]
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
|
|
@ -58,22 +58,25 @@ class Processor(FlowProcessor):
|
|||
v = msg.value()
|
||||
logger.info(f"Indexing {v.metadata.id}...")
|
||||
|
||||
entities = []
|
||||
|
||||
try:
|
||||
|
||||
for entity in v.entities:
|
||||
# Collect all contexts for batch embedding
|
||||
contexts = [entity.context for entity in v.entities]
|
||||
|
||||
vectors = await flow("embeddings-request").embed(
|
||||
text = entity.context
|
||||
)
|
||||
# Single batch embedding call
|
||||
all_vectors = await flow("embeddings-request").embed(
|
||||
texts=contexts
|
||||
)
|
||||
|
||||
entities.append(
|
||||
EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vectors=vectors
|
||||
)
|
||||
# Pair results with entities
|
||||
entities = [
|
||||
EntityEmbeddings(
|
||||
entity=entity.entity,
|
||||
vector=vector,
|
||||
chunk_id=entity.chunk_id, # Provenance: source chunk
|
||||
)
|
||||
for entity, vector in zip(v.entities, all_vectors)
|
||||
]
|
||||
|
||||
# Send in batches to avoid oversized messages
|
||||
for i in range(0, len(entities), self.batch_size):
|
||||
|
|
|
|||
|
|
@ -30,16 +30,21 @@ class Processor(EmbeddingsService):
|
|||
self.client = Client(host=ollama)
|
||||
self.default_model = model
|
||||
|
||||
async def on_embeddings(self, text, model=None):
|
||||
async def on_embeddings(self, texts, model=None):
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
use_model = model or self.default_model
|
||||
|
||||
# Ollama handles batch input efficiently
|
||||
embeds = self.client.embed(
|
||||
model = use_model,
|
||||
input = text
|
||||
input = texts
|
||||
)
|
||||
|
||||
return embeds.embeddings
|
||||
# Return list of vectors, one per input text
|
||||
return list(embeds.embeddings)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
|
|
@ -200,15 +200,23 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
embeddings_list = []
|
||||
|
||||
try:
|
||||
for text, (index_name, index_value) in texts_to_embed.items():
|
||||
vectors = await flow("embeddings-request").embed(text=text)
|
||||
# Collect texts and metadata for batch embedding
|
||||
texts = list(texts_to_embed.keys())
|
||||
metadata = list(texts_to_embed.values())
|
||||
|
||||
# Single batch embedding call
|
||||
all_vectors = await flow("embeddings-request").embed(texts=texts)
|
||||
|
||||
# Pair results with metadata
|
||||
for text, (index_name, index_value), vector in zip(
|
||||
texts, metadata, all_vectors
|
||||
):
|
||||
embeddings_list.append(
|
||||
RowIndexEmbedding(
|
||||
index_name=index_name,
|
||||
index_value=index_value,
|
||||
text=text,
|
||||
vectors=vectors
|
||||
vector=vector
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,13 @@ import logging
|
|||
from ....schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
|
||||
from ....schema import EntityContext, EntityContexts
|
||||
|
||||
from ....rdf import TRUSTGRAPH_ENTITIES, RDF_LABEL, SUBJECT_OF, DEFINITION
|
||||
from ....rdf import TRUSTGRAPH_ENTITIES, RDF_LABEL, DEFINITION
|
||||
|
||||
from ....base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ....base import AgentClientSpec
|
||||
|
||||
from ....provenance import subgraph_uri, subgraph_provenance_triples, set_graph, GRAPH_SOURCE
|
||||
from ....flow_version import __version__ as COMPONENT_VERSION
|
||||
from ....template import PromptManager
|
||||
|
||||
# Module logger
|
||||
|
|
@ -104,7 +106,7 @@ class Processor(FlowProcessor):
|
|||
tpls = Triples(
|
||||
metadata = Metadata(
|
||||
id = metadata.id,
|
||||
metadata = [],
|
||||
root = metadata.root,
|
||||
user = metadata.user,
|
||||
collection = metadata.collection,
|
||||
),
|
||||
|
|
@ -117,7 +119,7 @@ class Processor(FlowProcessor):
|
|||
ecs = EntityContexts(
|
||||
metadata = Metadata(
|
||||
id = metadata.id,
|
||||
metadata = [],
|
||||
root = metadata.root,
|
||||
user = metadata.user,
|
||||
collection = metadata.collection,
|
||||
),
|
||||
|
|
@ -183,24 +185,8 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.debug(f"Agent prompt: {prompt}")
|
||||
|
||||
async def handle(response):
|
||||
|
||||
logger.debug(f"Agent response: {response}")
|
||||
|
||||
if response.error is not None:
|
||||
if response.error.message:
|
||||
raise RuntimeError(str(response.error.message))
|
||||
else:
|
||||
raise RuntimeError(str(response.error))
|
||||
|
||||
if response.answer is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
# Send to agent API
|
||||
agent_response = await flow("agent-request").invoke(
|
||||
recipient = handle,
|
||||
question = prompt
|
||||
)
|
||||
|
||||
|
|
@ -212,14 +198,22 @@ class Processor(FlowProcessor):
|
|||
return
|
||||
|
||||
# Process extraction data
|
||||
triples, entity_contexts = self.process_extraction_data(
|
||||
extraction_data, v.metadata
|
||||
)
|
||||
triples, entity_contexts, extracted_triples = \
|
||||
self.process_extraction_data(extraction_data, v.metadata)
|
||||
|
||||
# Generate subgraph provenance for extracted triples
|
||||
if extracted_triples:
|
||||
chunk_uri = v.metadata.id
|
||||
sg_uri = subgraph_uri()
|
||||
prov_triples = subgraph_provenance_triples(
|
||||
subgraph_uri=sg_uri,
|
||||
extracted_triples=extracted_triples,
|
||||
chunk_uri=chunk_uri,
|
||||
component_name=default_ident,
|
||||
component_version=COMPONENT_VERSION,
|
||||
)
|
||||
triples.extend(set_graph(prov_triples, GRAPH_SOURCE))
|
||||
|
||||
# Put document metadata into triples
|
||||
for t in v.metadata.metadata:
|
||||
triples.append(t)
|
||||
|
||||
# Emit outputs
|
||||
if triples:
|
||||
await self.emit_triples(flow("triples"), v.metadata, triples)
|
||||
|
|
@ -241,8 +235,13 @@ class Processor(FlowProcessor):
|
|||
Data is a flat list of objects with 'type' discriminator field:
|
||||
- {"type": "definition", "entity": "...", "definition": "..."}
|
||||
- {"type": "relationship", "subject": "...", "predicate": "...", "object": "...", "object-entity": bool}
|
||||
|
||||
Returns:
|
||||
Tuple of (all_triples, entity_contexts, extracted_triples) where
|
||||
extracted_triples contains only the core knowledge facts (for provenance).
|
||||
"""
|
||||
triples = []
|
||||
extracted_triples = []
|
||||
entity_contexts = []
|
||||
|
||||
# Categorize items by type
|
||||
|
|
@ -262,26 +261,20 @@ class Processor(FlowProcessor):
|
|||
))
|
||||
|
||||
# Add definition
|
||||
triples.append(Triple(
|
||||
definition_triple = Triple(
|
||||
s = Term(type=IRI, iri=entity_uri),
|
||||
p = Term(type=IRI, iri=DEFINITION),
|
||||
o = Term(type=LITERAL, value=defn["definition"]),
|
||||
))
|
||||
|
||||
# Add subject-of relationship to document
|
||||
if metadata.id:
|
||||
triples.append(Triple(
|
||||
s = Term(type=IRI, iri=entity_uri),
|
||||
p = Term(type=IRI, iri=SUBJECT_OF),
|
||||
o = Term(type=IRI, iri=metadata.id),
|
||||
))
|
||||
)
|
||||
triples.append(definition_triple)
|
||||
extracted_triples.append(definition_triple)
|
||||
|
||||
# Create entity context for embeddings
|
||||
entity_contexts.append(EntityContext(
|
||||
entity=Term(type=IRI, iri=entity_uri),
|
||||
context=defn["definition"]
|
||||
))
|
||||
|
||||
|
||||
# Process relationships
|
||||
for rel in relationships:
|
||||
|
||||
|
|
@ -318,34 +311,15 @@ class Processor(FlowProcessor):
|
|||
))
|
||||
|
||||
# Add the main relationship triple
|
||||
triples.append(Triple(
|
||||
relationship_triple = Triple(
|
||||
s = subject_value,
|
||||
p = predicate_value,
|
||||
o = object_value
|
||||
))
|
||||
)
|
||||
triples.append(relationship_triple)
|
||||
extracted_triples.append(relationship_triple)
|
||||
|
||||
# Add subject-of relationships to document
|
||||
if metadata.id:
|
||||
triples.append(Triple(
|
||||
s = subject_value,
|
||||
p = Term(type=IRI, iri=SUBJECT_OF),
|
||||
o = Term(type=IRI, iri=metadata.id),
|
||||
))
|
||||
|
||||
triples.append(Triple(
|
||||
s = predicate_value,
|
||||
p = Term(type=IRI, iri=SUBJECT_OF),
|
||||
o = Term(type=IRI, iri=metadata.id),
|
||||
))
|
||||
|
||||
if rel.get("object-entity", True):
|
||||
triples.append(Triple(
|
||||
s = object_value,
|
||||
p = Term(type=IRI, iri=SUBJECT_OF),
|
||||
o = Term(type=IRI, iri=metadata.id),
|
||||
))
|
||||
|
||||
return triples, entity_contexts
|
||||
return triples, entity_contexts, extracted_triples
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
|
|
@ -15,15 +15,16 @@ from .... schema import Chunk, Triple, Triples, Metadata, Term, IRI, LITERAL
|
|||
logger = logging.getLogger(__name__)
|
||||
from .... schema import EntityContext, EntityContexts
|
||||
from .... schema import PromptRequest, PromptResponse
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL, SUBJECT_OF
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION, RDF_LABEL
|
||||
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base import PromptClientSpec
|
||||
from .... base import PromptClientSpec, ParameterSpec
|
||||
|
||||
from .... provenance import subgraph_uri, subgraph_provenance_triples, set_graph, GRAPH_SOURCE
|
||||
from .... flow_version import __version__ as COMPONENT_VERSION
|
||||
|
||||
DEFINITION_VALUE = Term(type=IRI, iri=DEFINITION)
|
||||
RDF_LABEL_VALUE = Term(type=IRI, iri=RDF_LABEL)
|
||||
SUBJECT_OF_VALUE = Term(type=IRI, iri=SUBJECT_OF)
|
||||
|
||||
default_ident = "kg-extract-definitions"
|
||||
default_concurrency = 1
|
||||
default_triples_batch_size = 50
|
||||
|
|
@ -75,6 +76,10 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
# Optional flow parameters for provenance
|
||||
self.register_specification(ParameterSpec("llm-model"))
|
||||
self.register_specification(ParameterSpec("ontology"))
|
||||
|
||||
def to_uri(self, text):
|
||||
|
||||
part = text.replace(" ", "-").lower().encode("utf-8")
|
||||
|
|
@ -126,12 +131,19 @@ class Processor(FlowProcessor):
|
|||
raise e
|
||||
|
||||
triples = []
|
||||
extracted_triples = []
|
||||
entities = []
|
||||
|
||||
# FIXME: Putting metadata into triples store is duplicated in
|
||||
# relationships extractor too
|
||||
for t in v.metadata.metadata:
|
||||
triples.append(t)
|
||||
# Get chunk document ID for provenance linking
|
||||
chunk_doc_id = v.document_id if v.document_id else v.metadata.id
|
||||
chunk_uri = v.metadata.id # The URI form for the chunk
|
||||
|
||||
# Get optional provenance parameters
|
||||
llm_model = flow("llm-model")
|
||||
ontology_uri = flow("ontology")
|
||||
|
||||
# Note: Document metadata is now emitted once by librarian at processing
|
||||
# initiation, so we don't need to duplicate it here.
|
||||
|
||||
for defn in defs:
|
||||
|
||||
|
|
@ -155,28 +167,43 @@ class Processor(FlowProcessor):
|
|||
o=Term(type=LITERAL, value=s),
|
||||
))
|
||||
|
||||
triples.append(Triple(
|
||||
# The definition triple - this is the main extracted fact
|
||||
definition_triple = Triple(
|
||||
s=s_value, p=DEFINITION_VALUE, o=o_value
|
||||
))
|
||||
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=SUBJECT_OF_VALUE,
|
||||
o=Term(type=IRI, iri=v.metadata.id)
|
||||
))
|
||||
)
|
||||
triples.append(definition_triple)
|
||||
extracted_triples.append(definition_triple)
|
||||
|
||||
# Output entity name as context for direct name matching
|
||||
# Include chunk_id for embedding provenance
|
||||
entities.append(EntityContext(
|
||||
entity=s_value,
|
||||
context=s,
|
||||
chunk_id=chunk_doc_id,
|
||||
))
|
||||
|
||||
# Output definition as context for semantic matching
|
||||
# Include chunk_id for embedding provenance
|
||||
entities.append(EntityContext(
|
||||
entity=s_value,
|
||||
context=defn["definition"],
|
||||
chunk_id=chunk_doc_id,
|
||||
))
|
||||
|
||||
# Generate subgraph provenance once for all extracted triples
|
||||
if extracted_triples:
|
||||
sg_uri = subgraph_uri()
|
||||
prov_triples = subgraph_provenance_triples(
|
||||
subgraph_uri=sg_uri,
|
||||
extracted_triples=extracted_triples,
|
||||
chunk_uri=chunk_uri,
|
||||
component_name=default_ident,
|
||||
component_version=COMPONENT_VERSION,
|
||||
llm_model=llm_model,
|
||||
ontology_uri=ontology_uri,
|
||||
)
|
||||
triples.extend(set_graph(prov_triples, GRAPH_SOURCE))
|
||||
|
||||
# Send triples in batches
|
||||
for i in range(0, len(triples), self.triples_batch_size):
|
||||
batch = triples[i:i + self.triples_batch_size]
|
||||
|
|
@ -184,7 +211,7 @@ class Processor(FlowProcessor):
|
|||
flow("triples"),
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
|
|
@ -198,7 +225,7 @@ class Processor(FlowProcessor):
|
|||
flow("entity-contexts"),
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ from .ontology_selector import OntologySelector, OntologySubset
|
|||
from .simplified_parser import parse_extraction_response
|
||||
from .triple_converter import TripleConverter
|
||||
|
||||
from .... provenance import subgraph_uri, subgraph_provenance_triples, set_graph, GRAPH_SOURCE
|
||||
from .... flow_version import __version__ as COMPONENT_VERSION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "kg-extract-ontology"
|
||||
|
|
@ -148,8 +151,8 @@ class Processor(FlowProcessor):
|
|||
|
||||
# Detect embedding dimension by embedding a test string
|
||||
logger.info("Detecting embedding dimension from embeddings service...")
|
||||
test_embedding_response = await embeddings_client.embed("test")
|
||||
test_embedding = test_embedding_response[0] # Extract from [[vector]]
|
||||
test_embedding_response = await embeddings_client.embed(["test"])
|
||||
test_embedding = test_embedding_response[0] # Extract first vector
|
||||
dimension = len(test_embedding)
|
||||
logger.info(f"Detected embedding dimension: {dimension}")
|
||||
|
||||
|
|
@ -306,15 +309,25 @@ class Processor(FlowProcessor):
|
|||
flow, chunk, ontology_subset, prompt_variables
|
||||
)
|
||||
|
||||
# Add metadata triples
|
||||
for t in v.metadata.metadata:
|
||||
triples.append(t)
|
||||
# Generate subgraph provenance for extracted triples
|
||||
if triples:
|
||||
chunk_uri = v.metadata.id
|
||||
sg_uri = subgraph_uri()
|
||||
prov_triples = subgraph_provenance_triples(
|
||||
subgraph_uri=sg_uri,
|
||||
extracted_triples=triples,
|
||||
chunk_uri=chunk_uri,
|
||||
component_name=default_ident,
|
||||
component_version=COMPONENT_VERSION,
|
||||
)
|
||||
|
||||
# Generate ontology definition triples
|
||||
ontology_triples = self.build_ontology_triples(ontology_subset)
|
||||
|
||||
# Combine extracted triples with ontology triples
|
||||
# Combine extracted triples with ontology triples and provenance
|
||||
all_triples = triples + ontology_triples
|
||||
if triples:
|
||||
all_triples.extend(set_graph(prov_triples, GRAPH_SOURCE))
|
||||
|
||||
# Build entity contexts from all triples (including ontology elements)
|
||||
entity_contexts = self.build_entity_contexts(all_triples)
|
||||
|
|
@ -558,7 +571,7 @@ class Processor(FlowProcessor):
|
|||
t = Triples(
|
||||
metadata=Metadata(
|
||||
id=metadata.id,
|
||||
metadata=[],
|
||||
root=metadata.root,
|
||||
user=metadata.user,
|
||||
collection=metadata.collection,
|
||||
),
|
||||
|
|
@ -571,7 +584,7 @@ class Processor(FlowProcessor):
|
|||
ec = EntityContexts(
|
||||
metadata=Metadata(
|
||||
id=metadata.id,
|
||||
metadata=[],
|
||||
root=metadata.root,
|
||||
user=metadata.user,
|
||||
collection=metadata.collection,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -153,16 +153,11 @@ class OntologyEmbedder:
|
|||
# Get embeddings for batch
|
||||
texts = [elem['text'] for elem in batch]
|
||||
try:
|
||||
# Call embedding service for each text
|
||||
# Note: embed() returns 2D array [[vector]], so extract first element
|
||||
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
|
||||
embeddings_responses = await asyncio.gather(*embedding_tasks)
|
||||
|
||||
# Extract vectors from responses (each is [[vector]])
|
||||
embeddings_list = [resp[0] for resp in embeddings_responses]
|
||||
# Single batch embedding call - returns list of vectors
|
||||
embeddings_response = await self.embedding_service.embed(texts)
|
||||
|
||||
# Convert to numpy array
|
||||
embeddings = np.array(embeddings_list)
|
||||
embeddings = np.array(embeddings_response)
|
||||
|
||||
# Log embedding shape for debugging
|
||||
logger.debug(f"Embeddings shape: {embeddings.shape}, expected: ({len(batch)}, {self.vector_store.dimension})")
|
||||
|
|
@ -218,8 +213,8 @@ class OntologyEmbedder:
|
|||
return None
|
||||
|
||||
try:
|
||||
# embed() returns 2D array [[vector]], extract first element
|
||||
embedding_response = await self.embedding_service.embed(text)
|
||||
# embed() with single text, extract first vector
|
||||
embedding_response = await self.embedding_service.embed([text])
|
||||
return np.array(embedding_response[0])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed text: {e}")
|
||||
|
|
@ -239,12 +234,9 @@ class OntologyEmbedder:
|
|||
return None
|
||||
|
||||
try:
|
||||
# Call embed() for each text (returns [[vector]] per call)
|
||||
embedding_tasks = [self.embedding_service.embed(text) for text in texts]
|
||||
embeddings_responses = await asyncio.gather(*embedding_tasks)
|
||||
# Extract first vector from each response
|
||||
embeddings_list = [resp[0] for resp in embeddings_responses]
|
||||
return np.array(embeddings_list)
|
||||
# Single batch embedding call - returns list of vectors
|
||||
embeddings_response = await self.embedding_service.embed(texts)
|
||||
return np.array(embeddings_response)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed texts: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -15,13 +15,15 @@ logger = logging.getLogger(__name__)
|
|||
from .... schema import Chunk, Triple, Triples
|
||||
from .... schema import Metadata, Term, IRI, LITERAL
|
||||
from .... schema import PromptRequest, PromptResponse
|
||||
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF
|
||||
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES
|
||||
|
||||
from .... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from .... base import PromptClientSpec
|
||||
from .... base import PromptClientSpec, ParameterSpec
|
||||
|
||||
from .... provenance import subgraph_uri, subgraph_provenance_triples, set_graph, GRAPH_SOURCE
|
||||
from .... flow_version import __version__ as COMPONENT_VERSION
|
||||
|
||||
RDF_LABEL_VALUE = Term(type=IRI, iri=RDF_LABEL)
|
||||
SUBJECT_OF_VALUE = Term(type=IRI, iri=SUBJECT_OF)
|
||||
|
||||
default_ident = "kg-extract-relationships"
|
||||
default_concurrency = 1
|
||||
|
|
@ -65,6 +67,10 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
# Optional flow parameters for provenance
|
||||
self.register_specification(ParameterSpec("llm-model"))
|
||||
self.register_specification(ParameterSpec("ontology"))
|
||||
|
||||
def to_uri(self, text):
|
||||
|
||||
part = text.replace(" ", "-").lower().encode("utf-8")
|
||||
|
|
@ -108,11 +114,18 @@ class Processor(FlowProcessor):
|
|||
raise e
|
||||
|
||||
triples = []
|
||||
extracted_triples = []
|
||||
|
||||
# FIXME: Putting metadata into triples store is duplicated in
|
||||
# relationships extractor too
|
||||
for t in v.metadata.metadata:
|
||||
triples.append(t)
|
||||
# Get chunk document ID for provenance linking
|
||||
chunk_doc_id = v.document_id if v.document_id else v.metadata.id
|
||||
chunk_uri = v.metadata.id # The URI form for the chunk
|
||||
|
||||
# Get optional provenance parameters
|
||||
llm_model = flow("llm-model")
|
||||
ontology_uri = flow("ontology")
|
||||
|
||||
# Note: Document metadata is now emitted once by librarian at processing
|
||||
# initiation, so we don't need to duplicate it here.
|
||||
|
||||
for rel in rels:
|
||||
|
||||
|
|
@ -140,11 +153,14 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
o_value = Term(type=LITERAL, value=str(o))
|
||||
|
||||
triples.append(Triple(
|
||||
# The relationship triple - this is the main extracted fact
|
||||
relationship_triple = Triple(
|
||||
s=s_value,
|
||||
p=p_value,
|
||||
o=o_value
|
||||
))
|
||||
)
|
||||
triples.append(relationship_triple)
|
||||
extracted_triples.append(relationship_triple)
|
||||
|
||||
# Label for s
|
||||
triples.append(Triple(
|
||||
|
|
@ -168,20 +184,19 @@ class Processor(FlowProcessor):
|
|||
o=Term(type=LITERAL, value=str(o))
|
||||
))
|
||||
|
||||
# 'Subject of' for s
|
||||
triples.append(Triple(
|
||||
s=s_value,
|
||||
p=SUBJECT_OF_VALUE,
|
||||
o=Term(type=IRI, iri=v.metadata.id)
|
||||
))
|
||||
|
||||
if rel["object-entity"]:
|
||||
# 'Subject of' for o
|
||||
triples.append(Triple(
|
||||
s=o_value,
|
||||
p=SUBJECT_OF_VALUE,
|
||||
o=Term(type=IRI, iri=v.metadata.id)
|
||||
))
|
||||
# Generate subgraph provenance once for all extracted triples
|
||||
if extracted_triples:
|
||||
sg_uri = subgraph_uri()
|
||||
prov_triples = subgraph_provenance_triples(
|
||||
subgraph_uri=sg_uri,
|
||||
extracted_triples=extracted_triples,
|
||||
chunk_uri=chunk_uri,
|
||||
component_name=default_ident,
|
||||
component_version=COMPONENT_VERSION,
|
||||
llm_model=llm_model,
|
||||
ontology_uri=ontology_uri,
|
||||
)
|
||||
triples.extend(set_graph(prov_triples, GRAPH_SOURCE))
|
||||
|
||||
# Send triples in batches
|
||||
for i in range(0, len(triples), self.triples_batch_size):
|
||||
|
|
@ -190,7 +205,7 @@ class Processor(FlowProcessor):
|
|||
flow("triples"),
|
||||
Metadata(
|
||||
id=v.metadata.id,
|
||||
metadata=[],
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -272,7 +272,7 @@ class Processor(FlowProcessor):
|
|||
extracted = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id=f"{v.metadata.id}:{schema_name}",
|
||||
metadata=[],
|
||||
root=v.metadata.root,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from . librarian import LibrarianRequestor
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentStreamExport:
|
||||
|
||||
def __init__(self, backend):
|
||||
self.backend = backend
|
||||
|
||||
async def process(self, data, error, ok, request):
|
||||
|
||||
user = request.query.get("user")
|
||||
document_id = request.query.get("document-id")
|
||||
chunk_size = int(request.query.get("chunk-size", 1024 * 1024))
|
||||
|
||||
if not user or not document_id:
|
||||
return await error("Missing required parameters: user, document-id")
|
||||
|
||||
response = await ok()
|
||||
|
||||
lr = LibrarianRequestor(
|
||||
backend=self.backend,
|
||||
consumer="api-gateway-doc-stream-" + str(uuid.uuid4()),
|
||||
subscriber="api-gateway-doc-stream-" + str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
await lr.start()
|
||||
|
||||
async def responder(resp, fin):
|
||||
if "content" in resp:
|
||||
content = resp["content"]
|
||||
# Content is base64 encoded, write as-is for client to decode
|
||||
# Or decode here and write raw bytes
|
||||
import base64
|
||||
chunk_data = base64.b64decode(content)
|
||||
await response.write(chunk_data)
|
||||
|
||||
await lr.process(
|
||||
{
|
||||
"operation": "stream-document",
|
||||
"user": user,
|
||||
"document-id": document_id,
|
||||
"chunk-size": chunk_size,
|
||||
},
|
||||
responder
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Document stream exception: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
|
||||
await lr.stop()
|
||||
|
||||
await response.write_eof()
|
||||
|
||||
return response
|
||||
|
|
@ -45,6 +45,7 @@ from . rows_import import RowsImport
|
|||
|
||||
from . core_export import CoreExport
|
||||
from . core_import import CoreImport
|
||||
from . document_stream import DocumentStreamExport
|
||||
|
||||
from . mux import Mux
|
||||
|
||||
|
|
@ -135,6 +136,14 @@ class DispatcherManager:
|
|||
def dispatch_core_import(self):
|
||||
return DispatcherWrapper(self.process_core_import)
|
||||
|
||||
def dispatch_document_stream(self):
|
||||
return DispatcherWrapper(self.process_document_stream)
|
||||
|
||||
async def process_document_stream(self, data, error, ok, request):
|
||||
|
||||
ds = DocumentStreamExport(self.backend)
|
||||
return await ds.process(data, error, ok, request)
|
||||
|
||||
async def process_core_import(self, data, error, ok, request):
|
||||
|
||||
ci = CoreImport(self.backend)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@ class RowsImport:
|
|||
elt = ExtractedObject(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"].get("metadata", [])),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
|
|
|
|||
|
|
@ -37,35 +37,37 @@ def serialize_triples(message):
|
|||
return {
|
||||
"metadata": {
|
||||
"id": message.metadata.id,
|
||||
"metadata": serialize_subgraph(message.metadata.metadata),
|
||||
"root": message.metadata.root,
|
||||
"user": message.metadata.user,
|
||||
"collection": message.metadata.collection,
|
||||
},
|
||||
"triples": serialize_subgraph(message.triples),
|
||||
}
|
||||
|
||||
|
||||
|
||||
def serialize_graph_embeddings(message):
|
||||
return {
|
||||
"metadata": {
|
||||
"id": message.metadata.id,
|
||||
"metadata": serialize_subgraph(message.metadata.metadata),
|
||||
"root": message.metadata.root,
|
||||
"user": message.metadata.user,
|
||||
"collection": message.metadata.collection,
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"vectors": entity.vectors,
|
||||
"vector": entity.vector,
|
||||
"entity": serialize_value(entity.entity),
|
||||
}
|
||||
for entity in message.entities
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def serialize_entity_contexts(message):
|
||||
return {
|
||||
"metadata": {
|
||||
"id": message.metadata.id,
|
||||
"metadata": serialize_subgraph(message.metadata.metadata),
|
||||
"root": message.metadata.root,
|
||||
"user": message.metadata.user,
|
||||
"collection": message.metadata.collection,
|
||||
},
|
||||
|
|
@ -78,18 +80,19 @@ def serialize_entity_contexts(message):
|
|||
],
|
||||
}
|
||||
|
||||
|
||||
def serialize_document_embeddings(message):
|
||||
return {
|
||||
"metadata": {
|
||||
"id": message.metadata.id,
|
||||
"metadata": serialize_subgraph(message.metadata.metadata),
|
||||
"root": message.metadata.root,
|
||||
"user": message.metadata.user,
|
||||
"collection": message.metadata.collection,
|
||||
},
|
||||
"chunks": [
|
||||
{
|
||||
"vectors": chunk.vectors,
|
||||
"chunk": chunk.chunk.decode("utf-8"),
|
||||
"vector": chunk.vector,
|
||||
"chunk_id": chunk.chunk_id,
|
||||
}
|
||||
for chunk in message.chunks
|
||||
],
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class TriplesImport:
|
|||
elt = Triples(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
root=data["metadata"].get("root", ""),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
|
|
|
|||
|
|
@ -64,6 +64,12 @@ class EndpointManager:
|
|||
method = "GET",
|
||||
dispatcher = dispatcher_manager.dispatch_core_export(),
|
||||
),
|
||||
StreamEndpoint(
|
||||
endpoint_path = "/api/v1/document-stream",
|
||||
auth = auth,
|
||||
method = "GET",
|
||||
dispatcher = dispatcher_manager.dispatch_document_stream(),
|
||||
),
|
||||
]
|
||||
|
||||
def add_routes(self, app):
|
||||
|
|
|
|||
|
|
@ -3,9 +3,12 @@ from .. knowledge import hash
|
|||
from .. exceptions import RequestError
|
||||
|
||||
from minio import Minio
|
||||
from minio.datatypes import Part
|
||||
import time
|
||||
import io
|
||||
import logging
|
||||
from typing import Iterator, List, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -78,3 +81,163 @@ class BlobStore:
|
|||
|
||||
return resp.read()
|
||||
|
||||
async def get_range(self, object_id, offset: int, length: int) -> bytes:
|
||||
"""Fetch a specific byte range from an object."""
|
||||
resp = self.client.get_object(
|
||||
bucket_name=self.bucket_name,
|
||||
object_name="doc/" + str(object_id),
|
||||
offset=offset,
|
||||
length=length,
|
||||
)
|
||||
try:
|
||||
return resp.read()
|
||||
finally:
|
||||
resp.close()
|
||||
resp.release_conn()
|
||||
|
||||
async def get_size(self, object_id) -> int:
|
||||
"""Get the size of an object without downloading it."""
|
||||
stat = self.client.stat_object(
|
||||
bucket_name=self.bucket_name,
|
||||
object_name="doc/" + str(object_id),
|
||||
)
|
||||
return stat.size
|
||||
|
||||
def get_stream(self, object_id, chunk_size: int = 1024 * 1024) -> Iterator[bytes]:
|
||||
"""
|
||||
Stream document content in chunks.
|
||||
|
||||
Yields chunks of the document, allowing processing without loading
|
||||
the entire document into memory.
|
||||
|
||||
Args:
|
||||
object_id: The UUID of the document object
|
||||
chunk_size: Size of each chunk in bytes (default 1MB)
|
||||
|
||||
Yields:
|
||||
Chunks of document content as bytes
|
||||
"""
|
||||
resp = self.client.get_object(
|
||||
bucket_name=self.bucket_name,
|
||||
object_name="doc/" + str(object_id),
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
chunk = resp.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
finally:
|
||||
resp.close()
|
||||
resp.release_conn()
|
||||
|
||||
logger.debug("Stream complete")
|
||||
|
||||
def create_multipart_upload(self, object_id: UUID, kind: str) -> str:
|
||||
"""
|
||||
Initialize a multipart upload.
|
||||
|
||||
Args:
|
||||
object_id: The UUID for the new object
|
||||
kind: MIME type of the document
|
||||
|
||||
Returns:
|
||||
The S3 upload_id for this multipart upload session
|
||||
"""
|
||||
object_name = "doc/" + str(object_id)
|
||||
|
||||
# Use minio's internal method to create multipart upload
|
||||
upload_id = self.client._create_multipart_upload(
|
||||
bucket_name=self.bucket_name,
|
||||
object_name=object_name,
|
||||
headers={"Content-Type": kind},
|
||||
)
|
||||
|
||||
logger.info(f"Created multipart upload {upload_id} for {object_id}")
|
||||
return upload_id
|
||||
|
||||
def upload_part(
|
||||
self,
|
||||
object_id: UUID,
|
||||
upload_id: str,
|
||||
part_number: int,
|
||||
data: bytes
|
||||
) -> str:
|
||||
"""
|
||||
Upload a single part of a multipart upload.
|
||||
|
||||
Args:
|
||||
object_id: The UUID of the object being uploaded
|
||||
upload_id: The S3 upload_id from create_multipart_upload
|
||||
part_number: Part number (1-indexed, as per S3 spec)
|
||||
data: The chunk data to upload
|
||||
|
||||
Returns:
|
||||
The ETag for this part (needed for complete_multipart_upload)
|
||||
"""
|
||||
object_name = "doc/" + str(object_id)
|
||||
|
||||
etag = self.client._upload_part(
|
||||
bucket_name=self.bucket_name,
|
||||
object_name=object_name,
|
||||
data=data,
|
||||
headers={"Content-Length": str(len(data))},
|
||||
upload_id=upload_id,
|
||||
part_number=part_number,
|
||||
)
|
||||
|
||||
logger.debug(f"Uploaded part {part_number} for {object_id}, etag={etag}")
|
||||
return etag
|
||||
|
||||
def complete_multipart_upload(
|
||||
self,
|
||||
object_id: UUID,
|
||||
upload_id: str,
|
||||
parts: List[Tuple[int, str]]
|
||||
) -> None:
|
||||
"""
|
||||
Complete a multipart upload, assembling all parts into the final object.
|
||||
|
||||
S3 coalesces the parts server-side - no data transfer through this client.
|
||||
|
||||
Args:
|
||||
object_id: The UUID of the object
|
||||
upload_id: The S3 upload_id from create_multipart_upload
|
||||
parts: List of (part_number, etag) tuples in order
|
||||
"""
|
||||
object_name = "doc/" + str(object_id)
|
||||
|
||||
# Convert to Part objects as expected by minio
|
||||
part_objects = [
|
||||
Part(part_number, etag)
|
||||
for part_number, etag in parts
|
||||
]
|
||||
|
||||
self.client._complete_multipart_upload(
|
||||
bucket_name=self.bucket_name,
|
||||
object_name=object_name,
|
||||
upload_id=upload_id,
|
||||
parts=part_objects,
|
||||
)
|
||||
|
||||
logger.info(f"Completed multipart upload for {object_id}")
|
||||
|
||||
def abort_multipart_upload(self, object_id: UUID, upload_id: str) -> None:
|
||||
"""
|
||||
Abort a multipart upload, cleaning up any uploaded parts.
|
||||
|
||||
Args:
|
||||
object_id: The UUID of the object
|
||||
upload_id: The S3 upload_id from create_multipart_upload
|
||||
"""
|
||||
object_name = "doc/" + str(object_id)
|
||||
|
||||
self.client._abort_multipart_upload(
|
||||
bucket_name=self.bucket_name,
|
||||
object_name=object_name,
|
||||
upload_id=upload_id,
|
||||
)
|
||||
|
||||
logger.info(f"Aborted multipart upload {upload_id} for {object_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,24 @@
|
|||
|
||||
from .. schema import LibrarianRequest, LibrarianResponse, Error, Triple
|
||||
from .. schema import UploadSession
|
||||
from .. knowledge import hash
|
||||
from .. exceptions import RequestError
|
||||
from .. tables.library import LibraryTableStore
|
||||
from . blob_store import BlobStore
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
|
||||
import uuid
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default chunk size for multipart uploads
|
||||
DEFAULT_CHUNK_SIZE = 2 * 1024 * 1024 # 2MB default
|
||||
|
||||
class Librarian:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -20,6 +27,7 @@ class Librarian:
|
|||
object_store_endpoint, object_store_access_key, object_store_secret_key,
|
||||
bucket_name, keyspace, load_document,
|
||||
object_store_use_ssl=False, object_store_region=None,
|
||||
min_chunk_size=1, # Default: no minimum (for Garage)
|
||||
):
|
||||
|
||||
self.blob_store = BlobStore(
|
||||
|
|
@ -32,6 +40,7 @@ class Librarian:
|
|||
)
|
||||
|
||||
self.load_document = load_document
|
||||
self.min_chunk_size = min_chunk_size
|
||||
|
||||
async def add_document(self, request):
|
||||
|
||||
|
|
@ -66,13 +75,7 @@ class Librarian:
|
|||
|
||||
logger.debug("Add complete")
|
||||
|
||||
return LibrarianResponse(
|
||||
error = None,
|
||||
document_metadata = None,
|
||||
content = None,
|
||||
document_metadatas = None,
|
||||
processing_metadatas = None,
|
||||
)
|
||||
return LibrarianResponse()
|
||||
|
||||
async def remove_document(self, request):
|
||||
|
||||
|
|
@ -84,6 +87,21 @@ class Librarian:
|
|||
):
|
||||
raise RuntimeError("Document does not exist")
|
||||
|
||||
# First, cascade delete all child documents
|
||||
children = await self.table_store.list_children(request.document_id)
|
||||
for child in children:
|
||||
logger.debug(f"Cascade deleting child document {child.id}")
|
||||
try:
|
||||
child_object_id = await self.table_store.get_document_object_id(
|
||||
child.user,
|
||||
child.id
|
||||
)
|
||||
await self.blob_store.remove(child_object_id)
|
||||
await self.table_store.remove_document(child.user, child.id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete child document {child.id}: {e}")
|
||||
|
||||
# Now remove the parent document
|
||||
object_id = await self.table_store.get_document_object_id(
|
||||
request.user,
|
||||
request.document_id
|
||||
|
|
@ -100,13 +118,7 @@ class Librarian:
|
|||
|
||||
logger.debug("Remove complete")
|
||||
|
||||
return LibrarianResponse(
|
||||
error = None,
|
||||
document_metadata = None,
|
||||
content = None,
|
||||
document_metadatas = None,
|
||||
processing_metadatas = None,
|
||||
)
|
||||
return LibrarianResponse()
|
||||
|
||||
async def update_document(self, request):
|
||||
|
||||
|
|
@ -124,13 +136,7 @@ class Librarian:
|
|||
|
||||
logger.debug("Update complete")
|
||||
|
||||
return LibrarianResponse(
|
||||
error = None,
|
||||
document_metadata = None,
|
||||
content = None,
|
||||
document_metadatas = None,
|
||||
processing_metadatas = None,
|
||||
)
|
||||
return LibrarianResponse()
|
||||
|
||||
async def get_document_metadata(self, request):
|
||||
|
||||
|
|
@ -147,8 +153,6 @@ class Librarian:
|
|||
error = None,
|
||||
document_metadata = doc,
|
||||
content = None,
|
||||
document_metadatas = None,
|
||||
processing_metadatas = None,
|
||||
)
|
||||
|
||||
async def get_document_content(self, request):
|
||||
|
|
@ -170,8 +174,6 @@ class Librarian:
|
|||
error = None,
|
||||
document_metadata = None,
|
||||
content = base64.b64encode(content),
|
||||
document_metadatas = None,
|
||||
processing_metadatas = None,
|
||||
)
|
||||
|
||||
async def add_processing(self, request):
|
||||
|
|
@ -217,13 +219,7 @@ class Librarian:
|
|||
|
||||
logger.debug("Add complete")
|
||||
|
||||
return LibrarianResponse(
|
||||
error = None,
|
||||
document_metadata = None,
|
||||
content = None,
|
||||
document_metadatas = None,
|
||||
processing_metadatas = None,
|
||||
)
|
||||
return LibrarianResponse()
|
||||
|
||||
async def remove_processing(self, request):
|
||||
|
||||
|
|
@ -243,24 +239,23 @@ class Librarian:
|
|||
|
||||
logger.debug("Remove complete")
|
||||
|
||||
return LibrarianResponse(
|
||||
error = None,
|
||||
document_metadata = None,
|
||||
content = None,
|
||||
document_metadatas = None,
|
||||
processing_metadatas = None,
|
||||
)
|
||||
return LibrarianResponse()
|
||||
|
||||
async def list_documents(self, request):
|
||||
|
||||
docs = await self.table_store.list_documents(request.user)
|
||||
|
||||
# Filter out child documents and answer documents by default
|
||||
include_children = getattr(request, 'include_children', False)
|
||||
if not include_children:
|
||||
docs = [
|
||||
doc for doc in docs
|
||||
if not doc.parent_id # Only include top-level documents
|
||||
and doc.document_type != "answer" # Exclude GraphRAG answers
|
||||
]
|
||||
|
||||
return LibrarianResponse(
|
||||
error = None,
|
||||
document_metadata = None,
|
||||
content = None,
|
||||
document_metadatas = docs,
|
||||
processing_metadatas = None,
|
||||
)
|
||||
|
||||
async def list_processing(self, request):
|
||||
|
|
@ -268,10 +263,444 @@ class Librarian:
|
|||
procs = await self.table_store.list_processing(request.user)
|
||||
|
||||
return LibrarianResponse(
|
||||
error = None,
|
||||
document_metadata = None,
|
||||
content = None,
|
||||
document_metadatas = None,
|
||||
processing_metadatas = procs,
|
||||
)
|
||||
|
||||
# Chunked upload operations
|
||||
|
||||
async def begin_upload(self, request):
|
||||
"""
|
||||
Initialize a chunked upload session.
|
||||
|
||||
Creates an S3 multipart upload and stores session state in Cassandra.
|
||||
"""
|
||||
logger.info(f"Beginning chunked upload for document {request.document_metadata.id}")
|
||||
|
||||
if request.document_metadata.kind not in ("text/plain", "application/pdf"):
|
||||
raise RequestError(
|
||||
"Invalid document kind: " + request.document_metadata.kind
|
||||
)
|
||||
|
||||
if await self.table_store.document_exists(
|
||||
request.document_metadata.user,
|
||||
request.document_metadata.id
|
||||
):
|
||||
raise RequestError("Document already exists")
|
||||
|
||||
# Validate sizes
|
||||
total_size = request.total_size
|
||||
if total_size <= 0:
|
||||
raise RequestError("total_size must be positive")
|
||||
|
||||
# Use provided chunk size or default
|
||||
chunk_size = request.chunk_size if request.chunk_size > 0 else DEFAULT_CHUNK_SIZE
|
||||
if chunk_size < self.min_chunk_size:
|
||||
raise RequestError(
|
||||
f"Chunk size {chunk_size} is below minimum {self.min_chunk_size}"
|
||||
)
|
||||
|
||||
# Calculate total chunks
|
||||
total_chunks = math.ceil(total_size / chunk_size)
|
||||
|
||||
# Generate IDs
|
||||
upload_id = str(uuid.uuid4())
|
||||
object_id = uuid.uuid4()
|
||||
|
||||
# Create S3 multipart upload
|
||||
s3_upload_id = self.blob_store.create_multipart_upload(
|
||||
object_id, request.document_metadata.kind
|
||||
)
|
||||
|
||||
# Serialize document metadata for storage
|
||||
doc_meta_json = json.dumps({
|
||||
"id": request.document_metadata.id,
|
||||
"time": request.document_metadata.time,
|
||||
"kind": request.document_metadata.kind,
|
||||
"title": request.document_metadata.title,
|
||||
"comments": request.document_metadata.comments,
|
||||
"user": request.document_metadata.user,
|
||||
"tags": request.document_metadata.tags,
|
||||
})
|
||||
|
||||
# Store session in Cassandra
|
||||
await self.table_store.create_upload_session(
|
||||
upload_id=upload_id,
|
||||
user=request.document_metadata.user,
|
||||
document_id=request.document_metadata.id,
|
||||
document_metadata=doc_meta_json,
|
||||
s3_upload_id=s3_upload_id,
|
||||
object_id=object_id,
|
||||
total_size=total_size,
|
||||
chunk_size=chunk_size,
|
||||
total_chunks=total_chunks,
|
||||
)
|
||||
|
||||
logger.info(f"Created upload session {upload_id} with {total_chunks} chunks")
|
||||
|
||||
return LibrarianResponse(
|
||||
error=None,
|
||||
upload_id=upload_id,
|
||||
chunk_size=chunk_size,
|
||||
total_chunks=total_chunks,
|
||||
)
|
||||
|
||||
async def upload_chunk(self, request):
|
||||
"""
|
||||
Upload a single chunk of a document.
|
||||
|
||||
Forwards the chunk to S3 and updates session state.
|
||||
"""
|
||||
logger.debug(f"Uploading chunk {request.chunk_index} for upload {request.upload_id}")
|
||||
|
||||
# Get session
|
||||
session = await self.table_store.get_upload_session(request.upload_id)
|
||||
if session is None:
|
||||
raise RequestError("Upload session not found or expired")
|
||||
|
||||
# Validate ownership
|
||||
if session["user"] != request.user:
|
||||
raise RequestError("Not authorized to upload to this session")
|
||||
|
||||
# Validate chunk index
|
||||
if request.chunk_index < 0 or request.chunk_index >= session["total_chunks"]:
|
||||
raise RequestError(
|
||||
f"Invalid chunk index {request.chunk_index}, "
|
||||
f"must be 0-{session['total_chunks']-1}"
|
||||
)
|
||||
|
||||
# Decode content
|
||||
content = base64.b64decode(request.content)
|
||||
|
||||
# Upload to S3 (part numbers are 1-indexed in S3)
|
||||
part_number = request.chunk_index + 1
|
||||
etag = self.blob_store.upload_part(
|
||||
object_id=session["object_id"],
|
||||
upload_id=session["s3_upload_id"],
|
||||
part_number=part_number,
|
||||
data=content,
|
||||
)
|
||||
|
||||
# Update session with chunk info
|
||||
await self.table_store.update_upload_session_chunk(
|
||||
upload_id=request.upload_id,
|
||||
chunk_index=request.chunk_index,
|
||||
etag=etag,
|
||||
)
|
||||
|
||||
# Calculate progress
|
||||
chunks_received = session["chunks_received"]
|
||||
# Add this chunk if not already present
|
||||
if request.chunk_index not in chunks_received:
|
||||
chunks_received[request.chunk_index] = etag
|
||||
|
||||
num_chunks_received = len(chunks_received) + 1 # +1 for this chunk
|
||||
bytes_received = num_chunks_received * session["chunk_size"]
|
||||
# Adjust for last chunk potentially being smaller
|
||||
if bytes_received > session["total_size"]:
|
||||
bytes_received = session["total_size"]
|
||||
|
||||
logger.debug(f"Chunk {request.chunk_index} uploaded, {num_chunks_received}/{session['total_chunks']} complete")
|
||||
|
||||
return LibrarianResponse(
|
||||
error=None,
|
||||
upload_id=request.upload_id,
|
||||
chunk_index=request.chunk_index,
|
||||
chunks_received=num_chunks_received,
|
||||
total_chunks=session["total_chunks"],
|
||||
bytes_received=bytes_received,
|
||||
total_bytes=session["total_size"],
|
||||
)
|
||||
|
||||
async def complete_upload(self, request):
|
||||
"""
|
||||
Finalize a chunked upload and create the document.
|
||||
|
||||
Completes the S3 multipart upload and creates the document metadata.
|
||||
"""
|
||||
logger.info(f"Completing upload {request.upload_id}")
|
||||
|
||||
# Get session
|
||||
session = await self.table_store.get_upload_session(request.upload_id)
|
||||
if session is None:
|
||||
raise RequestError("Upload session not found or expired")
|
||||
|
||||
# Validate ownership
|
||||
if session["user"] != request.user:
|
||||
raise RequestError("Not authorized to complete this upload")
|
||||
|
||||
# Verify all chunks received
|
||||
chunks_received = session["chunks_received"]
|
||||
if len(chunks_received) != session["total_chunks"]:
|
||||
missing = [
|
||||
i for i in range(session["total_chunks"])
|
||||
if i not in chunks_received
|
||||
]
|
||||
raise RequestError(
|
||||
f"Missing chunks: {missing[:10]}{'...' if len(missing) > 10 else ''}"
|
||||
)
|
||||
|
||||
# Build parts list for S3 (sorted by part number)
|
||||
parts = [
|
||||
(chunk_index + 1, etag) # S3 part numbers are 1-indexed
|
||||
for chunk_index, etag in sorted(chunks_received.items())
|
||||
]
|
||||
|
||||
# Complete S3 multipart upload
|
||||
self.blob_store.complete_multipart_upload(
|
||||
object_id=session["object_id"],
|
||||
upload_id=session["s3_upload_id"],
|
||||
parts=parts,
|
||||
)
|
||||
|
||||
# Parse document metadata from session
|
||||
doc_meta_dict = json.loads(session["document_metadata"])
|
||||
|
||||
# Create DocumentMetadata object
|
||||
from .. schema import DocumentMetadata
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_meta_dict["id"],
|
||||
time=doc_meta_dict.get("time", int(time.time())),
|
||||
kind=doc_meta_dict["kind"],
|
||||
title=doc_meta_dict.get("title", ""),
|
||||
comments=doc_meta_dict.get("comments", ""),
|
||||
user=doc_meta_dict["user"],
|
||||
tags=doc_meta_dict.get("tags", []),
|
||||
metadata=[], # Triples not supported in chunked upload yet
|
||||
)
|
||||
|
||||
# Add document to table
|
||||
await self.table_store.add_document(doc_metadata, session["object_id"])
|
||||
|
||||
# Delete upload session
|
||||
await self.table_store.delete_upload_session(request.upload_id)
|
||||
|
||||
logger.info(f"Upload {request.upload_id} completed, document {doc_metadata.id} created")
|
||||
|
||||
return LibrarianResponse(
|
||||
error=None,
|
||||
document_id=doc_metadata.id,
|
||||
object_id=str(session["object_id"]),
|
||||
)
|
||||
|
||||
async def abort_upload(self, request):
|
||||
"""
|
||||
Cancel a chunked upload and clean up resources.
|
||||
"""
|
||||
logger.info(f"Aborting upload {request.upload_id}")
|
||||
|
||||
# Get session
|
||||
session = await self.table_store.get_upload_session(request.upload_id)
|
||||
if session is None:
|
||||
raise RequestError("Upload session not found or expired")
|
||||
|
||||
# Validate ownership
|
||||
if session["user"] != request.user:
|
||||
raise RequestError("Not authorized to abort this upload")
|
||||
|
||||
# Abort S3 multipart upload
|
||||
self.blob_store.abort_multipart_upload(
|
||||
object_id=session["object_id"],
|
||||
upload_id=session["s3_upload_id"],
|
||||
)
|
||||
|
||||
# Delete session from Cassandra
|
||||
await self.table_store.delete_upload_session(request.upload_id)
|
||||
|
||||
logger.info(f"Upload {request.upload_id} aborted")
|
||||
|
||||
return LibrarianResponse(error=None)
|
||||
|
||||
async def get_upload_status(self, request):
|
||||
"""
|
||||
Get the status of an in-progress upload.
|
||||
"""
|
||||
logger.debug(f"Getting status for upload {request.upload_id}")
|
||||
|
||||
# Get session
|
||||
session = await self.table_store.get_upload_session(request.upload_id)
|
||||
if session is None:
|
||||
return LibrarianResponse(
|
||||
error=None,
|
||||
upload_id=request.upload_id,
|
||||
upload_state="expired",
|
||||
)
|
||||
|
||||
# Validate ownership
|
||||
if session["user"] != request.user:
|
||||
raise RequestError("Not authorized to view this upload")
|
||||
|
||||
chunks_received = session["chunks_received"]
|
||||
received_list = sorted(chunks_received.keys())
|
||||
missing_list = [
|
||||
i for i in range(session["total_chunks"])
|
||||
if i not in chunks_received
|
||||
]
|
||||
|
||||
bytes_received = len(chunks_received) * session["chunk_size"]
|
||||
if bytes_received > session["total_size"]:
|
||||
bytes_received = session["total_size"]
|
||||
|
||||
return LibrarianResponse(
|
||||
error=None,
|
||||
upload_id=request.upload_id,
|
||||
upload_state="in-progress",
|
||||
received_chunks=received_list,
|
||||
missing_chunks=missing_list,
|
||||
chunks_received=len(chunks_received),
|
||||
total_chunks=session["total_chunks"],
|
||||
bytes_received=bytes_received,
|
||||
total_bytes=session["total_size"],
|
||||
)
|
||||
|
||||
async def list_uploads(self, request):
|
||||
"""
|
||||
List all in-progress uploads for a user.
|
||||
"""
|
||||
logger.debug(f"Listing uploads for user {request.user}")
|
||||
|
||||
sessions = await self.table_store.list_upload_sessions(request.user)
|
||||
|
||||
upload_sessions = [
|
||||
UploadSession(
|
||||
upload_id=s["upload_id"],
|
||||
document_id=s["document_id"],
|
||||
document_metadata_json=s.get("document_metadata", ""),
|
||||
total_size=s["total_size"],
|
||||
chunk_size=s["chunk_size"],
|
||||
total_chunks=s["total_chunks"],
|
||||
chunks_received=s["chunks_received"],
|
||||
created_at=str(s.get("created_at", "")),
|
||||
)
|
||||
for s in sessions
|
||||
]
|
||||
|
||||
return LibrarianResponse(
|
||||
error=None,
|
||||
upload_sessions=upload_sessions,
|
||||
)
|
||||
|
||||
# Child document operations
|
||||
|
||||
async def add_child_document(self, request):
|
||||
"""
|
||||
Add a child document linked to a parent document.
|
||||
|
||||
Child documents are typically extracted content (e.g., pages from a PDF).
|
||||
They have a parent_id pointing to the source document and document_type
|
||||
set to "extracted".
|
||||
"""
|
||||
logger.info(f"Adding child document {request.document_metadata.id} "
|
||||
f"for parent {request.document_metadata.parent_id}")
|
||||
|
||||
if not request.document_metadata.parent_id:
|
||||
raise RequestError("parent_id is required for child documents")
|
||||
|
||||
# Verify parent exists
|
||||
if not await self.table_store.document_exists(
|
||||
request.document_metadata.user,
|
||||
request.document_metadata.parent_id
|
||||
):
|
||||
raise RequestError(
|
||||
f"Parent document {request.document_metadata.parent_id} does not exist"
|
||||
)
|
||||
|
||||
if await self.table_store.document_exists(
|
||||
request.document_metadata.user,
|
||||
request.document_metadata.id
|
||||
):
|
||||
raise RequestError("Document already exists")
|
||||
|
||||
# Set document_type if not specified by caller
|
||||
# Valid types: "page", "chunk", or "extracted" (legacy)
|
||||
if not request.document_metadata.document_type or request.document_metadata.document_type == "source":
|
||||
request.document_metadata.document_type = "extracted"
|
||||
|
||||
# Create object ID for blob
|
||||
object_id = uuid.uuid4()
|
||||
|
||||
logger.debug("Adding blob...")
|
||||
|
||||
await self.blob_store.add(
|
||||
object_id, base64.b64decode(request.content),
|
||||
request.document_metadata.kind
|
||||
)
|
||||
|
||||
logger.debug("Adding to table...")
|
||||
|
||||
await self.table_store.add_document(
|
||||
request.document_metadata, object_id
|
||||
)
|
||||
|
||||
logger.debug("Add child document complete")
|
||||
|
||||
return LibrarianResponse(
|
||||
error=None,
|
||||
document_id=request.document_metadata.id,
|
||||
)
|
||||
|
||||
async def list_children(self, request):
|
||||
"""
|
||||
List all child documents for a given parent document.
|
||||
"""
|
||||
logger.debug(f"Listing children for parent {request.document_id}")
|
||||
|
||||
children = await self.table_store.list_children(request.document_id)
|
||||
|
||||
return LibrarianResponse(
|
||||
error=None,
|
||||
document_metadatas=children,
|
||||
)
|
||||
|
||||
async def stream_document(self, request):
|
||||
"""
|
||||
Stream document content in chunks.
|
||||
|
||||
This is an async generator that yields document content in smaller chunks,
|
||||
allowing memory-efficient processing of large documents. Each yielded
|
||||
response includes chunk_index and total_chunks for tracking progress.
|
||||
Completion is determined by chunk_index reaching total_chunks - 1.
|
||||
"""
|
||||
logger.debug(f"Streaming document {request.document_id}")
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1024 * 1024 # 1MB default
|
||||
|
||||
chunk_size = request.chunk_size if request.chunk_size > 0 else DEFAULT_CHUNK_SIZE
|
||||
if chunk_size < self.min_chunk_size:
|
||||
raise RequestError(
|
||||
f"Chunk size {chunk_size} is below minimum {self.min_chunk_size}"
|
||||
)
|
||||
|
||||
object_id = await self.table_store.get_document_object_id(
|
||||
request.user,
|
||||
request.document_id
|
||||
)
|
||||
|
||||
# Get size via stat (no content download)
|
||||
total_size = await self.blob_store.get_size(object_id)
|
||||
total_chunks = math.ceil(total_size / chunk_size)
|
||||
|
||||
# Stream all chunks
|
||||
for chunk_index in range(total_chunks):
|
||||
# Calculate byte range
|
||||
offset = chunk_index * chunk_size
|
||||
length = min(chunk_size, total_size - offset)
|
||||
|
||||
# Fetch only the requested range
|
||||
chunk_content = await self.blob_store.get_range(object_id, offset, length)
|
||||
|
||||
is_last = (chunk_index == total_chunks - 1)
|
||||
|
||||
logger.debug(f"Streaming chunk {chunk_index + 1}/{total_chunks}, "
|
||||
f"bytes {offset}-{offset + length} of {total_size}")
|
||||
|
||||
yield LibrarianResponse(
|
||||
error=None,
|
||||
content=base64.b64encode(chunk_content),
|
||||
chunk_index=chunk_index,
|
||||
chunks_received=chunk_index + 1,
|
||||
total_chunks=total_chunks,
|
||||
bytes_received=offset + length,
|
||||
total_bytes=total_size,
|
||||
is_final=is_last,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,9 +23,14 @@ from .. schema import config_request_queue, config_response_queue
|
|||
|
||||
from .. schema import Document, Metadata
|
||||
from .. schema import TextDocument, Metadata
|
||||
from .. schema import Triples
|
||||
|
||||
from .. exceptions import RequestError
|
||||
|
||||
from .. provenance import (
|
||||
document_uri, document_triples, get_vocabulary_triples,
|
||||
)
|
||||
|
||||
from . librarian import Librarian
|
||||
from . collection_manager import CollectionManager
|
||||
|
||||
|
|
@ -47,6 +52,7 @@ default_object_store_secret_key = "object-password"
|
|||
default_object_store_use_ssl = False
|
||||
default_object_store_region = None
|
||||
default_cassandra_host = "cassandra"
|
||||
default_min_chunk_size = 1 # No minimum by default (for Garage)
|
||||
|
||||
bucket_name = "library"
|
||||
|
||||
|
|
@ -100,6 +106,11 @@ class Processor(AsyncProcessor):
|
|||
default_object_store_region
|
||||
)
|
||||
|
||||
min_chunk_size = params.get(
|
||||
"min_chunk_size",
|
||||
default_min_chunk_size
|
||||
)
|
||||
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
cassandra_username = params.get("cassandra_username")
|
||||
cassandra_password = params.get("cassandra_password")
|
||||
|
|
@ -226,6 +237,7 @@ class Processor(AsyncProcessor):
|
|||
load_document = self.load_document,
|
||||
object_store_use_ssl = object_store_use_ssl,
|
||||
object_store_region = object_store_region,
|
||||
min_chunk_size = min_chunk_size,
|
||||
)
|
||||
|
||||
self.collection_manager = CollectionManager(
|
||||
|
|
@ -271,6 +283,70 @@ class Processor(AsyncProcessor):
|
|||
|
||||
pass
|
||||
|
||||
# Threshold for sending document_id instead of inline content (2MB)
|
||||
STREAMING_THRESHOLD = 2 * 1024 * 1024
|
||||
|
||||
async def emit_document_provenance(self, document, processing, triples_queue):
|
||||
"""
|
||||
Emit document provenance metadata to the knowledge graph.
|
||||
|
||||
This emits:
|
||||
1. Vocabulary bootstrap triples (idempotent, safe to re-emit)
|
||||
2. Document metadata as PROV-O triples
|
||||
"""
|
||||
logger.debug(f"Emitting document provenance for {document.id}")
|
||||
|
||||
# Build document URI and provenance triples
|
||||
doc_uri = document_uri(document.id)
|
||||
|
||||
# Get page count for PDFs (if available from document metadata)
|
||||
page_count = None
|
||||
if document.kind == "application/pdf":
|
||||
# Page count might be in document metadata triples
|
||||
# For now, we don't have it at this point - it gets determined during extraction
|
||||
pass
|
||||
|
||||
# Build document metadata triples
|
||||
prov_triples = document_triples(
|
||||
doc_uri=doc_uri,
|
||||
title=document.title if document.title else None,
|
||||
mime_type=document.kind,
|
||||
)
|
||||
|
||||
# Include any existing metadata triples from the document
|
||||
if document.metadata:
|
||||
prov_triples.extend(document.metadata)
|
||||
|
||||
# Get vocabulary bootstrap triples (idempotent)
|
||||
vocab_triples = get_vocabulary_triples()
|
||||
|
||||
# Combine all triples
|
||||
all_triples = vocab_triples + prov_triples
|
||||
|
||||
# Create publisher and emit
|
||||
triples_pub = Publisher(
|
||||
self.pubsub, triples_queue, schema=Triples
|
||||
)
|
||||
|
||||
try:
|
||||
await triples_pub.start()
|
||||
|
||||
triples_msg = Triples(
|
||||
metadata=Metadata(
|
||||
id=doc_uri,
|
||||
root=document.id,
|
||||
user=processing.user,
|
||||
collection=processing.collection,
|
||||
),
|
||||
triples=all_triples,
|
||||
)
|
||||
|
||||
await triples_pub.send(None, triples_msg)
|
||||
logger.debug(f"Emitted {len(all_triples)} provenance triples for {document.id}")
|
||||
|
||||
finally:
|
||||
await triples_pub.stop()
|
||||
|
||||
async def load_document(self, document, processing, content):
|
||||
|
||||
logger.debug("Ready for document processing...")
|
||||
|
|
@ -291,27 +367,64 @@ class Processor(AsyncProcessor):
|
|||
|
||||
q = flow["interfaces"][kind]
|
||||
|
||||
if kind == "text-load":
|
||||
doc = TextDocument(
|
||||
metadata = Metadata(
|
||||
id = document.id,
|
||||
metadata = document.metadata,
|
||||
user = processing.user,
|
||||
collection = processing.collection
|
||||
),
|
||||
text = content,
|
||||
# Emit document provenance to knowledge graph
|
||||
if "triples-store" in flow["interfaces"]:
|
||||
await self.emit_document_provenance(
|
||||
document, processing, flow["interfaces"]["triples-store"]
|
||||
)
|
||||
|
||||
if kind == "text-load":
|
||||
# For large text documents, send document_id for streaming retrieval
|
||||
if len(content) >= self.STREAMING_THRESHOLD:
|
||||
logger.info(f"Text document {document.id} is large ({len(content)} bytes), "
|
||||
f"sending document_id for streaming retrieval")
|
||||
doc = TextDocument(
|
||||
metadata = Metadata(
|
||||
id = document.id,
|
||||
root = document.id,
|
||||
user = processing.user,
|
||||
collection = processing.collection
|
||||
),
|
||||
document_id = document.id,
|
||||
text = b"", # Empty, receiver will fetch via librarian
|
||||
)
|
||||
else:
|
||||
doc = TextDocument(
|
||||
metadata = Metadata(
|
||||
id = document.id,
|
||||
root = document.id,
|
||||
user = processing.user,
|
||||
collection = processing.collection
|
||||
),
|
||||
text = content,
|
||||
)
|
||||
schema = TextDocument
|
||||
else:
|
||||
doc = Document(
|
||||
metadata = Metadata(
|
||||
id = document.id,
|
||||
metadata = document.metadata,
|
||||
user = processing.user,
|
||||
collection = processing.collection
|
||||
),
|
||||
data = base64.b64encode(content).decode("utf-8")
|
||||
)
|
||||
# For large PDF documents, send document_id for streaming retrieval
|
||||
# instead of embedding the entire content in the message
|
||||
if len(content) >= self.STREAMING_THRESHOLD:
|
||||
logger.info(f"Document {document.id} is large ({len(content)} bytes), "
|
||||
f"sending document_id for streaming retrieval")
|
||||
doc = Document(
|
||||
metadata = Metadata(
|
||||
id = document.id,
|
||||
root = document.id,
|
||||
user = processing.user,
|
||||
collection = processing.collection
|
||||
),
|
||||
document_id = document.id,
|
||||
data = b"", # Empty data, receiver will fetch via API
|
||||
)
|
||||
else:
|
||||
doc = Document(
|
||||
metadata = Metadata(
|
||||
id = document.id,
|
||||
root = document.id,
|
||||
user = processing.user,
|
||||
collection = processing.collection
|
||||
),
|
||||
data = base64.b64encode(content).decode("utf-8")
|
||||
)
|
||||
schema = Document
|
||||
|
||||
logger.debug(f"Submitting to queue {q}...")
|
||||
|
|
@ -361,6 +474,17 @@ class Processor(AsyncProcessor):
|
|||
"remove-processing": self.librarian.remove_processing,
|
||||
"list-documents": self.librarian.list_documents,
|
||||
"list-processing": self.librarian.list_processing,
|
||||
# Chunked upload operations
|
||||
"begin-upload": self.librarian.begin_upload,
|
||||
"upload-chunk": self.librarian.upload_chunk,
|
||||
"complete-upload": self.librarian.complete_upload,
|
||||
"abort-upload": self.librarian.abort_upload,
|
||||
"get-upload-status": self.librarian.get_upload_status,
|
||||
"list-uploads": self.librarian.list_uploads,
|
||||
# Child document and streaming operations
|
||||
"add-child-document": self.librarian.add_child_document,
|
||||
"list-children": self.librarian.list_children,
|
||||
"stream-document": self.librarian.stream_document,
|
||||
}
|
||||
|
||||
if v.operation not in impls:
|
||||
|
|
@ -380,6 +504,15 @@ class Processor(AsyncProcessor):
|
|||
|
||||
try:
|
||||
|
||||
# Handle streaming operations specially
|
||||
if v.operation == "stream-document":
|
||||
async for resp in self.librarian.stream_document(v):
|
||||
await self.librarian_response_producer.send(
|
||||
resp, properties={"id": id}
|
||||
)
|
||||
return
|
||||
|
||||
# Non-streaming operations
|
||||
resp = await self.process_request(v)
|
||||
|
||||
await self.librarian_response_producer.send(
|
||||
|
|
@ -393,7 +526,7 @@ class Processor(AsyncProcessor):
|
|||
error = Error(
|
||||
type = "request-error",
|
||||
message = str(e),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
await self.librarian_response_producer.send(
|
||||
|
|
@ -406,7 +539,7 @@ class Processor(AsyncProcessor):
|
|||
error = Error(
|
||||
type = "unexpected-error",
|
||||
message = str(e),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
await self.librarian_response_producer.send(
|
||||
|
|
@ -538,6 +671,14 @@ class Processor(AsyncProcessor):
|
|||
help='Object storage region (optional)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--min-chunk-size',
|
||||
type=int,
|
||||
default=default_min_chunk_size,
|
||||
help=f'Minimum chunk size in bytes for uploads/downloads '
|
||||
f'(default: {default_min_chunk_size})',
|
||||
)
|
||||
|
||||
add_cassandra_args(parser)
|
||||
|
||||
def run():
|
||||
|
|
|
|||
|
|
@ -55,11 +55,13 @@ class Processor(LlmService):
|
|||
self.max_output = max_output
|
||||
self.default_model = model
|
||||
|
||||
def build_prompt(self, system, content, temperature=None, stream=False):
|
||||
def build_prompt(self, system, content, temperature=None, stream=False, model=None):
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
model_name = model or self.default_model
|
||||
|
||||
data = {
|
||||
"model": model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system", "content": system
|
||||
|
|
@ -100,7 +102,8 @@ class Processor(LlmService):
|
|||
raise TooManyRequests()
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError("LLM failure")
|
||||
logger.error(f"Azure API error: status={resp.status_code}, body={resp.text}")
|
||||
raise RuntimeError(f"LLM failure: HTTP {resp.status_code}")
|
||||
|
||||
result = resp.json()
|
||||
|
||||
|
|
@ -121,7 +124,8 @@ class Processor(LlmService):
|
|||
prompt = self.build_prompt(
|
||||
system,
|
||||
prompt,
|
||||
effective_temperature
|
||||
effective_temperature,
|
||||
model=model_name
|
||||
)
|
||||
|
||||
response = self.call_llm(prompt)
|
||||
|
|
@ -174,7 +178,7 @@ class Processor(LlmService):
|
|||
logger.debug(f"Using temperature: {effective_temperature}")
|
||||
|
||||
try:
|
||||
body = self.build_prompt(system, prompt, effective_temperature, stream=True)
|
||||
body = self.build_prompt(system, prompt, effective_temperature, stream=True, model=model_name)
|
||||
|
||||
url = self.endpoint
|
||||
api_key = self.token
|
||||
|
|
@ -190,7 +194,11 @@ class Processor(LlmService):
|
|||
raise TooManyRequests()
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError("LLM failure")
|
||||
logger.error(f"Azure API error: status={response.status_code}, body={response.text}")
|
||||
raise RuntimeError(f"LLM failure: HTTP {response.status_code}")
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
|
@ -279,6 +287,12 @@ class Processor(LlmService):
|
|||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default=default_model,
|
||||
help=f'LLM model name (default: {default_model})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
|
||||
"""
|
||||
Document embeddings query service. Input is vector, output is an array
|
||||
of chunks
|
||||
of chunk_ids
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from .... direct.milvus_doc_embeddings import DocVectors
|
||||
from .... schema import DocumentEmbeddingsResponse
|
||||
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
||||
from .... schema import Error
|
||||
from .... base import DocumentEmbeddingsQueryService
|
||||
|
||||
|
|
@ -35,24 +35,31 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
|
||||
try:
|
||||
|
||||
vec = msg.vector
|
||||
if not vec:
|
||||
return []
|
||||
|
||||
# Handle zero limit case
|
||||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
resp = self.vecstore.search(
|
||||
vec,
|
||||
msg.user,
|
||||
msg.collection,
|
||||
limit=msg.limit
|
||||
)
|
||||
|
||||
chunks = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
resp = self.vecstore.search(
|
||||
vec,
|
||||
msg.user,
|
||||
msg.collection,
|
||||
limit=msg.limit
|
||||
)
|
||||
|
||||
for r in resp:
|
||||
chunk = r["entity"]["doc"]
|
||||
chunks.append(chunk)
|
||||
for r in resp:
|
||||
chunk_id = r["entity"]["chunk_id"]
|
||||
# Milvus returns distance, convert to similarity score
|
||||
distance = r.get("distance", 0.0)
|
||||
score = 1.0 - distance if distance else 0.0
|
||||
chunks.append(ChunkMatch(
|
||||
chunk_id=chunk_id,
|
||||
score=score,
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
"""
|
||||
Document embeddings query service. Input is vector, output is an array
|
||||
of chunks. Pinecone implementation.
|
||||
of chunk_ids. Pinecone implementation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -11,6 +11,7 @@ import os
|
|||
from pinecone import Pinecone, ServerlessSpec
|
||||
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
||||
|
||||
from .... schema import ChunkMatch
|
||||
from .... base import DocumentEmbeddingsQueryService
|
||||
|
||||
# Module logger
|
||||
|
|
@ -51,36 +52,41 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
|
||||
try:
|
||||
|
||||
vec = msg.vector
|
||||
if not vec:
|
||||
return []
|
||||
|
||||
# Handle zero limit case
|
||||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
dim = len(vec)
|
||||
|
||||
# Use dimension suffix in index name
|
||||
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
|
||||
|
||||
# Check if index exists - return empty if not
|
||||
if not self.pinecone.has_index(index_name):
|
||||
logger.info(f"Index {index_name} does not exist")
|
||||
return []
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
results = index.query(
|
||||
vector=vec,
|
||||
top_k=msg.limit,
|
||||
include_values=False,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
chunks = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
|
||||
# Use dimension suffix in index name
|
||||
index_name = f"d-{msg.user}-{msg.collection}-{dim}"
|
||||
|
||||
# Check if index exists - skip if not
|
||||
if not self.pinecone.has_index(index_name):
|
||||
logger.info(f"Index {index_name} does not exist, skipping this vector")
|
||||
continue
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
results = index.query(
|
||||
vector=vec,
|
||||
top_k=msg.limit,
|
||||
include_values=False,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
for r in results.matches:
|
||||
doc = r.metadata["doc"]
|
||||
chunks.append(doc)
|
||||
for r in results.matches:
|
||||
chunk_id = r.metadata["chunk_id"]
|
||||
score = r.score if hasattr(r, 'score') else 0.0
|
||||
chunks.append(ChunkMatch(
|
||||
chunk_id=chunk_id,
|
||||
score=score,
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
"""
|
||||
Document embeddings query service. Input is vector, output is an array
|
||||
of chunks
|
||||
of chunk_ids
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
|
|||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
from .... schema import DocumentEmbeddingsResponse
|
||||
from .... schema import DocumentEmbeddingsResponse, ChunkMatch
|
||||
from .... schema import Error
|
||||
from .... base import DocumentEmbeddingsQueryService
|
||||
|
||||
|
|
@ -69,29 +69,34 @@ class Processor(DocumentEmbeddingsQueryService):
|
|||
|
||||
try:
|
||||
|
||||
vec = msg.vector
|
||||
if not vec:
|
||||
return []
|
||||
|
||||
# Use dimension suffix in collection name
|
||||
dim = len(vec)
|
||||
collection = f"d_{msg.user}_{msg.collection}_{dim}"
|
||||
|
||||
# Check if collection exists - return empty if not
|
||||
if not self.collection_exists(collection):
|
||||
logger.info(f"Collection {collection} does not exist, returning empty results")
|
||||
return []
|
||||
|
||||
search_result = self.qdrant.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
limit=msg.limit,
|
||||
with_payload=True,
|
||||
).points
|
||||
|
||||
chunks = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
|
||||
# Use dimension suffix in collection name
|
||||
dim = len(vec)
|
||||
collection = f"d_{msg.user}_{msg.collection}_{dim}"
|
||||
|
||||
# Check if collection exists - return empty if not
|
||||
if not self.collection_exists(collection):
|
||||
logger.info(f"Collection {collection} does not exist, returning empty results")
|
||||
continue
|
||||
|
||||
search_result = self.qdrant.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
limit=msg.limit,
|
||||
with_payload=True,
|
||||
).points
|
||||
|
||||
for r in search_result:
|
||||
ent = r.payload["doc"]
|
||||
chunks.append(ent)
|
||||
for r in search_result:
|
||||
chunk_id = r.payload["chunk_id"]
|
||||
score = r.score if hasattr(r, 'score') else 0.0
|
||||
chunks.append(ChunkMatch(
|
||||
chunk_id=chunk_id,
|
||||
score=score,
|
||||
))
|
||||
|
||||
return chunks
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ entities
|
|||
import logging
|
||||
|
||||
from .... direct.milvus_graph_embeddings import EntityVectors
|
||||
from .... schema import GraphEmbeddingsResponse
|
||||
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||
from .... schema import Error, Term, IRI, LITERAL
|
||||
from .... base import GraphEmbeddingsQueryService
|
||||
|
||||
|
|
@ -41,42 +41,41 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
|
||||
try:
|
||||
|
||||
entity_set = set()
|
||||
entities = []
|
||||
vec = msg.vector
|
||||
if not vec:
|
||||
return []
|
||||
|
||||
# Handle zero limit case
|
||||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
for vec in msg.vectors:
|
||||
resp = self.vecstore.search(
|
||||
vec,
|
||||
msg.user,
|
||||
msg.collection,
|
||||
limit=msg.limit * 2
|
||||
)
|
||||
|
||||
resp = self.vecstore.search(
|
||||
vec,
|
||||
msg.user,
|
||||
msg.collection,
|
||||
limit=msg.limit * 2
|
||||
)
|
||||
entity_set = set()
|
||||
entities = []
|
||||
|
||||
for r in resp:
|
||||
ent = r["entity"]["entity"]
|
||||
|
||||
# De-dupe entities
|
||||
if ent not in entity_set:
|
||||
entity_set.add(ent)
|
||||
entities.append(ent)
|
||||
for r in resp:
|
||||
ent = r["entity"]["entity"]
|
||||
# Milvus returns distance, convert to similarity score
|
||||
distance = r.get("distance", 0.0)
|
||||
score = 1.0 - distance if distance else 0.0
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= msg.limit: break
|
||||
# De-dupe entities, keep highest score
|
||||
if ent not in entity_set:
|
||||
entity_set.add(ent)
|
||||
entities.append(EntityMatch(
|
||||
entity=self.create_value(ent),
|
||||
score=score,
|
||||
))
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= msg.limit: break
|
||||
|
||||
ents2 = []
|
||||
|
||||
for ent in entities:
|
||||
ents2.append(self.create_value(ent))
|
||||
|
||||
entities = ents2
|
||||
if len(entities) >= msg.limit:
|
||||
break
|
||||
|
||||
logger.debug("Send response...")
|
||||
return entities
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import os
|
|||
from pinecone import Pinecone, ServerlessSpec
|
||||
from pinecone.grpc import PineconeGRPC, GRPCClientConfig
|
||||
|
||||
from .... schema import GraphEmbeddingsResponse
|
||||
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||
from .... schema import Error, Term, IRI, LITERAL
|
||||
from .... base import GraphEmbeddingsQueryService
|
||||
|
||||
|
|
@ -59,57 +59,53 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
|
||||
try:
|
||||
|
||||
vec = msg.vector
|
||||
if not vec:
|
||||
return []
|
||||
|
||||
# Handle zero limit case
|
||||
if msg.limit <= 0:
|
||||
return []
|
||||
|
||||
dim = len(vec)
|
||||
|
||||
# Use dimension suffix in index name
|
||||
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
|
||||
|
||||
# Check if index exists - return empty if not
|
||||
if not self.pinecone.has_index(index_name):
|
||||
logger.info(f"Index {index_name} does not exist")
|
||||
return []
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
# Heuristic hack, get (2*limit), so that we have more chance
|
||||
# of getting (limit) unique entities
|
||||
results = index.query(
|
||||
vector=vec,
|
||||
top_k=msg.limit * 2,
|
||||
include_values=False,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
entity_set = set()
|
||||
entities = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
for r in results.matches:
|
||||
ent = r.metadata["entity"]
|
||||
score = r.score if hasattr(r, 'score') else 0.0
|
||||
|
||||
dim = len(vec)
|
||||
|
||||
# Use dimension suffix in index name
|
||||
index_name = f"t-{msg.user}-{msg.collection}-{dim}"
|
||||
|
||||
# Check if index exists - skip if not
|
||||
if not self.pinecone.has_index(index_name):
|
||||
logger.info(f"Index {index_name} does not exist, skipping this vector")
|
||||
continue
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
# Heuristic hack, get (2*limit), so that we have more chance
|
||||
# of getting (limit) entities
|
||||
results = index.query(
|
||||
vector=vec,
|
||||
top_k=msg.limit * 2,
|
||||
include_values=False,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
for r in results.matches:
|
||||
|
||||
ent = r.metadata["entity"]
|
||||
|
||||
# De-dupe entities
|
||||
if ent not in entity_set:
|
||||
entity_set.add(ent)
|
||||
entities.append(ent)
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= msg.limit: break
|
||||
# De-dupe entities, keep highest score
|
||||
if ent not in entity_set:
|
||||
entity_set.add(ent)
|
||||
entities.append(EntityMatch(
|
||||
entity=self.create_value(ent),
|
||||
score=score,
|
||||
))
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= msg.limit: break
|
||||
|
||||
ents2 = []
|
||||
|
||||
for ent in entities:
|
||||
ents2.append(self.create_value(ent))
|
||||
|
||||
entities = ents2
|
||||
if len(entities) >= msg.limit:
|
||||
break
|
||||
|
||||
return entities
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from qdrant_client import QdrantClient
|
|||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
from .... schema import GraphEmbeddingsResponse
|
||||
from .... schema import GraphEmbeddingsResponse, EntityMatch
|
||||
from .... schema import Error, Term, IRI, LITERAL
|
||||
from .... base import GraphEmbeddingsQueryService
|
||||
|
||||
|
|
@ -75,49 +75,46 @@ class Processor(GraphEmbeddingsQueryService):
|
|||
|
||||
try:
|
||||
|
||||
vec = msg.vector
|
||||
if not vec:
|
||||
return []
|
||||
|
||||
# Use dimension suffix in collection name
|
||||
dim = len(vec)
|
||||
collection = f"t_{msg.user}_{msg.collection}_{dim}"
|
||||
|
||||
# Check if collection exists - return empty if not
|
||||
if not self.collection_exists(collection):
|
||||
logger.info(f"Collection {collection} does not exist")
|
||||
return []
|
||||
|
||||
# Heuristic hack, get (2*limit), so that we have more chance
|
||||
# of getting (limit) unique entities
|
||||
search_result = self.qdrant.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
limit=msg.limit * 2,
|
||||
with_payload=True,
|
||||
).points
|
||||
|
||||
entity_set = set()
|
||||
entities = []
|
||||
|
||||
for vec in msg.vectors:
|
||||
for r in search_result:
|
||||
ent = r.payload["entity"]
|
||||
score = r.score if hasattr(r, 'score') else 0.0
|
||||
|
||||
# Use dimension suffix in collection name
|
||||
dim = len(vec)
|
||||
collection = f"t_{msg.user}_{msg.collection}_{dim}"
|
||||
|
||||
# Check if collection exists - return empty if not
|
||||
if not self.collection_exists(collection):
|
||||
logger.info(f"Collection {collection} does not exist, skipping this vector")
|
||||
continue
|
||||
|
||||
# Heuristic hack, get (2*limit), so that we have more chance
|
||||
# of getting (limit) entities
|
||||
search_result = self.qdrant.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
limit=msg.limit * 2,
|
||||
with_payload=True,
|
||||
).points
|
||||
|
||||
for r in search_result:
|
||||
ent = r.payload["entity"]
|
||||
|
||||
# De-dupe entities
|
||||
if ent not in entity_set:
|
||||
entity_set.add(ent)
|
||||
entities.append(ent)
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= msg.limit: break
|
||||
# De-dupe entities, keep highest score
|
||||
if ent not in entity_set:
|
||||
entity_set.add(ent)
|
||||
entities.append(EntityMatch(
|
||||
entity=self.create_value(ent),
|
||||
score=score,
|
||||
))
|
||||
|
||||
# Keep adding entities until limit
|
||||
if len(entity_set) >= msg.limit: break
|
||||
|
||||
ents2 = []
|
||||
|
||||
for ent in entities:
|
||||
ents2.append(self.create_value(ent))
|
||||
|
||||
entities = ents2
|
||||
if len(entities) >= msg.limit:
|
||||
break
|
||||
|
||||
logger.debug("Send response...")
|
||||
return entities
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
default_ident = "row-embeddings-query"
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
default_concurrency = 10
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
|
@ -31,6 +32,7 @@ class Processor(FlowProcessor):
|
|||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
concurrency = params.get("concurrency", default_concurrency)
|
||||
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
api_key = params.get("api_key", None)
|
||||
|
|
@ -47,7 +49,8 @@ class Processor(FlowProcessor):
|
|||
ConsumerSpec(
|
||||
name="request",
|
||||
schema=RowEmbeddingsRequest,
|
||||
handler=self.on_message
|
||||
handler=self.on_message,
|
||||
concurrency=concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -93,7 +96,9 @@ class Processor(FlowProcessor):
|
|||
async def query_row_embeddings(self, request: RowEmbeddingsRequest):
|
||||
"""Execute row embeddings query"""
|
||||
|
||||
matches = []
|
||||
vec = request.vector
|
||||
if not vec:
|
||||
return []
|
||||
|
||||
# Find the collection for this user/collection/schema
|
||||
qdrant_collection = self.find_collection(
|
||||
|
|
@ -105,47 +110,47 @@ class Processor(FlowProcessor):
|
|||
f"No Qdrant collection found for "
|
||||
f"{request.user}/{request.collection}/{request.schema_name}"
|
||||
)
|
||||
return []
|
||||
|
||||
try:
|
||||
# Build optional filter for index_name
|
||||
query_filter = None
|
||||
if request.index_name:
|
||||
query_filter = Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="index_name",
|
||||
match=MatchValue(value=request.index_name)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Query Qdrant
|
||||
search_result = self.qdrant.query_points(
|
||||
collection_name=qdrant_collection,
|
||||
query=vec,
|
||||
limit=request.limit,
|
||||
with_payload=True,
|
||||
query_filter=query_filter,
|
||||
).points
|
||||
|
||||
# Convert to RowIndexMatch objects
|
||||
matches = []
|
||||
for point in search_result:
|
||||
payload = point.payload or {}
|
||||
match = RowIndexMatch(
|
||||
index_name=payload.get("index_name", ""),
|
||||
index_value=payload.get("index_value", []),
|
||||
text=payload.get("text", ""),
|
||||
score=point.score if hasattr(point, 'score') else 0.0
|
||||
)
|
||||
matches.append(match)
|
||||
|
||||
return matches
|
||||
|
||||
for vec in request.vectors:
|
||||
try:
|
||||
# Build optional filter for index_name
|
||||
query_filter = None
|
||||
if request.index_name:
|
||||
query_filter = Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="index_name",
|
||||
match=MatchValue(value=request.index_name)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Query Qdrant
|
||||
search_result = self.qdrant.query_points(
|
||||
collection_name=qdrant_collection,
|
||||
query=vec,
|
||||
limit=request.limit,
|
||||
with_payload=True,
|
||||
query_filter=query_filter,
|
||||
).points
|
||||
|
||||
# Convert to RowIndexMatch objects
|
||||
for point in search_result:
|
||||
payload = point.payload or {}
|
||||
match = RowIndexMatch(
|
||||
index_name=payload.get("index_name", ""),
|
||||
index_value=payload.get("index_value", []),
|
||||
text=payload.get("text", ""),
|
||||
score=point.score if hasattr(point, 'score') else 0.0
|
||||
)
|
||||
matches.append(match)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
return matches
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to query Qdrant: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def on_message(self, msg, consumer, flow):
|
||||
"""Handle incoming query request"""
|
||||
|
|
@ -203,6 +208,13 @@ class Processor(FlowProcessor):
|
|||
help='API key for Qdrant (default: None)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Number of concurrent requests (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
"""Entry point for row-embeddings-query-qdrant command"""
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from ... graphql import GraphQLSchemaBuilder, SortDirection
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "rows-query"
|
||||
default_concurrency = 10
|
||||
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
|
@ -37,6 +38,7 @@ class Processor(FlowProcessor):
|
|||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
concurrency = params.get("concurrency", default_concurrency)
|
||||
|
||||
# Get Cassandra parameters
|
||||
cassandra_host = params.get("cassandra_host")
|
||||
|
|
@ -69,7 +71,8 @@ class Processor(FlowProcessor):
|
|||
ConsumerSpec(
|
||||
name="request",
|
||||
schema=RowsQueryRequest,
|
||||
handler=self.on_message
|
||||
handler=self.on_message,
|
||||
concurrency=concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -517,6 +520,13 @@ class Processor(FlowProcessor):
|
|||
help='Configuration type prefix for schemas (default: schema)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Number of concurrent requests (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
|
||||
def run():
|
||||
"""Entry point for rows-query-cassandra command"""
|
||||
|
|
|
|||
|
|
@ -6,11 +6,14 @@ null. Output is a list of quads.
|
|||
|
||||
import logging
|
||||
|
||||
import json
|
||||
from cassandra.query import SimpleStatement
|
||||
|
||||
from .... direct.cassandra_kg import (
|
||||
EntityCentricKnowledgeGraph, GRAPH_WILDCARD, DEFAULT_GRAPH
|
||||
EntityCentricKnowledgeGraph, DEFAULT_GRAPH
|
||||
)
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Term, Triple, IRI, LITERAL
|
||||
from .... schema import Term, Triple, IRI, LITERAL, TRIPLE, BLANK
|
||||
from .... base import TriplesQueryService
|
||||
from .... base.cassandra_config import add_cassandra_args, resolve_cassandra_config
|
||||
|
||||
|
|
@ -20,6 +23,36 @@ logger = logging.getLogger(__name__)
|
|||
default_ident = "triples-query"
|
||||
|
||||
|
||||
def serialize_triple(triple):
|
||||
"""Serialize a Triple object to JSON for querying (must match storage format)."""
|
||||
if triple is None:
|
||||
return None
|
||||
|
||||
def term_to_dict(term):
|
||||
if term is None:
|
||||
return None
|
||||
result = {"type": term.type}
|
||||
if term.type == IRI:
|
||||
result["iri"] = term.iri
|
||||
elif term.type == LITERAL:
|
||||
result["value"] = term.value
|
||||
if term.datatype:
|
||||
result["datatype"] = term.datatype
|
||||
if term.language:
|
||||
result["language"] = term.language
|
||||
elif term.type == BLANK:
|
||||
result["id"] = term.id
|
||||
elif term.type == TRIPLE:
|
||||
result["triple"] = serialize_triple(term.triple)
|
||||
return result
|
||||
|
||||
return json.dumps({
|
||||
"s": term_to_dict(triple.s),
|
||||
"p": term_to_dict(triple.p),
|
||||
"o": term_to_dict(triple.o),
|
||||
})
|
||||
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
|
|
@ -28,42 +61,88 @@ def get_term_value(term):
|
|||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
elif term.type == TRIPLE:
|
||||
# Serialize nested triple to JSON (must match storage format)
|
||||
return serialize_triple(term.triple)
|
||||
else:
|
||||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
||||
|
||||
def create_term(value, otype=None, dtype=None, lang=None):
|
||||
def deserialize_term(term_dict):
|
||||
"""Deserialize a term from JSON structure."""
|
||||
if term_dict is None:
|
||||
return None
|
||||
term_type = term_dict.get("type", "")
|
||||
if term_type == IRI:
|
||||
return Term(type=IRI, iri=term_dict.get("iri", ""))
|
||||
elif term_type == LITERAL:
|
||||
return Term(
|
||||
type=LITERAL,
|
||||
value=term_dict.get("value", ""),
|
||||
datatype=term_dict.get("datatype", ""),
|
||||
language=term_dict.get("language", "")
|
||||
)
|
||||
elif term_type == TRIPLE:
|
||||
# Recursive for nested triples
|
||||
nested = term_dict.get("triple")
|
||||
if nested:
|
||||
return Term(
|
||||
type=TRIPLE,
|
||||
triple=Triple(
|
||||
s=deserialize_term(nested.get("s")),
|
||||
p=deserialize_term(nested.get("p")),
|
||||
o=deserialize_term(nested.get("o")),
|
||||
)
|
||||
)
|
||||
# Fallback
|
||||
return Term(type=LITERAL, value=str(term_dict))
|
||||
|
||||
|
||||
def create_term(value, term_type=None, datatype=None, language=None):
|
||||
"""
|
||||
Create a Term from a string value, optionally using type metadata.
|
||||
|
||||
Args:
|
||||
value: The string value
|
||||
otype: Object type - 'u' (URI), 'l' (literal), 't' (triple)
|
||||
dtype: XSD datatype (for literals)
|
||||
lang: Language tag (for literals)
|
||||
term_type: 'u' (IRI), 'l' (literal), 't' (triple)
|
||||
datatype: XSD datatype for literals
|
||||
language: Language tag for literals
|
||||
|
||||
If otype is provided, uses it to determine Term type.
|
||||
Otherwise falls back to URL detection heuristic.
|
||||
If term_type is provided, uses it to determine Term type.
|
||||
Otherwise falls back to URL detection heuristic for object values.
|
||||
"""
|
||||
if otype is not None:
|
||||
if otype == 'u':
|
||||
return Term(type=IRI, iri=value)
|
||||
elif otype == 'l':
|
||||
return Term(
|
||||
type=LITERAL,
|
||||
value=value,
|
||||
datatype=dtype or "",
|
||||
language=lang or ""
|
||||
)
|
||||
elif otype == 't':
|
||||
# Triple/reification - treat as IRI for now
|
||||
return Term(type=IRI, iri=value)
|
||||
else:
|
||||
# Unknown otype, fall back to heuristic
|
||||
pass
|
||||
if term_type == 'u':
|
||||
return Term(type=IRI, iri=value)
|
||||
elif term_type == 'l':
|
||||
return Term(
|
||||
type=LITERAL,
|
||||
value=value,
|
||||
datatype=datatype or "",
|
||||
language=language or ""
|
||||
)
|
||||
elif term_type == 't':
|
||||
# Triple/reification - parse JSON and create nested Triple
|
||||
try:
|
||||
triple_data = json.loads(value) if isinstance(value, str) else value
|
||||
if isinstance(triple_data, dict):
|
||||
return Term(
|
||||
type=TRIPLE,
|
||||
triple=Triple(
|
||||
s=deserialize_term(triple_data.get("s")),
|
||||
p=deserialize_term(triple_data.get("p")),
|
||||
o=deserialize_term(triple_data.get("o")),
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.warning(f"Failed to parse triple JSON: {e}")
|
||||
# Fallback if parsing fails
|
||||
return Term(type=LITERAL, value=str(value))
|
||||
elif term_type is not None:
|
||||
# Unknown term_type, fall back to heuristic
|
||||
pass
|
||||
|
||||
# Heuristic fallback for backwards compatibility
|
||||
# Heuristic fallback for backwards compatibility (object values only)
|
||||
if value.startswith("http://") or value.startswith("https://"):
|
||||
return Term(type=IRI, iri=value)
|
||||
else:
|
||||
|
|
@ -98,28 +177,30 @@ class Processor(TriplesQueryService):
|
|||
self.cassandra_password = password
|
||||
self.table = None
|
||||
|
||||
def ensure_connection(self, user):
|
||||
"""Ensure we have a connection to the correct keyspace."""
|
||||
if user != self.table:
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=user,
|
||||
username=self.cassandra_username,
|
||||
password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=user,
|
||||
)
|
||||
self.table = user
|
||||
|
||||
async def query_triples(self, query):
|
||||
|
||||
try:
|
||||
|
||||
user = query.user
|
||||
|
||||
if user != self.table:
|
||||
# Use factory function to select implementation
|
||||
KGClass = EntityCentricKnowledgeGraph
|
||||
|
||||
if self.cassandra_username and self.cassandra_password:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=query.user,
|
||||
username=self.cassandra_username, password=self.cassandra_password
|
||||
)
|
||||
else:
|
||||
self.tg = KGClass(
|
||||
hosts=self.cassandra_host,
|
||||
keyspace=query.user,
|
||||
)
|
||||
self.table = user
|
||||
self.ensure_connection(query.user)
|
||||
|
||||
# Extract values from query
|
||||
s_val = get_term_value(query.s)
|
||||
|
|
@ -127,13 +208,13 @@ class Processor(TriplesQueryService):
|
|||
o_val = get_term_value(query.o)
|
||||
g_val = query.g # Already a string or None
|
||||
|
||||
# Helper to extract object metadata from result row
|
||||
def get_o_metadata(t):
|
||||
"""Extract otype/dtype/lang from result row if available"""
|
||||
otype = getattr(t, 'otype', None)
|
||||
dtype = getattr(t, 'dtype', None)
|
||||
lang = getattr(t, 'lang', None)
|
||||
return otype, dtype, lang
|
||||
def get_object_metadata(row):
|
||||
"""Extract term type metadata from result row"""
|
||||
return (
|
||||
getattr(row, 'otype', None),
|
||||
getattr(row, 'dtype', None),
|
||||
getattr(row, 'lang', None),
|
||||
)
|
||||
|
||||
quads = []
|
||||
|
||||
|
|
@ -148,8 +229,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((s_val, p_val, o_val, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((s_val, p_val, o_val, g, term_type, datatype, language))
|
||||
else:
|
||||
# SP specified
|
||||
resp = self.tg.get_sp(
|
||||
|
|
@ -158,8 +239,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((s_val, p_val, t.o, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((s_val, p_val, t.o, g, term_type, datatype, language))
|
||||
else:
|
||||
if o_val is not None:
|
||||
# SO specified
|
||||
|
|
@ -169,8 +250,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((s_val, t.p, o_val, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((s_val, t.p, o_val, g, term_type, datatype, language))
|
||||
else:
|
||||
# S only
|
||||
resp = self.tg.get_s(
|
||||
|
|
@ -179,8 +260,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((s_val, t.p, t.o, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((s_val, t.p, t.o, g, term_type, datatype, language))
|
||||
else:
|
||||
if p_val is not None:
|
||||
if o_val is not None:
|
||||
|
|
@ -191,8 +272,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((t.s, p_val, o_val, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((t.s, p_val, o_val, g, term_type, datatype, language))
|
||||
else:
|
||||
# P only
|
||||
resp = self.tg.get_p(
|
||||
|
|
@ -201,8 +282,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((t.s, p_val, t.o, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((t.s, p_val, t.o, g, term_type, datatype, language))
|
||||
else:
|
||||
if o_val is not None:
|
||||
# O only
|
||||
|
|
@ -212,8 +293,8 @@ class Processor(TriplesQueryService):
|
|||
)
|
||||
for t in resp:
|
||||
g = t.g if hasattr(t, 'g') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((t.s, t.p, o_val, g, otype, dtype, lang))
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((t.s, t.p, o_val, g, term_type, datatype, language))
|
||||
else:
|
||||
# Nothing specified - get all
|
||||
resp = self.tg.get_all(
|
||||
|
|
@ -223,16 +304,24 @@ class Processor(TriplesQueryService):
|
|||
for t in resp:
|
||||
# Note: quads_by_collection uses 'd' for graph field
|
||||
g = t.d if hasattr(t, 'd') else DEFAULT_GRAPH
|
||||
otype, dtype, lang = get_o_metadata(t)
|
||||
quads.append((t.s, t.p, t.o, g, otype, dtype, lang))
|
||||
# Filter by graph
|
||||
# g_val=None means all graphs (no filter)
|
||||
# g_val="" means default graph only
|
||||
# otherwise filter to specific named graph
|
||||
if g_val is not None:
|
||||
if g != g_val:
|
||||
continue
|
||||
term_type, datatype, language = get_object_metadata(t)
|
||||
quads.append((t.s, t.p, t.o, g, term_type, datatype, language))
|
||||
|
||||
# Convert to Triple objects (with g field)
|
||||
# Use otype/dtype/lang for proper Term reconstruction if available
|
||||
# s and p are always IRIs in RDF
|
||||
# Object uses term_type/datatype/language metadata from database
|
||||
triples = [
|
||||
Triple(
|
||||
s=create_term(q[0]),
|
||||
p=create_term(q[1]),
|
||||
o=create_term(q[2], otype=q[4], dtype=q[5], lang=q[6]),
|
||||
s=create_term(q[0], term_type='u'),
|
||||
p=create_term(q[1], term_type='u'),
|
||||
o=create_term(q[2], term_type=q[4], datatype=q[5], language=q[6]),
|
||||
g=q[3] if q[3] != DEFAULT_GRAPH else None
|
||||
)
|
||||
for q in quads
|
||||
|
|
@ -245,6 +334,104 @@ class Processor(TriplesQueryService):
|
|||
logger.error(f"Exception querying triples: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
async def query_triples_stream(self, query):
|
||||
"""
|
||||
Streaming query - yields (batch, is_final) tuples.
|
||||
Uses Cassandra's paging to fetch results incrementally.
|
||||
"""
|
||||
try:
|
||||
self.ensure_connection(query.user)
|
||||
|
||||
batch_size = query.batch_size if query.batch_size > 0 else 20
|
||||
limit = query.limit if query.limit > 0 else 10000
|
||||
|
||||
# Extract query pattern
|
||||
s_val = get_term_value(query.s)
|
||||
p_val = get_term_value(query.p)
|
||||
o_val = get_term_value(query.o)
|
||||
g_val = query.g
|
||||
|
||||
def get_object_metadata(row):
|
||||
"""Extract term type metadata from result row"""
|
||||
return (
|
||||
getattr(row, 'otype', None),
|
||||
getattr(row, 'dtype', None),
|
||||
getattr(row, 'lang', None),
|
||||
)
|
||||
|
||||
# For streaming, we need to execute with fetch_size
|
||||
# Use the collection table for get_all queries (most common streaming case)
|
||||
|
||||
# Determine which query to use based on pattern
|
||||
if s_val is None and p_val is None and o_val is None:
|
||||
# Get all - use collection table with paging
|
||||
cql = f"SELECT d, s, p, o, otype, dtype, lang FROM {self.tg.collection_table} WHERE collection = %s"
|
||||
params = [query.collection]
|
||||
else:
|
||||
# For specific patterns, fall back to non-streaming
|
||||
# (these typically return small result sets anyway)
|
||||
async for batch, is_final in self._fallback_stream(query, batch_size):
|
||||
yield batch, is_final
|
||||
return
|
||||
|
||||
# Create statement with fetch_size for true streaming
|
||||
statement = SimpleStatement(cql, fetch_size=batch_size)
|
||||
result_set = self.tg.session.execute(statement, params)
|
||||
|
||||
batch = []
|
||||
count = 0
|
||||
|
||||
for row in result_set:
|
||||
if count >= limit:
|
||||
break
|
||||
|
||||
g = row.d if hasattr(row, 'd') else DEFAULT_GRAPH
|
||||
|
||||
# Filter by graph
|
||||
# g_val=None means all graphs (no filter)
|
||||
# g_val="" means default graph only
|
||||
# otherwise filter to specific named graph
|
||||
if g_val is not None:
|
||||
if g != g_val:
|
||||
continue
|
||||
|
||||
term_type, datatype, language = get_object_metadata(row)
|
||||
|
||||
# s and p are always IRIs in RDF
|
||||
triple = Triple(
|
||||
s=create_term(row.s, term_type='u'),
|
||||
p=create_term(row.p, term_type='u'),
|
||||
o=create_term(row.o, term_type=term_type, datatype=datatype, language=language),
|
||||
g=g if g != DEFAULT_GRAPH else None
|
||||
)
|
||||
batch.append(triple)
|
||||
count += 1
|
||||
|
||||
# Yield batch when full (never mark as final mid-stream)
|
||||
if len(batch) >= batch_size:
|
||||
yield batch, False
|
||||
batch = []
|
||||
|
||||
# Always yield final batch to signal completion
|
||||
# This handles: remaining rows, empty result set, or exact batch boundary
|
||||
yield batch, True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in streaming query: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
async def _fallback_stream(self, query, batch_size):
|
||||
"""Fallback to non-streaming query with post-hoc batching."""
|
||||
triples = await self.query_triples(query)
|
||||
|
||||
for i in range(0, len(triples), batch_size):
|
||||
batch = triples[i:i + batch_size]
|
||||
is_final = (i + batch_size >= len(triples))
|
||||
yield batch, is_final
|
||||
|
||||
if len(triples) == 0:
|
||||
yield [], True
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,22 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# Provenance imports
|
||||
from trustgraph.provenance import (
|
||||
docrag_question_uri,
|
||||
docrag_grounding_uri,
|
||||
docrag_exploration_uri,
|
||||
docrag_synthesis_uri,
|
||||
docrag_question_triples,
|
||||
grounding_triples,
|
||||
docrag_exploration_triples,
|
||||
docrag_synthesis_triples,
|
||||
set_graph,
|
||||
GRAPH_RETRIEVAL,
|
||||
)
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -19,41 +35,106 @@ class Query:
|
|||
self.verbose = verbose
|
||||
self.doc_limit = doc_limit
|
||||
|
||||
async def get_vector(self, query):
|
||||
async def extract_concepts(self, query):
|
||||
"""Extract key concepts from query for independent embedding."""
|
||||
response = await self.rag.prompt_client.prompt(
|
||||
"extract-concepts",
|
||||
variables={"query": query}
|
||||
)
|
||||
|
||||
concepts = []
|
||||
if isinstance(response, str):
|
||||
for line in response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
concepts.append(line)
|
||||
|
||||
# Fallback to raw query if no concepts extracted
|
||||
if not concepts:
|
||||
concepts = [query]
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Extracted concepts: {concepts}")
|
||||
|
||||
return concepts
|
||||
|
||||
async def get_vectors(self, concepts):
|
||||
"""Compute embeddings for a list of concepts."""
|
||||
if self.verbose:
|
||||
logger.debug("Computing embeddings...")
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed(query)
|
||||
qembeds = await self.rag.embeddings_client.embed(concepts)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Embeddings computed")
|
||||
|
||||
return qembeds
|
||||
|
||||
async def get_docs(self, query):
|
||||
async def get_docs(self, concepts):
|
||||
"""
|
||||
Get documents (chunks) matching the extracted concepts.
|
||||
|
||||
vectors = await self.get_vector(query)
|
||||
Returns:
|
||||
tuple: (docs, chunk_ids) where:
|
||||
- docs: list of document content strings
|
||||
- chunk_ids: list of chunk IDs that were successfully fetched
|
||||
"""
|
||||
vectors = await self.get_vectors(concepts)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting documents...")
|
||||
logger.debug("Getting chunks from embeddings store...")
|
||||
|
||||
docs = await self.rag.doc_embeddings_client.query(
|
||||
vectors, limit=self.doc_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
# Query chunk matches for each concept concurrently
|
||||
per_concept_limit = max(
|
||||
1, self.doc_limit // len(vectors)
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Documents:")
|
||||
for doc in docs:
|
||||
logger.debug(f" {doc}")
|
||||
async def query_concept(vec):
|
||||
return await self.rag.doc_embeddings_client.query(
|
||||
vector=vec, limit=per_concept_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
|
||||
return docs
|
||||
results = await asyncio.gather(
|
||||
*[query_concept(v) for v in vectors]
|
||||
)
|
||||
|
||||
# Deduplicate chunk matches by chunk_id
|
||||
seen = set()
|
||||
chunk_matches = []
|
||||
for matches in results:
|
||||
for match in matches:
|
||||
if match.chunk_id and match.chunk_id not in seen:
|
||||
seen.add(match.chunk_id)
|
||||
chunk_matches.append(match)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Got {len(chunk_matches)} chunks, fetching content from Garage...")
|
||||
|
||||
# Fetch chunk content from Garage
|
||||
docs = []
|
||||
chunk_ids = []
|
||||
for match in chunk_matches:
|
||||
if match.chunk_id:
|
||||
try:
|
||||
content = await self.rag.fetch_chunk(match.chunk_id, self.user)
|
||||
docs.append(content)
|
||||
chunk_ids.append(match.chunk_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch chunk {match.chunk_id}: {e}")
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Documents fetched:")
|
||||
for doc in docs:
|
||||
logger.debug(f" {doc[:100]}...")
|
||||
|
||||
return docs, chunk_ids
|
||||
|
||||
class DocumentRag:
|
||||
|
||||
def __init__(
|
||||
self, prompt_client, embeddings_client, doc_embeddings_client,
|
||||
fetch_chunk,
|
||||
verbose=False,
|
||||
):
|
||||
|
||||
|
|
@ -62,6 +143,7 @@ class DocumentRag:
|
|||
self.prompt_client = prompt_client
|
||||
self.embeddings_client = embeddings_client
|
||||
self.doc_embeddings_client = doc_embeddings_client
|
||||
self.fetch_chunk = fetch_chunk
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("DocumentRag initialized")
|
||||
|
|
@ -69,17 +151,69 @@ class DocumentRag:
|
|||
async def query(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
doc_limit=20, streaming=False, chunk_callback=None,
|
||||
explain_callback=None, save_answer_callback=None,
|
||||
):
|
||||
"""
|
||||
Execute a Document RAG query with optional explainability tracking.
|
||||
|
||||
Args:
|
||||
query: The query string
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
doc_limit: Max chunks to retrieve
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for explainability
|
||||
save_answer_callback: async def callback(doc_id, answer_text) to save answer to librarian
|
||||
|
||||
Returns:
|
||||
str: The synthesized answer text
|
||||
"""
|
||||
if self.verbose:
|
||||
logger.debug("Constructing prompt...")
|
||||
|
||||
# Generate explainability URIs upfront
|
||||
session_id = str(uuid.uuid4())
|
||||
q_uri = docrag_question_uri(session_id)
|
||||
gnd_uri = docrag_grounding_uri(session_id)
|
||||
exp_uri = docrag_exploration_uri(session_id)
|
||||
syn_uri = docrag_synthesis_uri(session_id)
|
||||
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
# Emit question explainability immediately
|
||||
if explain_callback:
|
||||
q_triples = set_graph(
|
||||
docrag_question_triples(q_uri, query, timestamp),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(q_triples, q_uri)
|
||||
|
||||
q = Query(
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose,
|
||||
doc_limit=doc_limit
|
||||
)
|
||||
|
||||
docs = await q.get_docs(query)
|
||||
# Extract concepts from query (grounding step)
|
||||
concepts = await q.extract_concepts(query)
|
||||
|
||||
# Emit grounding explainability after concept extraction
|
||||
if explain_callback:
|
||||
gnd_triples = set_graph(
|
||||
grounding_triples(gnd_uri, q_uri, concepts),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(gnd_triples, gnd_uri)
|
||||
|
||||
docs, chunk_ids = await q.get_docs(concepts)
|
||||
|
||||
# Emit exploration explainability after chunks retrieved
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
docrag_exploration_triples(exp_uri, gnd_uri, len(chunk_ids), chunk_ids),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(exp_triples, exp_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Invoking LLM...")
|
||||
|
|
@ -87,12 +221,21 @@ class DocumentRag:
|
|||
logger.debug(f"Query: {query}")
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Accumulate chunks for answer storage while forwarding to callback
|
||||
accumulated_chunks = []
|
||||
|
||||
async def accumulating_callback(chunk, end_of_stream):
|
||||
accumulated_chunks.append(chunk)
|
||||
await chunk_callback(chunk, end_of_stream)
|
||||
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
query=query,
|
||||
documents=docs,
|
||||
streaming=True,
|
||||
chunk_callback=chunk_callback
|
||||
chunk_callback=accumulating_callback
|
||||
)
|
||||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
resp = await self.prompt_client.document_prompt(
|
||||
query=query,
|
||||
|
|
@ -102,5 +245,33 @@ class DocumentRag:
|
|||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
||||
# Emit synthesis explainability after answer generated
|
||||
if explain_callback:
|
||||
synthesis_doc_id = None
|
||||
answer_text = resp if resp else ""
|
||||
|
||||
# Save answer to librarian
|
||||
if save_answer_callback and answer_text:
|
||||
synthesis_doc_id = f"urn:trustgraph:docrag:{session_id}/answer"
|
||||
try:
|
||||
await save_answer_callback(synthesis_doc_id, answer_text)
|
||||
if self.verbose:
|
||||
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
synthesis_doc_id = None
|
||||
|
||||
syn_triples = set_graph(
|
||||
docrag_synthesis_triples(
|
||||
syn_uri, exp_uri,
|
||||
document_id=synthesis_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(syn_triples, syn_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Emitted explain for session {session_id}")
|
||||
|
||||
return resp
|
||||
|
||||
|
|
|
|||
|
|
@ -4,17 +4,30 @@ Simple RAG service, performs query using document RAG an LLM.
|
|||
Input is query, output is response.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
|
||||
import uuid
|
||||
|
||||
from ... schema import DocumentRagQuery, DocumentRagResponse, Error
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from ... schema import Triples, Metadata
|
||||
from ... provenance import GRAPH_RETRIEVAL
|
||||
from . document_rag import DocumentRag
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import DocumentEmbeddingsClientSpec
|
||||
from ... base import Consumer, Producer
|
||||
from ... base import ConsumerMetrics, ProducerMetrics
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "document-rag"
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
|
|
@ -69,6 +82,161 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "explainability",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
# Librarian client for fetching chunk content from Garage
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend=self.pubsub,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
backend=self.pubsub,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self.on_librarian_response,
|
||||
metrics=librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_requests = {}
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id in self.pending_requests:
|
||||
future = self.pending_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def fetch_chunk_content(self, chunk_id, user, timeout=120):
|
||||
"""Fetch chunk content from librarian/Garage."""
|
||||
import uuid
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-content",
|
||||
document_id=chunk_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
# Content is base64 encoded
|
||||
content = response.content
|
||||
if isinstance(content, str):
|
||||
content = content.encode('utf-8')
|
||||
return base64.b64decode(content).decode("utf-8")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout fetching chunk {chunk_id}")
|
||||
|
||||
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
|
||||
"""
|
||||
Save answer content to the librarian.
|
||||
|
||||
Args:
|
||||
doc_id: ID for the answer document
|
||||
user: User ID
|
||||
content: Answer text content
|
||||
title: Optional title
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
The document ID on success
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or "DocumentRAG Answer",
|
||||
document_type="answer",
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-document",
|
||||
document_id=doc_id,
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving answer document {doc_id}")
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
|
@ -77,6 +245,7 @@ class Processor(FlowProcessor):
|
|||
embeddings_client = flow("embeddings-request"),
|
||||
doc_embeddings_client = flow("document-embeddings-request"),
|
||||
prompt_client = flow("prompt-request"),
|
||||
fetch_chunk = self.fetch_chunk_content,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
|
@ -92,6 +261,39 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
doc_limit = self.doc_limit
|
||||
|
||||
# Real-time explainability callback - emits triples and IDs as they're generated
|
||||
# Triples are stored in the user's collection with a named graph (urn:graph:retrieval)
|
||||
async def send_explainability(triples, explain_id):
|
||||
# Send triples to explainability queue - stores in same collection with named graph
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=explain_id,
|
||||
user=v.user,
|
||||
collection=v.collection, # Store in user's collection
|
||||
),
|
||||
triples=triples,
|
||||
))
|
||||
|
||||
# Send explain ID and graph to response queue
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response=None,
|
||||
explain_id=explain_id,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
message_type="explain",
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
# Callback to save answer content to librarian
|
||||
async def save_answer(doc_id, answer_text):
|
||||
await self.save_answer_content(
|
||||
doc_id=doc_id,
|
||||
user=v.user,
|
||||
content=answer_text,
|
||||
title=f"DocumentRAG Answer: {v.query[:50]}...",
|
||||
)
|
||||
|
||||
# Check if streaming is requested
|
||||
if v.streaming:
|
||||
# Define async callback for streaming chunks
|
||||
|
|
@ -101,6 +303,7 @@ class Processor(FlowProcessor):
|
|||
DocumentRagResponse(
|
||||
response=chunk,
|
||||
end_of_stream=end_of_stream,
|
||||
message_type="chunk",
|
||||
error=None
|
||||
),
|
||||
properties={"id": id}
|
||||
|
|
@ -115,6 +318,18 @@ class Processor(FlowProcessor):
|
|||
doc_limit=doc_limit,
|
||||
streaming=True,
|
||||
chunk_callback=send_chunk,
|
||||
explain_callback=send_explainability,
|
||||
save_answer_callback=save_answer,
|
||||
)
|
||||
|
||||
# Send end_of_session to signal entire session is complete
|
||||
await flow("response").send(
|
||||
DocumentRagResponse(
|
||||
response=None,
|
||||
end_of_session=True,
|
||||
message_type="end",
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
else:
|
||||
# Non-streaming path (existing behavior)
|
||||
|
|
@ -122,7 +337,9 @@ class Processor(FlowProcessor):
|
|||
v.query,
|
||||
user=v.user,
|
||||
collection=v.collection,
|
||||
doc_limit=doc_limit
|
||||
doc_limit=doc_limit,
|
||||
explain_callback=send_explainability,
|
||||
save_answer_callback=save_answer,
|
||||
)
|
||||
|
||||
await flow("response").send(
|
||||
|
|
|
|||
|
|
@ -1,14 +1,57 @@
|
|||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
|
||||
from ... schema import Term, Triple as SchemaTriple, IRI, LITERAL, TRIPLE
|
||||
|
||||
# Provenance imports
|
||||
from trustgraph.provenance import (
|
||||
question_uri,
|
||||
grounding_uri as make_grounding_uri,
|
||||
exploration_uri as make_exploration_uri,
|
||||
focus_uri as make_focus_uri,
|
||||
synthesis_uri as make_synthesis_uri,
|
||||
question_triples,
|
||||
grounding_triples,
|
||||
exploration_triples,
|
||||
focus_triples,
|
||||
synthesis_triples,
|
||||
set_graph,
|
||||
GRAPH_RETRIEVAL, GRAPH_SOURCE,
|
||||
TG_CONTAINS, PROV_WAS_DERIVED_FROM,
|
||||
)
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
|
||||
|
||||
def term_to_string(term):
|
||||
"""Extract string value from a Term object."""
|
||||
if term is None:
|
||||
return None
|
||||
if term.type == IRI:
|
||||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
# Fallback
|
||||
return term.iri or term.value or str(term)
|
||||
|
||||
|
||||
def edge_id(s, p, o):
|
||||
"""Generate an 8-character hash ID for an edge (s, p, o)."""
|
||||
edge_str = f"{s}|{p}|{o}"
|
||||
return hashlib.sha256(edge_str.encode()).hexdigest()[:8]
|
||||
|
||||
|
||||
|
||||
class LRUCacheWithTTL:
|
||||
"""LRU cache with TTL for label caching
|
||||
|
||||
|
|
@ -67,12 +110,32 @@ class Query:
|
|||
self.max_subgraph_size = max_subgraph_size
|
||||
self.max_path_length = max_path_length
|
||||
|
||||
async def get_vector(self, query):
|
||||
async def extract_concepts(self, query):
|
||||
"""Extract key concepts from query for independent embedding."""
|
||||
response = await self.rag.prompt_client.prompt(
|
||||
"extract-concepts",
|
||||
variables={"query": query}
|
||||
)
|
||||
|
||||
concepts = []
|
||||
if isinstance(response, str):
|
||||
for line in response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if line:
|
||||
concepts.append(line)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Extracted concepts: {concepts}")
|
||||
|
||||
# Fall back to raw query if extraction returns nothing
|
||||
return concepts if concepts else [query]
|
||||
|
||||
async def get_vectors(self, concepts):
|
||||
"""Embed multiple concepts concurrently."""
|
||||
if self.verbose:
|
||||
logger.debug("Computing embeddings...")
|
||||
|
||||
qembeds = await self.rag.embeddings_client.embed(query)
|
||||
qembeds = await self.rag.embeddings_client.embed(concepts)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
|
|
@ -80,28 +143,55 @@ class Query:
|
|||
return qembeds
|
||||
|
||||
async def get_entities(self, query):
|
||||
"""
|
||||
Extract concepts from query, embed them, and retrieve matching entities.
|
||||
|
||||
vectors = await self.get_vector(query)
|
||||
Returns:
|
||||
tuple: (entities, concepts) where entities is a list of entity URI
|
||||
strings and concepts is the list of concept strings extracted
|
||||
from the query.
|
||||
"""
|
||||
|
||||
concepts = await self.extract_concepts(query)
|
||||
|
||||
vectors = await self.get_vectors(concepts)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting entities...")
|
||||
|
||||
entities = await self.rag.graph_embeddings_client.query(
|
||||
vectors=vectors, limit=self.entity_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
# Query entity matches for each concept concurrently
|
||||
per_concept_limit = max(
|
||||
1, self.entity_limit // len(vectors)
|
||||
)
|
||||
|
||||
entities = [
|
||||
str(e)
|
||||
for e in entities
|
||||
entity_tasks = [
|
||||
self.rag.graph_embeddings_client.query(
|
||||
vector=v, limit=per_concept_limit,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
for v in vectors
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*entity_tasks, return_exceptions=True)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen = set()
|
||||
entities = []
|
||||
for result in results:
|
||||
if isinstance(result, Exception) or not result:
|
||||
continue
|
||||
for e in result:
|
||||
entity = term_to_string(e.entity)
|
||||
if entity not in seen:
|
||||
seen.add(entity)
|
||||
entities.append(entity)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Entities:")
|
||||
for ent in entities:
|
||||
logger.debug(f" {ent}")
|
||||
|
||||
return entities
|
||||
return entities, concepts
|
||||
|
||||
async def maybe_label(self, e):
|
||||
|
||||
|
|
@ -117,6 +207,7 @@ class Query:
|
|||
res = await self.rag.triples_client.query(
|
||||
s=e, p=LABEL, o=None, limit=1,
|
||||
user=self.user, collection=self.collection,
|
||||
g="",
|
||||
)
|
||||
|
||||
if len(res) == 0:
|
||||
|
|
@ -128,26 +219,29 @@ class Query:
|
|||
return label
|
||||
|
||||
async def execute_batch_triple_queries(self, entities, limit_per_entity):
|
||||
"""Execute triple queries for multiple entities concurrently"""
|
||||
"""Execute triple queries for multiple entities concurrently using streaming"""
|
||||
tasks = []
|
||||
|
||||
for entity in entities:
|
||||
# Create concurrent tasks for all 3 query types per entity
|
||||
# Create concurrent streaming tasks for all 3 query types per entity
|
||||
tasks.extend([
|
||||
self.rag.triples_client.query(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=entity, p=None, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20, g="",
|
||||
),
|
||||
self.rag.triples_client.query(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=entity, o=None,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20, g="",
|
||||
),
|
||||
self.rag.triples_client.query(
|
||||
self.rag.triples_client.query_stream(
|
||||
s=None, p=None, o=entity,
|
||||
limit=limit_per_entity,
|
||||
user=self.user, collection=self.collection
|
||||
user=self.user, collection=self.collection,
|
||||
batch_size=20, g="",
|
||||
)
|
||||
])
|
||||
|
||||
|
|
@ -157,7 +251,7 @@ class Query:
|
|||
# Combine all results
|
||||
all_triples = []
|
||||
for result in results:
|
||||
if not isinstance(result, Exception):
|
||||
if not isinstance(result, Exception) and result is not None:
|
||||
all_triples.extend(result)
|
||||
|
||||
return all_triples
|
||||
|
|
@ -220,8 +314,16 @@ class Query:
|
|||
subgraph.update(batch_result)
|
||||
|
||||
async def get_subgraph(self, query):
|
||||
"""
|
||||
Get subgraph by extracting concepts, finding entities, and traversing.
|
||||
|
||||
entities = await self.get_entities(query)
|
||||
Returns:
|
||||
tuple: (subgraph, entities, concepts) where subgraph is a list of
|
||||
(s, p, o) tuples, entities is the seed entity list, and concepts
|
||||
is the extracted concept list.
|
||||
"""
|
||||
|
||||
entities, concepts = await self.get_entities(query)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Getting subgraph...")
|
||||
|
|
@ -229,7 +331,7 @@ class Query:
|
|||
# Use optimized batch traversal instead of sequential processing
|
||||
subgraph = await self.follow_edges_batch(entities, self.max_path_length)
|
||||
|
||||
return list(subgraph)
|
||||
return list(subgraph), entities, concepts
|
||||
|
||||
async def resolve_labels_batch(self, entities):
|
||||
"""Resolve labels for multiple entities in parallel"""
|
||||
|
|
@ -240,8 +342,17 @@ class Query:
|
|||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def get_labelgraph(self, query):
|
||||
"""
|
||||
Get subgraph with labels resolved for display.
|
||||
|
||||
subgraph = await self.get_subgraph(query)
|
||||
Returns:
|
||||
tuple: (labeled_edges, uri_map, entities, concepts) where:
|
||||
- labeled_edges: list of (label_s, label_p, label_o) tuples
|
||||
- uri_map: dict mapping edge_id(label_s, label_p, label_o) -> (uri_s, uri_p, uri_o)
|
||||
- entities: list of seed entity URI strings
|
||||
- concepts: list of concept strings extracted from query
|
||||
"""
|
||||
subgraph, entities, concepts = await self.get_subgraph(query)
|
||||
|
||||
# Filter out label triples
|
||||
filtered_subgraph = [edge for edge in subgraph if edge[1] != LABEL]
|
||||
|
|
@ -263,28 +374,151 @@ class Query:
|
|||
else:
|
||||
label_map[entity] = entity # Fallback to entity itself
|
||||
|
||||
# Apply labels to subgraph
|
||||
sg2 = []
|
||||
# Apply labels to subgraph and build URI mapping
|
||||
labeled_edges = []
|
||||
uri_map = {} # Maps edge_id of labeled edge -> original URI triple
|
||||
|
||||
for s, p, o in filtered_subgraph:
|
||||
labeled_triple = (
|
||||
label_map.get(s, s),
|
||||
label_map.get(p, p),
|
||||
label_map.get(o, o)
|
||||
)
|
||||
sg2.append(labeled_triple)
|
||||
labeled_edges.append(labeled_triple)
|
||||
|
||||
sg2 = sg2[0:self.max_subgraph_size]
|
||||
# Map from labeled edge ID to original URIs
|
||||
labeled_eid = edge_id(labeled_triple[0], labeled_triple[1], labeled_triple[2])
|
||||
uri_map[labeled_eid] = (s, p, o)
|
||||
|
||||
labeled_edges = labeled_edges[0:self.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Subgraph:")
|
||||
for edge in sg2:
|
||||
for edge in labeled_edges:
|
||||
logger.debug(f" {str(edge)}")
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Done.")
|
||||
|
||||
return sg2
|
||||
|
||||
return labeled_edges, uri_map, entities, concepts
|
||||
|
||||
async def trace_source_documents(self, edge_uris):
|
||||
"""
|
||||
Trace selected edges back to their source documents via provenance.
|
||||
|
||||
Follows the chain: edge → subgraph (via tg:contains) → chunk →
|
||||
page → document (via prov:wasDerivedFrom), all in urn:graph:source.
|
||||
|
||||
Args:
|
||||
edge_uris: List of (s, p, o) URI string tuples
|
||||
|
||||
Returns:
|
||||
List of unique document titles
|
||||
"""
|
||||
# Step 1: Find subgraphs containing these edges via tg:contains
|
||||
subgraph_tasks = []
|
||||
for s, p, o in edge_uris:
|
||||
quoted = Term(
|
||||
type=TRIPLE,
|
||||
triple=SchemaTriple(
|
||||
s=Term(type=IRI, iri=s),
|
||||
p=Term(type=IRI, iri=p),
|
||||
o=Term(type=IRI, iri=o),
|
||||
)
|
||||
)
|
||||
subgraph_tasks.append(
|
||||
self.rag.triples_client.query(
|
||||
s=None, p=TG_CONTAINS, o=quoted, limit=1,
|
||||
user=self.user, collection=self.collection,
|
||||
g=GRAPH_SOURCE,
|
||||
)
|
||||
)
|
||||
|
||||
subgraph_results = await asyncio.gather(
|
||||
*subgraph_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# Collect unique subgraph URIs
|
||||
subgraph_uris = set()
|
||||
for result in subgraph_results:
|
||||
if isinstance(result, Exception) or not result:
|
||||
continue
|
||||
for triple in result:
|
||||
subgraph_uris.add(str(triple.s))
|
||||
|
||||
if not subgraph_uris:
|
||||
return []
|
||||
|
||||
# Step 2: Walk prov:wasDerivedFrom chain to find documents
|
||||
# Each level: query ?entity prov:wasDerivedFrom ?parent
|
||||
# Stop when we find entities typed tg:Document
|
||||
current_uris = subgraph_uris
|
||||
doc_uris = set()
|
||||
|
||||
for depth in range(4): # Max depth: subgraph → chunk → page → doc
|
||||
if not current_uris:
|
||||
break
|
||||
|
||||
derivation_tasks = [
|
||||
self.rag.triples_client.query(
|
||||
s=uri, p=PROV_WAS_DERIVED_FROM, o=None, limit=5,
|
||||
user=self.user, collection=self.collection,
|
||||
g=GRAPH_SOURCE,
|
||||
)
|
||||
for uri in current_uris
|
||||
]
|
||||
|
||||
derivation_results = await asyncio.gather(
|
||||
*derivation_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# URIs with no parent are root documents
|
||||
next_uris = set()
|
||||
for uri, result in zip(current_uris, derivation_results):
|
||||
if isinstance(result, Exception) or not result:
|
||||
doc_uris.add(uri)
|
||||
continue
|
||||
for triple in result:
|
||||
next_uris.add(str(triple.o))
|
||||
|
||||
current_uris = next_uris - doc_uris
|
||||
|
||||
if not doc_uris:
|
||||
return []
|
||||
|
||||
# Step 3: Get all document metadata properties
|
||||
# Skip structural predicates that aren't useful context
|
||||
SKIP_PREDICATES = {
|
||||
PROV_WAS_DERIVED_FROM,
|
||||
"http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
|
||||
}
|
||||
|
||||
metadata_tasks = [
|
||||
self.rag.triples_client.query(
|
||||
s=uri, p=None, o=None, limit=50,
|
||||
user=self.user, collection=self.collection,
|
||||
)
|
||||
for uri in doc_uris
|
||||
]
|
||||
|
||||
metadata_results = await asyncio.gather(
|
||||
*metadata_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
doc_edges = []
|
||||
for result in metadata_results:
|
||||
if isinstance(result, Exception) or not result:
|
||||
continue
|
||||
for triple in result:
|
||||
p = str(triple.p)
|
||||
if p in SKIP_PREDICATES:
|
||||
continue
|
||||
doc_edges.append((
|
||||
str(triple.s), p, str(triple.o)
|
||||
))
|
||||
|
||||
return doc_edges
|
||||
|
||||
class GraphRag:
|
||||
"""
|
||||
CRITICAL SECURITY:
|
||||
|
|
@ -316,12 +550,50 @@ class GraphRag:
|
|||
async def query(
|
||||
self, query, user = "trustgraph", collection = "default",
|
||||
entity_limit = 50, triple_limit = 30, max_subgraph_size = 1000,
|
||||
max_path_length = 2, streaming = False, chunk_callback = None,
|
||||
max_path_length = 2, edge_limit = 25, streaming = False,
|
||||
chunk_callback = None,
|
||||
explain_callback = None, save_answer_callback = None,
|
||||
):
|
||||
"""
|
||||
Execute a GraphRAG query with real-time explainability tracking.
|
||||
|
||||
Args:
|
||||
query: The query string
|
||||
user: User identifier
|
||||
collection: Collection identifier
|
||||
entity_limit: Max entities to retrieve
|
||||
triple_limit: Max triples per entity
|
||||
max_subgraph_size: Max edges in subgraph
|
||||
max_path_length: Max hops from seed entities
|
||||
streaming: Enable streaming LLM response
|
||||
chunk_callback: async def callback(chunk, end_of_stream) for streaming
|
||||
explain_callback: async def callback(triples, explain_id) for real-time explainability
|
||||
save_answer_callback: async def callback(doc_id, answer_text) -> doc_id to save answer to librarian
|
||||
|
||||
Returns:
|
||||
str: The synthesized answer text
|
||||
"""
|
||||
if self.verbose:
|
||||
logger.debug("Constructing prompt...")
|
||||
|
||||
# Generate explainability URIs upfront
|
||||
session_id = str(uuid.uuid4())
|
||||
q_uri = question_uri(session_id)
|
||||
gnd_uri = make_grounding_uri(session_id)
|
||||
exp_uri = make_exploration_uri(session_id)
|
||||
foc_uri = make_focus_uri(session_id)
|
||||
syn_uri = make_synthesis_uri(session_id)
|
||||
|
||||
timestamp = datetime.utcnow().isoformat() + "Z"
|
||||
|
||||
# Emit question explainability immediately
|
||||
if explain_callback:
|
||||
q_triples = set_graph(
|
||||
question_triples(q_uri, query, timestamp),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(q_triples, q_uri)
|
||||
|
||||
q = Query(
|
||||
rag = self, user = user, collection = collection,
|
||||
verbose = self.verbose, entity_limit = entity_limit,
|
||||
|
|
@ -330,24 +602,262 @@ class GraphRag:
|
|||
max_path_length = max_path_length,
|
||||
)
|
||||
|
||||
kg = await q.get_labelgraph(query)
|
||||
kg, uri_map, seed_entities, concepts = await q.get_labelgraph(query)
|
||||
|
||||
# Emit grounding explain after concept extraction
|
||||
if explain_callback:
|
||||
gnd_triples = set_graph(
|
||||
grounding_triples(gnd_uri, q_uri, concepts),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(gnd_triples, gnd_uri)
|
||||
|
||||
# Emit exploration explain after graph retrieval completes
|
||||
if explain_callback:
|
||||
exp_triples = set_graph(
|
||||
exploration_triples(
|
||||
exp_uri, gnd_uri, len(kg),
|
||||
entities=seed_entities,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(exp_triples, exp_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Invoking LLM...")
|
||||
logger.debug(f"Knowledge graph: {kg}")
|
||||
logger.debug(f"Query: {query}")
|
||||
|
||||
if streaming and chunk_callback:
|
||||
resp = await self.prompt_client.kg_prompt(
|
||||
query, kg,
|
||||
streaming=True,
|
||||
chunk_callback=chunk_callback
|
||||
# Build edge map: {hash_id: (labeled_s, labeled_p, labeled_o)}
|
||||
# uri_map already maps edge_id -> (uri_s, uri_p, uri_o)
|
||||
edge_map = {}
|
||||
edges_with_ids = []
|
||||
for s, p, o in kg:
|
||||
eid = edge_id(s, p, o)
|
||||
edge_map[eid] = (s, p, o)
|
||||
edges_with_ids.append({
|
||||
"id": eid,
|
||||
"s": s,
|
||||
"p": p,
|
||||
"o": o
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Built edge map with {len(edge_map)} edges")
|
||||
|
||||
# Step 1a: Edge Scoring - LLM scores edges for relevance
|
||||
scoring_response = await self.prompt_client.prompt(
|
||||
"kg-edge-scoring",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": edges_with_ids
|
||||
}
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge scoring response: {scoring_response}")
|
||||
|
||||
# Parse scoring response to get edge IDs with scores
|
||||
scored_edges = []
|
||||
|
||||
def parse_scored_edge(obj):
|
||||
if isinstance(obj, dict) and "id" in obj and "score" in obj:
|
||||
try:
|
||||
score = int(obj["score"])
|
||||
except (ValueError, TypeError):
|
||||
score = 0
|
||||
scored_edges.append({"id": obj["id"], "score": score})
|
||||
|
||||
if isinstance(scoring_response, list):
|
||||
for obj in scoring_response:
|
||||
parse_scored_edge(obj)
|
||||
elif isinstance(scoring_response, str):
|
||||
for line in scoring_response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
parse_scored_edge(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge scoring line: {line}"
|
||||
)
|
||||
|
||||
# Select top N edges by score
|
||||
scored_edges.sort(key=lambda x: x["score"], reverse=True)
|
||||
top_edges = scored_edges[:edge_limit]
|
||||
selected_ids = {e["id"] for e in top_edges}
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(
|
||||
f"Scored {len(scored_edges)} edges, "
|
||||
f"selected top {len(selected_ids)}"
|
||||
)
|
||||
|
||||
# Filter to selected edges
|
||||
selected_edges = []
|
||||
for eid in selected_ids:
|
||||
if eid in edge_map:
|
||||
selected_edges.append(edge_map[eid])
|
||||
|
||||
# Step 1b: Edge Reasoning + Document Tracing (concurrent)
|
||||
selected_edges_with_ids = [
|
||||
{"id": eid, "s": s, "p": p, "o": o}
|
||||
for eid in selected_ids
|
||||
if eid in edge_map
|
||||
for s, p, o in [edge_map[eid]]
|
||||
]
|
||||
|
||||
# Collect selected edge URIs for document tracing
|
||||
selected_edge_uris = [
|
||||
uri_map[eid]
|
||||
for eid in selected_ids
|
||||
if eid in uri_map
|
||||
]
|
||||
|
||||
# Run reasoning and document tracing concurrently
|
||||
reasoning_task = self.prompt_client.prompt(
|
||||
"kg-edge-reasoning",
|
||||
variables={
|
||||
"query": query,
|
||||
"knowledge": selected_edges_with_ids
|
||||
}
|
||||
)
|
||||
doc_trace_task = q.trace_source_documents(selected_edge_uris)
|
||||
|
||||
reasoning_response, source_documents = await asyncio.gather(
|
||||
reasoning_task, doc_trace_task, return_exceptions=True
|
||||
)
|
||||
|
||||
# Handle exceptions from gather
|
||||
if isinstance(reasoning_response, Exception):
|
||||
logger.warning(
|
||||
f"Edge reasoning failed: {reasoning_response}"
|
||||
)
|
||||
reasoning_response = ""
|
||||
if isinstance(source_documents, Exception):
|
||||
logger.warning(
|
||||
f"Document tracing failed: {source_documents}"
|
||||
)
|
||||
source_documents = []
|
||||
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Edge reasoning response: {reasoning_response}")
|
||||
|
||||
# Parse reasoning response and build explainability data
|
||||
reasoning_map = {}
|
||||
|
||||
def parse_reasoning(obj):
|
||||
if isinstance(obj, dict) and "id" in obj:
|
||||
reasoning_map[obj["id"]] = obj.get("reasoning", "")
|
||||
|
||||
if isinstance(reasoning_response, list):
|
||||
for obj in reasoning_response:
|
||||
parse_reasoning(obj)
|
||||
elif isinstance(reasoning_response, str):
|
||||
for line in reasoning_response.strip().split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
parse_reasoning(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Failed to parse edge reasoning line: {line}"
|
||||
)
|
||||
|
||||
selected_edges_with_reasoning = []
|
||||
for eid in selected_ids:
|
||||
if eid in uri_map:
|
||||
uri_s, uri_p, uri_o = uri_map[eid]
|
||||
selected_edges_with_reasoning.append({
|
||||
"edge": (uri_s, uri_p, uri_o),
|
||||
"reasoning": reasoning_map.get(eid, ""),
|
||||
})
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Filtered to {len(selected_edges)} edges")
|
||||
|
||||
# Emit focus explain after edge selection completes
|
||||
if explain_callback:
|
||||
foc_triples = set_graph(
|
||||
focus_triples(
|
||||
foc_uri, exp_uri, selected_edges_with_reasoning, session_id
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(foc_triples, foc_uri)
|
||||
|
||||
# Step 2: Synthesis - LLM generates answer from selected edges only
|
||||
selected_edge_dicts = [
|
||||
{"s": s, "p": p, "o": o}
|
||||
for s, p, o in selected_edges
|
||||
]
|
||||
|
||||
# Add source document metadata as knowledge edges
|
||||
for s, p, o in source_documents:
|
||||
selected_edge_dicts.append({
|
||||
"s": s, "p": p, "o": o,
|
||||
})
|
||||
|
||||
synthesis_variables = {
|
||||
"query": query,
|
||||
"knowledge": selected_edge_dicts,
|
||||
}
|
||||
|
||||
if streaming and chunk_callback:
|
||||
# Accumulate chunks for answer storage while forwarding to callback
|
||||
accumulated_chunks = []
|
||||
|
||||
async def accumulating_callback(chunk, end_of_stream):
|
||||
accumulated_chunks.append(chunk)
|
||||
await chunk_callback(chunk, end_of_stream)
|
||||
|
||||
await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables=synthesis_variables,
|
||||
streaming=True,
|
||||
chunk_callback=accumulating_callback
|
||||
)
|
||||
# Combine all chunks into full response
|
||||
resp = "".join(accumulated_chunks)
|
||||
else:
|
||||
resp = await self.prompt_client.kg_prompt(query, kg)
|
||||
resp = await self.prompt_client.prompt(
|
||||
"kg-synthesis",
|
||||
variables=synthesis_variables,
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug("Query processing complete")
|
||||
|
||||
# Emit synthesis explain after synthesis completes
|
||||
if explain_callback:
|
||||
synthesis_doc_id = None
|
||||
answer_text = resp if resp else ""
|
||||
|
||||
# Save answer to librarian
|
||||
if save_answer_callback and answer_text:
|
||||
synthesis_doc_id = f"urn:trustgraph:synthesis:{session_id}"
|
||||
try:
|
||||
await save_answer_callback(synthesis_doc_id, answer_text)
|
||||
if self.verbose:
|
||||
logger.debug(f"Saved answer to librarian: {synthesis_doc_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save answer to librarian: {e}")
|
||||
synthesis_doc_id = None
|
||||
|
||||
syn_triples = set_graph(
|
||||
synthesis_triples(
|
||||
syn_uri, foc_uri,
|
||||
document_id=synthesis_doc_id,
|
||||
),
|
||||
GRAPH_RETRIEVAL
|
||||
)
|
||||
await explain_callback(syn_triples, syn_uri)
|
||||
|
||||
if self.verbose:
|
||||
logger.debug(f"Emitted explain for session {session_id}")
|
||||
|
||||
return resp
|
||||
|
||||
|
|
|
|||
|
|
@ -4,18 +4,29 @@ Simple RAG service, performs query using graph RAG an LLM.
|
|||
Input is query, output is response.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from ... schema import GraphRagQuery, GraphRagResponse, Error
|
||||
from ... schema import Triples, Metadata
|
||||
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from ... provenance import GRAPH_RETRIEVAL
|
||||
from . graph_rag import GraphRag
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
||||
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "graph-rag"
|
||||
default_concurrency = 1
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
|
|
@ -28,6 +39,7 @@ class Processor(FlowProcessor):
|
|||
triple_limit = params.get("triple_limit", 30)
|
||||
max_subgraph_size = params.get("max_subgraph_size", 150)
|
||||
max_path_length = params.get("max_path_length", 2)
|
||||
edge_limit = params.get("edge_limit", 25)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
|
|
@ -37,6 +49,7 @@ class Processor(FlowProcessor):
|
|||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
"max_path_length": max_path_length,
|
||||
"edge_limit": edge_limit,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -44,6 +57,7 @@ class Processor(FlowProcessor):
|
|||
self.default_triple_limit = triple_limit
|
||||
self.default_max_subgraph_size = max_subgraph_size
|
||||
self.default_max_path_length = max_path_length
|
||||
self.default_edge_limit = edge_limit
|
||||
|
||||
# CRITICAL SECURITY: NEVER share data between users or collections
|
||||
# Each user/collection combination MUST have isolated data access
|
||||
|
|
@ -93,10 +107,163 @@ class Processor(FlowProcessor):
|
|||
)
|
||||
)
|
||||
|
||||
self.register_specification(
|
||||
ProducerSpec(
|
||||
name = "explainability",
|
||||
schema = Triples,
|
||||
)
|
||||
)
|
||||
|
||||
# Librarian client for storing answer content
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
backend=self.pubsub,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
backend=self.pubsub,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self.on_librarian_response,
|
||||
metrics=librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_librarian_requests = {}
|
||||
|
||||
logger.info("Graph RAG service initialized")
|
||||
|
||||
async def start(self):
|
||||
await super(Processor, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id and request_id in self.pending_librarian_requests:
|
||||
future = self.pending_librarian_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
else:
|
||||
logger.warning(f"Received unexpected librarian response: {request_id}")
|
||||
|
||||
async def save_answer_content(self, doc_id, user, content, title=None, timeout=120):
|
||||
"""
|
||||
Save answer content to the librarian.
|
||||
|
||||
Args:
|
||||
doc_id: ID for the answer document
|
||||
user: User ID
|
||||
content: Answer text content
|
||||
title: Optional title
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
The document ID on success
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or "GraphRAG Answer",
|
||||
document_type="answer",
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-document",
|
||||
document_id=doc_id,
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_librarian_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_librarian_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving answer document {doc_id}")
|
||||
|
||||
async def on_request(self, msg, consumer, flow):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.info(f"Handling input {id}...")
|
||||
|
||||
# Track explainability refs for end_of_session signaling
|
||||
explainability_refs_emitted = []
|
||||
|
||||
# Real-time explainability callback - emits triples and IDs as they're generated
|
||||
# Triples are stored in the user's collection with a named graph (urn:graph:retrieval)
|
||||
async def send_explainability(triples, explain_id):
|
||||
# Send triples to explainability queue - stores in same collection with named graph
|
||||
await flow("explainability").send(Triples(
|
||||
metadata=Metadata(
|
||||
id=explain_id,
|
||||
user=v.user,
|
||||
collection=v.collection, # Store in user's collection, not separate explainability collection
|
||||
),
|
||||
triples=triples,
|
||||
))
|
||||
|
||||
# Send explain ID and graph to response queue
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="explain",
|
||||
explain_id=explain_id,
|
||||
explain_graph=GRAPH_RETRIEVAL,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
explainability_refs_emitted.append(explain_id)
|
||||
|
||||
# CRITICAL SECURITY: Create new GraphRag instance per request
|
||||
# This ensures proper isolation between users and collections
|
||||
# Flow clients are request-scoped and must not be shared
|
||||
|
|
@ -108,13 +275,6 @@ class Processor(FlowProcessor):
|
|||
verbose=True,
|
||||
)
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
logger.info(f"Handling input {id}...")
|
||||
|
||||
if v.entity_limit:
|
||||
entity_limit = v.entity_limit
|
||||
else:
|
||||
|
|
@ -135,6 +295,20 @@ class Processor(FlowProcessor):
|
|||
else:
|
||||
max_path_length = self.default_max_path_length
|
||||
|
||||
if v.edge_limit:
|
||||
edge_limit = v.edge_limit
|
||||
else:
|
||||
edge_limit = self.default_edge_limit
|
||||
|
||||
# Callback to save answer content to librarian
|
||||
async def save_answer(doc_id, answer_text):
|
||||
await self.save_answer_content(
|
||||
doc_id=doc_id,
|
||||
user=v.user,
|
||||
content=answer_text,
|
||||
title=f"GraphRAG Answer: {v.query[:50]}...",
|
||||
)
|
||||
|
||||
# Check if streaming is requested
|
||||
if v.streaming:
|
||||
# Define async callback for streaming chunks
|
||||
|
|
@ -142,6 +316,7 @@ class Processor(FlowProcessor):
|
|||
async def send_chunk(chunk, end_of_stream):
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response=chunk,
|
||||
end_of_stream=end_of_stream,
|
||||
error=None
|
||||
|
|
@ -149,34 +324,52 @@ class Processor(FlowProcessor):
|
|||
properties={"id": id}
|
||||
)
|
||||
|
||||
# Query with streaming enabled
|
||||
# All chunks (including final one with end_of_stream=True) are sent via callback
|
||||
await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
)
|
||||
else:
|
||||
# Non-streaming path (existing behavior)
|
||||
# Query with streaming and real-time explain
|
||||
response = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_limit = edge_limit,
|
||||
streaming = True,
|
||||
chunk_callback = send_chunk,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
)
|
||||
|
||||
else:
|
||||
# Non-streaming path with real-time explain
|
||||
response = await rag.query(
|
||||
query = v.query, user = v.user, collection = v.collection,
|
||||
entity_limit = entity_limit, triple_limit = triple_limit,
|
||||
max_subgraph_size = max_subgraph_size,
|
||||
max_path_length = max_path_length,
|
||||
edge_limit = edge_limit,
|
||||
explain_callback = send_explainability,
|
||||
save_answer_callback = save_answer,
|
||||
)
|
||||
|
||||
# Send chunk with response
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
response = response,
|
||||
end_of_stream = True,
|
||||
error = None
|
||||
message_type="chunk",
|
||||
response=response,
|
||||
end_of_stream=True,
|
||||
error=None,
|
||||
),
|
||||
properties = {"id": id}
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
# Send final message to close session
|
||||
await flow("response").send(
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
response="",
|
||||
end_of_session=True,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
logger.info("Request processing complete")
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -185,22 +378,18 @@ class Processor(FlowProcessor):
|
|||
|
||||
logger.debug("Sending error response...")
|
||||
|
||||
# Send error response with end_of_stream flag if streaming was requested
|
||||
error_response = GraphRagResponse(
|
||||
response = None,
|
||||
error = Error(
|
||||
type = "graph-rag-error",
|
||||
message = str(e),
|
||||
),
|
||||
)
|
||||
|
||||
# If streaming was requested, indicate stream end
|
||||
if v.streaming:
|
||||
error_response.end_of_stream = True
|
||||
|
||||
# Send error response and close session
|
||||
await flow("response").send(
|
||||
error_response,
|
||||
properties = {"id": id}
|
||||
GraphRagResponse(
|
||||
message_type="chunk",
|
||||
error=Error(
|
||||
type="graph-rag-error",
|
||||
message=str(e),
|
||||
),
|
||||
end_of_stream=True,
|
||||
end_of_session=True,
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -243,6 +432,9 @@ class Processor(FlowProcessor):
|
|||
help=f'Default max path length (default: 2)'
|
||||
)
|
||||
|
||||
# Note: Explainability triples are now stored in the user's collection
|
||||
# with the named graph urn:graph:retrieval (no separate collection needed)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
|
|||
|
|
@ -37,14 +37,14 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
|
||||
for emb in message.chunks:
|
||||
|
||||
if emb.chunk is None or emb.chunk == b"": continue
|
||||
chunk_id = emb.chunk_id
|
||||
if chunk_id == "":
|
||||
continue
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
vec = emb.vector
|
||||
if vec:
|
||||
self.vecstore.insert(
|
||||
vec, chunk,
|
||||
vec, chunk_id,
|
||||
message.metadata.user,
|
||||
message.metadata.collection
|
||||
)
|
||||
|
|
|
|||
|
|
@ -101,40 +101,41 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
|
||||
for emb in message.chunks:
|
||||
|
||||
if emb.chunk is None or emb.chunk == b"": continue
|
||||
chunk_id = emb.chunk_id
|
||||
if chunk_id == "":
|
||||
continue
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": continue
|
||||
vec = emb.vector
|
||||
if not vec:
|
||||
continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
# Create index name with dimension suffix for lazy creation
|
||||
dim = len(vec)
|
||||
index_name = (
|
||||
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
|
||||
)
|
||||
|
||||
# Create index name with dimension suffix for lazy creation
|
||||
dim = len(vec)
|
||||
index_name = (
|
||||
f"d-{message.metadata.user}-{message.metadata.collection}-{dim}"
|
||||
)
|
||||
# Lazily create index if it doesn't exist (but only if authorized in config)
|
||||
if not self.pinecone.has_index(index_name):
|
||||
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
|
||||
self.create_index(index_name, dim)
|
||||
|
||||
# Lazily create index if it doesn't exist (but only if authorized in config)
|
||||
if not self.pinecone.has_index(index_name):
|
||||
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
|
||||
self.create_index(index_name, dim)
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
# Generate unique ID for each vector
|
||||
vector_id = str(uuid.uuid4())
|
||||
|
||||
# Generate unique ID for each vector
|
||||
vector_id = str(uuid.uuid4())
|
||||
records = [
|
||||
{
|
||||
"id": vector_id,
|
||||
"values": vec,
|
||||
"metadata": { "chunk_id": chunk_id },
|
||||
}
|
||||
]
|
||||
|
||||
records = [
|
||||
{
|
||||
"id": vector_id,
|
||||
"values": vec,
|
||||
"metadata": { "doc": chunk },
|
||||
}
|
||||
]
|
||||
|
||||
index.upsert(
|
||||
vectors = records,
|
||||
)
|
||||
index.upsert(
|
||||
vectors = records,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
|
|
@ -52,41 +52,44 @@ class Processor(CollectionConfigHandler, DocumentEmbeddingsStoreService):
|
|||
|
||||
for emb in message.chunks:
|
||||
|
||||
chunk = emb.chunk.decode("utf-8")
|
||||
if chunk == "": return
|
||||
chunk_id = emb.chunk_id
|
||||
if chunk_id == "":
|
||||
continue
|
||||
|
||||
for vec in emb.vectors:
|
||||
vec = emb.vector
|
||||
if not vec:
|
||||
continue
|
||||
|
||||
# Create collection name with dimension suffix for lazy creation
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
|
||||
)
|
||||
# Create collection name with dimension suffix for lazy creation
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
f"d_{message.metadata.user}_{message.metadata.collection}_{dim}"
|
||||
)
|
||||
|
||||
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=VectorParams(
|
||||
size=dim,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
|
||||
self.qdrant.upsert(
|
||||
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=vec,
|
||||
payload={
|
||||
"doc": chunk,
|
||||
}
|
||||
)
|
||||
]
|
||||
vectors_config=VectorParams(
|
||||
size=dim,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
|
||||
self.qdrant.upsert(
|
||||
collection_name=collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=vec,
|
||||
payload={
|
||||
"chunk_id": chunk_id,
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
|
|||
|
|
@ -53,11 +53,13 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
entity_value = get_term_value(entity.entity)
|
||||
|
||||
if entity_value != "" and entity_value is not None:
|
||||
for vec in entity.vectors:
|
||||
vec = entity.vector
|
||||
if vec:
|
||||
self.vecstore.insert(
|
||||
vec, entity_value,
|
||||
message.metadata.user,
|
||||
message.metadata.collection
|
||||
message.metadata.collection,
|
||||
chunk_id=entity.chunk_id or "",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -119,35 +119,41 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
if entity_value == "" or entity_value is None:
|
||||
continue
|
||||
|
||||
for vec in entity.vectors:
|
||||
vec = entity.vector
|
||||
if not vec:
|
||||
continue
|
||||
|
||||
# Create index name with dimension suffix for lazy creation
|
||||
dim = len(vec)
|
||||
index_name = (
|
||||
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
|
||||
)
|
||||
# Create index name with dimension suffix for lazy creation
|
||||
dim = len(vec)
|
||||
index_name = (
|
||||
f"t-{message.metadata.user}-{message.metadata.collection}-{dim}"
|
||||
)
|
||||
|
||||
# Lazily create index if it doesn't exist (but only if authorized in config)
|
||||
if not self.pinecone.has_index(index_name):
|
||||
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
|
||||
self.create_index(index_name, dim)
|
||||
# Lazily create index if it doesn't exist (but only if authorized in config)
|
||||
if not self.pinecone.has_index(index_name):
|
||||
logger.info(f"Lazily creating Pinecone index {index_name} with dimension {dim}")
|
||||
self.create_index(index_name, dim)
|
||||
|
||||
index = self.pinecone.Index(index_name)
|
||||
index = self.pinecone.Index(index_name)
|
||||
|
||||
# Generate unique ID for each vector
|
||||
vector_id = str(uuid.uuid4())
|
||||
# Generate unique ID for each vector
|
||||
vector_id = str(uuid.uuid4())
|
||||
|
||||
records = [
|
||||
{
|
||||
"id": vector_id,
|
||||
"values": vec,
|
||||
"metadata": { "entity": entity_value },
|
||||
}
|
||||
]
|
||||
metadata = {"entity": entity_value}
|
||||
if entity.chunk_id:
|
||||
metadata["chunk_id"] = entity.chunk_id
|
||||
|
||||
index.upsert(
|
||||
vectors = records,
|
||||
)
|
||||
records = [
|
||||
{
|
||||
"id": vector_id,
|
||||
"values": vec,
|
||||
"metadata": metadata,
|
||||
}
|
||||
]
|
||||
|
||||
index.upsert(
|
||||
vectors = records,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
|
|
|||
|
|
@ -71,38 +71,44 @@ class Processor(CollectionConfigHandler, GraphEmbeddingsStoreService):
|
|||
if entity_value == "" or entity_value is None:
|
||||
continue
|
||||
|
||||
for vec in entity.vectors:
|
||||
vec = entity.vector
|
||||
if not vec:
|
||||
continue
|
||||
|
||||
# Create collection name with dimension suffix for lazy creation
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
|
||||
)
|
||||
# Create collection name with dimension suffix for lazy creation
|
||||
dim = len(vec)
|
||||
collection = (
|
||||
f"t_{message.metadata.user}_{message.metadata.collection}_{dim}"
|
||||
)
|
||||
|
||||
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=VectorParams(
|
||||
size=dim,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
|
||||
self.qdrant.upsert(
|
||||
# Lazily create collection if it doesn't exist (but only if authorized in config)
|
||||
if not self.qdrant.collection_exists(collection):
|
||||
logger.info(f"Lazily creating Qdrant collection {collection} with dimension {dim}")
|
||||
self.qdrant.create_collection(
|
||||
collection_name=collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=vec,
|
||||
payload={
|
||||
"entity": entity_value,
|
||||
}
|
||||
)
|
||||
]
|
||||
vectors_config=VectorParams(
|
||||
size=dim,
|
||||
distance=Distance.COSINE
|
||||
)
|
||||
)
|
||||
|
||||
payload = {
|
||||
"entity": entity_value,
|
||||
}
|
||||
if entity.chunk_id:
|
||||
payload["chunk_id"] = entity.chunk_id
|
||||
|
||||
self.qdrant.upsert(
|
||||
collection_name=collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=vec,
|
||||
payload=payload,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
|
|
|
|||
|
|
@ -133,39 +133,38 @@ class Processor(CollectionConfigHandler, FlowProcessor):
|
|||
qdrant_collection = None
|
||||
|
||||
for row_emb in embeddings.embeddings:
|
||||
if not row_emb.vectors:
|
||||
vector = row_emb.vector
|
||||
if not vector:
|
||||
logger.warning(
|
||||
f"No vectors for index {row_emb.index_name} - skipping"
|
||||
f"No vector for index {row_emb.index_name} - skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Use first vector (there may be multiple from different models)
|
||||
for vector in row_emb.vectors:
|
||||
dimension = len(vector)
|
||||
dimension = len(vector)
|
||||
|
||||
# Create/get collection name (lazily on first vector)
|
||||
if qdrant_collection is None:
|
||||
qdrant_collection = self.get_collection_name(
|
||||
user, collection, schema_name, dimension
|
||||
)
|
||||
self.ensure_collection(qdrant_collection, dimension)
|
||||
|
||||
# Write to Qdrant
|
||||
self.qdrant.upsert(
|
||||
collection_name=qdrant_collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=vector,
|
||||
payload={
|
||||
"index_name": row_emb.index_name,
|
||||
"index_value": row_emb.index_value,
|
||||
"text": row_emb.text
|
||||
}
|
||||
)
|
||||
]
|
||||
# Create/get collection name (lazily on first vector)
|
||||
if qdrant_collection is None:
|
||||
qdrant_collection = self.get_collection_name(
|
||||
user, collection, schema_name, dimension
|
||||
)
|
||||
embeddings_written += 1
|
||||
self.ensure_collection(qdrant_collection, dimension)
|
||||
|
||||
# Write to Qdrant
|
||||
self.qdrant.upsert(
|
||||
collection_name=qdrant_collection,
|
||||
points=[
|
||||
PointStruct(
|
||||
id=str(uuid.uuid4()),
|
||||
vector=vector,
|
||||
payload={
|
||||
"index_name": row_emb.index_name,
|
||||
"index_value": row_emb.index_value,
|
||||
"text": row_emb.text
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
embeddings_written += 1
|
||||
|
||||
logger.info(f"Wrote {embeddings_written} embeddings to Qdrant")
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import os
|
|||
import argparse
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
|
||||
from .... direct.cassandra_kg import (
|
||||
EntityCentricKnowledgeGraph, DEFAULT_GRAPH
|
||||
|
|
@ -25,6 +26,37 @@ logger = logging.getLogger(__name__)
|
|||
default_ident = "triples-write"
|
||||
|
||||
|
||||
def serialize_triple(triple):
|
||||
"""Serialize a Triple object to JSON for storage."""
|
||||
if triple is None:
|
||||
return None
|
||||
|
||||
def term_to_dict(term):
|
||||
if term is None:
|
||||
return None
|
||||
|
||||
result = {"type": term.type}
|
||||
if term.type == IRI:
|
||||
result["iri"] = term.iri
|
||||
elif term.type == LITERAL:
|
||||
result["value"] = term.value
|
||||
if term.datatype:
|
||||
result["datatype"] = term.datatype
|
||||
if term.language:
|
||||
result["language"] = term.language
|
||||
elif term.type == BLANK:
|
||||
result["id"] = term.id
|
||||
elif term.type == TRIPLE:
|
||||
result["triple"] = serialize_triple(term.triple)
|
||||
return result
|
||||
|
||||
return json.dumps({
|
||||
"s": term_to_dict(triple.s),
|
||||
"p": term_to_dict(triple.p),
|
||||
"o": term_to_dict(triple.o),
|
||||
})
|
||||
|
||||
|
||||
def get_term_value(term):
|
||||
"""Extract the string value from a Term"""
|
||||
if term is None:
|
||||
|
|
@ -33,6 +65,9 @@ def get_term_value(term):
|
|||
return term.iri
|
||||
elif term.type == LITERAL:
|
||||
return term.value
|
||||
elif term.type == TRIPLE:
|
||||
# Serialize nested triple as JSON
|
||||
return serialize_triple(term.triple)
|
||||
else:
|
||||
# For blank nodes or other types, use id or value
|
||||
return term.id or term.value
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class KnowledgeTableStore:
|
|||
entity_embeddings list<
|
||||
tuple<
|
||||
tuple<text, boolean>,
|
||||
list<list<double>>
|
||||
list<double>
|
||||
>
|
||||
>,
|
||||
PRIMARY KEY ((user, document_id), id)
|
||||
|
|
@ -140,7 +140,7 @@ class KnowledgeTableStore:
|
|||
chunks list<
|
||||
tuple<
|
||||
blob,
|
||||
list<list<double>>
|
||||
list<double>
|
||||
>
|
||||
>,
|
||||
PRIMARY KEY ((user, document_id), id)
|
||||
|
|
@ -218,16 +218,6 @@ class KnowledgeTableStore:
|
|||
|
||||
when = int(time.time() * 1000)
|
||||
|
||||
if m.metadata.metadata:
|
||||
metadata = [
|
||||
(
|
||||
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
|
||||
)
|
||||
for v in m.metadata.metadata
|
||||
]
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
triples = [
|
||||
(
|
||||
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
|
||||
|
|
@ -243,8 +233,8 @@ class KnowledgeTableStore:
|
|||
self.insert_triples_stmt,
|
||||
(
|
||||
uuid.uuid4(), m.metadata.user,
|
||||
m.metadata.id, when,
|
||||
metadata, triples,
|
||||
m.metadata.root or m.metadata.id, when,
|
||||
[], triples,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -259,20 +249,10 @@ class KnowledgeTableStore:
|
|||
|
||||
when = int(time.time() * 1000)
|
||||
|
||||
if m.metadata.metadata:
|
||||
metadata = [
|
||||
(
|
||||
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
|
||||
)
|
||||
for v in m.metadata.metadata
|
||||
]
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
entities = [
|
||||
(
|
||||
term_to_tuple(v.entity),
|
||||
v.vectors
|
||||
v.vector
|
||||
)
|
||||
for v in m.entities
|
||||
]
|
||||
|
|
@ -285,8 +265,8 @@ class KnowledgeTableStore:
|
|||
self.insert_graph_embeddings_stmt,
|
||||
(
|
||||
uuid.uuid4(), m.metadata.user,
|
||||
m.metadata.id, when,
|
||||
metadata, entities,
|
||||
m.metadata.root or m.metadata.id, when,
|
||||
[], entities,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -301,20 +281,10 @@ class KnowledgeTableStore:
|
|||
|
||||
when = int(time.time() * 1000)
|
||||
|
||||
if m.metadata.metadata:
|
||||
metadata = [
|
||||
(
|
||||
*term_to_tuple(v.s), *term_to_tuple(v.p), *term_to_tuple(v.o)
|
||||
)
|
||||
for v in m.metadata.metadata
|
||||
]
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
chunks = [
|
||||
(
|
||||
v.chunk,
|
||||
v.vectors,
|
||||
v.chunk_id,
|
||||
v.vector,
|
||||
)
|
||||
for v in m.chunks
|
||||
]
|
||||
|
|
@ -327,8 +297,8 @@ class KnowledgeTableStore:
|
|||
self.insert_document_embeddings_stmt,
|
||||
(
|
||||
uuid.uuid4(), m.metadata.user,
|
||||
m.metadata.id, when,
|
||||
metadata, chunks,
|
||||
m.metadata.root or m.metadata.id, when,
|
||||
[], chunks,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -423,18 +393,6 @@ class KnowledgeTableStore:
|
|||
|
||||
for row in resp:
|
||||
|
||||
if row[2]:
|
||||
metadata = [
|
||||
Triple(
|
||||
s = tuple_to_term(elt[0], elt[1]),
|
||||
p = tuple_to_term(elt[2], elt[3]),
|
||||
o = tuple_to_term(elt[4], elt[5]),
|
||||
)
|
||||
for elt in row[2]
|
||||
]
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
if row[3]:
|
||||
triples = [
|
||||
Triple(
|
||||
|
|
@ -453,7 +411,6 @@ class KnowledgeTableStore:
|
|||
id = document_id,
|
||||
user = user,
|
||||
collection = "default", # FIXME: What to put here?
|
||||
metadata = metadata,
|
||||
),
|
||||
triples = triples
|
||||
)
|
||||
|
|
@ -482,18 +439,6 @@ class KnowledgeTableStore:
|
|||
|
||||
for row in resp:
|
||||
|
||||
if row[2]:
|
||||
metadata = [
|
||||
Triple(
|
||||
s = tuple_to_term(elt[0], elt[1]),
|
||||
p = tuple_to_term(elt[2], elt[3]),
|
||||
o = tuple_to_term(elt[4], elt[5]),
|
||||
)
|
||||
for elt in row[2]
|
||||
]
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
if row[3]:
|
||||
entities = [
|
||||
EntityEmbeddings(
|
||||
|
|
@ -511,7 +456,6 @@ class KnowledgeTableStore:
|
|||
id = document_id,
|
||||
user = user,
|
||||
collection = "default", # FIXME: What to put here?
|
||||
metadata = metadata,
|
||||
),
|
||||
entities = entities
|
||||
)
|
||||
|
|
|
|||
|
|
@ -112,6 +112,34 @@ class LibraryTableStore:
|
|||
ON document (object_id)
|
||||
""");
|
||||
|
||||
# Add parent_id and document_type columns for child document support
|
||||
logger.debug("document table parent_id column...")
|
||||
|
||||
try:
|
||||
self.cassandra.execute("""
|
||||
ALTER TABLE document ADD parent_id text
|
||||
""");
|
||||
except Exception as e:
|
||||
# Column may already exist
|
||||
if "already exists" not in str(e).lower() and "Invalid column name" not in str(e):
|
||||
logger.debug(f"parent_id column may already exist: {e}")
|
||||
|
||||
try:
|
||||
self.cassandra.execute("""
|
||||
ALTER TABLE document ADD document_type text
|
||||
""");
|
||||
except Exception as e:
|
||||
# Column may already exist
|
||||
if "already exists" not in str(e).lower() and "Invalid column name" not in str(e):
|
||||
logger.debug(f"document_type column may already exist: {e}")
|
||||
|
||||
logger.debug("document parent index...")
|
||||
|
||||
self.cassandra.execute("""
|
||||
CREATE INDEX IF NOT EXISTS document_parent
|
||||
ON document (parent_id)
|
||||
""");
|
||||
|
||||
logger.debug("processing table...")
|
||||
|
||||
self.cassandra.execute("""
|
||||
|
|
@ -127,6 +155,32 @@ class LibraryTableStore:
|
|||
);
|
||||
""");
|
||||
|
||||
logger.debug("upload_session table...")
|
||||
|
||||
self.cassandra.execute("""
|
||||
CREATE TABLE IF NOT EXISTS upload_session (
|
||||
upload_id text PRIMARY KEY,
|
||||
user text,
|
||||
document_id text,
|
||||
document_metadata text,
|
||||
s3_upload_id text,
|
||||
object_id uuid,
|
||||
total_size bigint,
|
||||
chunk_size int,
|
||||
total_chunks int,
|
||||
chunks_received map<int, text>,
|
||||
created_at timestamp,
|
||||
updated_at timestamp
|
||||
) WITH default_time_to_live = 86400;
|
||||
""");
|
||||
|
||||
logger.debug("upload_session user index...")
|
||||
|
||||
self.cassandra.execute("""
|
||||
CREATE INDEX IF NOT EXISTS upload_session_user
|
||||
ON upload_session (user)
|
||||
""");
|
||||
|
||||
logger.info("Cassandra schema OK.")
|
||||
|
||||
def prepare_statements(self):
|
||||
|
|
@ -136,9 +190,10 @@ class LibraryTableStore:
|
|||
(
|
||||
id, user, time,
|
||||
kind, title, comments,
|
||||
metadata, tags, object_id
|
||||
metadata, tags, object_id,
|
||||
parent_id, document_type
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""")
|
||||
|
||||
self.update_document_stmt = self.cassandra.prepare("""
|
||||
|
|
@ -149,7 +204,8 @@ class LibraryTableStore:
|
|||
""")
|
||||
|
||||
self.get_document_stmt = self.cassandra.prepare("""
|
||||
SELECT time, kind, title, comments, metadata, tags, object_id
|
||||
SELECT time, kind, title, comments, metadata, tags, object_id,
|
||||
parent_id, document_type
|
||||
FROM document
|
||||
WHERE user = ? AND id = ?
|
||||
""")
|
||||
|
|
@ -168,14 +224,16 @@ class LibraryTableStore:
|
|||
|
||||
self.list_document_stmt = self.cassandra.prepare("""
|
||||
SELECT
|
||||
id, time, kind, title, comments, metadata, tags, object_id
|
||||
id, time, kind, title, comments, metadata, tags, object_id,
|
||||
parent_id, document_type
|
||||
FROM document
|
||||
WHERE user = ?
|
||||
""")
|
||||
|
||||
self.list_document_by_tag_stmt = self.cassandra.prepare("""
|
||||
SELECT
|
||||
id, time, kind, title, comments, metadata, tags, object_id
|
||||
id, time, kind, title, comments, metadata, tags, object_id,
|
||||
parent_id, document_type
|
||||
FROM document
|
||||
WHERE user = ? AND tags CONTAINS ?
|
||||
ALLOW FILTERING
|
||||
|
|
@ -210,6 +268,57 @@ class LibraryTableStore:
|
|||
WHERE user = ?
|
||||
""")
|
||||
|
||||
# Upload session prepared statements
|
||||
self.insert_upload_session_stmt = self.cassandra.prepare("""
|
||||
INSERT INTO upload_session
|
||||
(
|
||||
upload_id, user, document_id, document_metadata,
|
||||
s3_upload_id, object_id, total_size, chunk_size,
|
||||
total_chunks, chunks_received, created_at, updated_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""")
|
||||
|
||||
self.get_upload_session_stmt = self.cassandra.prepare("""
|
||||
SELECT
|
||||
upload_id, user, document_id, document_metadata,
|
||||
s3_upload_id, object_id, total_size, chunk_size,
|
||||
total_chunks, chunks_received, created_at, updated_at
|
||||
FROM upload_session
|
||||
WHERE upload_id = ?
|
||||
""")
|
||||
|
||||
self.update_upload_session_chunk_stmt = self.cassandra.prepare("""
|
||||
UPDATE upload_session
|
||||
SET chunks_received = chunks_received + ?,
|
||||
updated_at = ?
|
||||
WHERE upload_id = ?
|
||||
""")
|
||||
|
||||
self.delete_upload_session_stmt = self.cassandra.prepare("""
|
||||
DELETE FROM upload_session
|
||||
WHERE upload_id = ?
|
||||
""")
|
||||
|
||||
self.list_upload_sessions_stmt = self.cassandra.prepare("""
|
||||
SELECT
|
||||
upload_id, document_id, document_metadata,
|
||||
total_size, chunk_size, total_chunks,
|
||||
chunks_received, created_at, updated_at
|
||||
FROM upload_session
|
||||
WHERE user = ?
|
||||
""")
|
||||
|
||||
# Child document queries
|
||||
self.list_children_stmt = self.cassandra.prepare("""
|
||||
SELECT
|
||||
id, user, time, kind, title, comments, metadata, tags,
|
||||
object_id, parent_id, document_type
|
||||
FROM document
|
||||
WHERE parent_id = ?
|
||||
ALLOW FILTERING
|
||||
""")
|
||||
|
||||
async def document_exists(self, user, id):
|
||||
|
||||
resp = self.cassandra.execute(
|
||||
|
|
@ -236,6 +345,10 @@ class LibraryTableStore:
|
|||
for v in document.metadata
|
||||
]
|
||||
|
||||
# Get parent_id and document_type from document, defaulting if not set
|
||||
parent_id = getattr(document, 'parent_id', '') or ''
|
||||
document_type = getattr(document, 'document_type', 'source') or 'source'
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
|
|
@ -245,7 +358,8 @@ class LibraryTableStore:
|
|||
(
|
||||
document.id, document.user, int(document.time * 1000),
|
||||
document.kind, document.title, document.comments,
|
||||
metadata, document.tags, object_id
|
||||
metadata, document.tags, object_id,
|
||||
parent_id, document_type
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -349,9 +463,58 @@ class LibraryTableStore:
|
|||
p=tuple_to_term(m[2], m[3]),
|
||||
o=tuple_to_term(m[4], m[5])
|
||||
)
|
||||
for m in row[5]
|
||||
for m in (row[5] or [])
|
||||
],
|
||||
tags = row[6] if row[6] else [],
|
||||
parent_id = row[8] if row[8] else "",
|
||||
document_type = row[9] if row[9] else "source",
|
||||
)
|
||||
for row in resp
|
||||
]
|
||||
|
||||
logger.debug("Done")
|
||||
|
||||
return lst
|
||||
|
||||
async def list_children(self, parent_id):
|
||||
"""List all child documents for a given parent document ID."""
|
||||
|
||||
logger.debug(f"List children for parent {parent_id}")
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
|
||||
resp = self.cassandra.execute(
|
||||
self.list_children_stmt,
|
||||
(parent_id,)
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Exception occurred", exc_info=True)
|
||||
raise e
|
||||
|
||||
lst = [
|
||||
DocumentMetadata(
|
||||
id = row[0],
|
||||
user = row[1],
|
||||
time = int(time.mktime(row[2].timetuple())),
|
||||
kind = row[3],
|
||||
title = row[4],
|
||||
comments = row[5],
|
||||
metadata = [
|
||||
Triple(
|
||||
s=tuple_to_term(m[0], m[1]),
|
||||
p=tuple_to_term(m[2], m[3]),
|
||||
o=tuple_to_term(m[4], m[5])
|
||||
)
|
||||
for m in (row[6] or [])
|
||||
],
|
||||
tags = row[7] if row[7] else [],
|
||||
parent_id = row[9] if row[9] else "",
|
||||
document_type = row[10] if row[10] else "source",
|
||||
)
|
||||
for row in resp
|
||||
]
|
||||
|
|
@ -394,9 +557,11 @@ class LibraryTableStore:
|
|||
p=tuple_to_term(m[2], m[3]),
|
||||
o=tuple_to_term(m[4], m[5])
|
||||
)
|
||||
for m in row[4]
|
||||
for m in (row[4] or [])
|
||||
],
|
||||
tags = row[5] if row[5] else [],
|
||||
parent_id = row[7] if row[7] else "",
|
||||
document_type = row[8] if row[8] else "source",
|
||||
)
|
||||
|
||||
logger.debug("Done")
|
||||
|
|
@ -532,3 +697,152 @@ class LibraryTableStore:
|
|||
logger.debug("Done")
|
||||
|
||||
return lst
|
||||
|
||||
# Upload session methods
|
||||
|
||||
async def create_upload_session(
|
||||
self,
|
||||
upload_id,
|
||||
user,
|
||||
document_id,
|
||||
document_metadata,
|
||||
s3_upload_id,
|
||||
object_id,
|
||||
total_size,
|
||||
chunk_size,
|
||||
total_chunks,
|
||||
):
|
||||
"""Create a new upload session for chunked upload."""
|
||||
|
||||
logger.info(f"Creating upload session {upload_id}")
|
||||
|
||||
now = int(time.time() * 1000)
|
||||
|
||||
while True:
|
||||
try:
|
||||
self.cassandra.execute(
|
||||
self.insert_upload_session_stmt,
|
||||
(
|
||||
upload_id, user, document_id, document_metadata,
|
||||
s3_upload_id, object_id, total_size, chunk_size,
|
||||
total_chunks, {}, now, now
|
||||
)
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Exception occurred", exc_info=True)
|
||||
raise e
|
||||
|
||||
logger.debug("Upload session created")
|
||||
|
||||
async def get_upload_session(self, upload_id):
|
||||
"""Get an upload session by ID."""
|
||||
|
||||
logger.debug(f"Get upload session {upload_id}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
resp = self.cassandra.execute(
|
||||
self.get_upload_session_stmt,
|
||||
(upload_id,)
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Exception occurred", exc_info=True)
|
||||
raise e
|
||||
|
||||
for row in resp:
|
||||
session = {
|
||||
"upload_id": row[0],
|
||||
"user": row[1],
|
||||
"document_id": row[2],
|
||||
"document_metadata": row[3],
|
||||
"s3_upload_id": row[4],
|
||||
"object_id": row[5],
|
||||
"total_size": row[6],
|
||||
"chunk_size": row[7],
|
||||
"total_chunks": row[8],
|
||||
"chunks_received": row[9] if row[9] else {},
|
||||
"created_at": row[10],
|
||||
"updated_at": row[11],
|
||||
}
|
||||
logger.debug("Done")
|
||||
return session
|
||||
|
||||
return None
|
||||
|
||||
async def update_upload_session_chunk(self, upload_id, chunk_index, etag):
|
||||
"""Record a successfully uploaded chunk."""
|
||||
|
||||
logger.debug(f"Update upload session {upload_id} chunk {chunk_index}")
|
||||
|
||||
now = int(time.time() * 1000)
|
||||
|
||||
while True:
|
||||
try:
|
||||
self.cassandra.execute(
|
||||
self.update_upload_session_chunk_stmt,
|
||||
(
|
||||
{chunk_index: etag},
|
||||
now,
|
||||
upload_id
|
||||
)
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Exception occurred", exc_info=True)
|
||||
raise e
|
||||
|
||||
logger.debug("Chunk recorded")
|
||||
|
||||
async def delete_upload_session(self, upload_id):
|
||||
"""Delete an upload session."""
|
||||
|
||||
logger.info(f"Deleting upload session {upload_id}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
self.cassandra.execute(
|
||||
self.delete_upload_session_stmt,
|
||||
(upload_id,)
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Exception occurred", exc_info=True)
|
||||
raise e
|
||||
|
||||
logger.debug("Upload session deleted")
|
||||
|
||||
async def list_upload_sessions(self, user):
|
||||
"""List all upload sessions for a user."""
|
||||
|
||||
logger.debug(f"List upload sessions for {user}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
resp = self.cassandra.execute(
|
||||
self.list_upload_sessions_stmt,
|
||||
(user,)
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Exception occurred", exc_info=True)
|
||||
raise e
|
||||
|
||||
sessions = []
|
||||
for row in resp:
|
||||
chunks_received = row[6] if row[6] else {}
|
||||
sessions.append({
|
||||
"upload_id": row[0],
|
||||
"document_id": row[1],
|
||||
"document_metadata": row[2],
|
||||
"total_size": row[3],
|
||||
"chunk_size": row[4],
|
||||
"total_chunks": row[5],
|
||||
"chunks_received": len(chunks_received),
|
||||
"created_at": row[7],
|
||||
"updated_at": row[8],
|
||||
})
|
||||
|
||||
logger.debug("Done")
|
||||
return sessions
|
||||
|
|
|
|||
1
trustgraph-flow/trustgraph/tool_service/__init__.py
Normal file
1
trustgraph-flow/trustgraph/tool_service/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# Tool service implementations
|
||||
2
trustgraph-flow/trustgraph/tool_service/joke/__init__.py
Normal file
2
trustgraph-flow/trustgraph/tool_service/joke/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
# Joke tool service
|
||||
from .service import run
|
||||
204
trustgraph-flow/trustgraph/tool_service/joke/service.py
Normal file
204
trustgraph-flow/trustgraph/tool_service/joke/service.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
"""
|
||||
Joke Tool Service - An example dynamic tool service.
|
||||
|
||||
This service demonstrates the tool service integration by:
|
||||
- Using the 'user' field to personalize responses
|
||||
- Using config params (style) to customize joke style
|
||||
- Using arguments (topic) to generate topic-specific jokes
|
||||
|
||||
Example tool-service config:
|
||||
{
|
||||
"id": "joke-service",
|
||||
"topic": "joke",
|
||||
"config-params": [
|
||||
{"name": "style", "required": false}
|
||||
]
|
||||
}
|
||||
|
||||
Example tool config:
|
||||
{
|
||||
"type": "tool-service",
|
||||
"name": "tell-joke",
|
||||
"description": "Tell a joke on a given topic",
|
||||
"service": "joke-service",
|
||||
"style": "pun",
|
||||
"arguments": [
|
||||
{
|
||||
"name": "topic",
|
||||
"type": "string",
|
||||
"description": "The topic for the joke (e.g., programming, animals, food)"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
import random
|
||||
import logging
|
||||
|
||||
from ... base import DynamicToolService
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_ident = "joke-service"
|
||||
default_topic = "joke"
|
||||
|
||||
# Joke database organized by topic and style
|
||||
JOKES = {
|
||||
"programming": {
|
||||
"pun": [
|
||||
"Why do programmers prefer dark mode? Because light attracts bugs!",
|
||||
"Why do Java developers wear glasses? Because they can't C#!",
|
||||
"A SQL query walks into a bar, walks up to two tables and asks... 'Can I join you?'",
|
||||
"Why was the JavaScript developer sad? Because he didn't Node how to Express himself!",
|
||||
],
|
||||
"dad-joke": [
|
||||
"I told my computer I needed a break, and now it won't stop sending me Kit-Kat ads.",
|
||||
"My son asked me to explain what a linked list is. I said 'I'll tell you, and then I'll tell you again, and again...'",
|
||||
"I asked my computer for a joke about UDP. I'm not sure if it got it.",
|
||||
],
|
||||
"one-liner": [
|
||||
"There are only 10 types of people: those who understand binary and those who don't.",
|
||||
"A programmer's wife tells him: 'Go to the store and get a loaf of bread. If they have eggs, get a dozen.' He returns with 12 loaves.",
|
||||
"99 little bugs in the code, 99 little bugs. Take one down, patch it around, 127 little bugs in the code.",
|
||||
],
|
||||
},
|
||||
"llama": {
|
||||
"pun": [
|
||||
"Why did the llama get a ticket? Because he was caught spitting in a no-spitting zone!",
|
||||
"What do you call a llama who's a great musician? A llama del Rey!",
|
||||
"Why did the llama cross the road? To prove he wasn't a chicken!",
|
||||
],
|
||||
"dad-joke": [
|
||||
"What did the llama say when he got kicked out of the zoo? 'Alpaca my bags!'",
|
||||
"Why don't llamas ever get lost? Because they always know the way to the Andes!",
|
||||
"What do you call a llama with no legs? A woolly rug!",
|
||||
],
|
||||
"one-liner": [
|
||||
"Llamas are great at meditation. They're always saying 'Dalai Llama.'",
|
||||
"I asked a llama for directions. He said 'No probllama!'",
|
||||
"Never trust a llama. They're always up to something woolly.",
|
||||
],
|
||||
},
|
||||
"animals": {
|
||||
"pun": [
|
||||
"What do you call a fish without eyes? A fsh!",
|
||||
"Why don't scientists trust atoms? Because they make up everything... just like that cat who blamed the dog!",
|
||||
"What do you call a bear with no teeth? A gummy bear!",
|
||||
],
|
||||
"dad-joke": [
|
||||
"I tried to catch some fog earlier. I mist. My dog wasn't impressed either.",
|
||||
"What do you call a dog that does magic tricks? A Labracadabrador!",
|
||||
"Why do cows wear bells? Because their horns don't work!",
|
||||
],
|
||||
"one-liner": [
|
||||
"I'm reading a book about anti-gravity. It's impossible to put down, unlike my cat.",
|
||||
"A horse walks into a bar. The bartender asks 'Why the long face?'",
|
||||
"What's orange and sounds like a parrot? A carrot!",
|
||||
],
|
||||
},
|
||||
"food": {
|
||||
"pun": [
|
||||
"I'm on a seafood diet. I see food and I eat it!",
|
||||
"Why did the tomato turn red? Because it saw the salad dressing!",
|
||||
"What do you call cheese that isn't yours? Nacho cheese!",
|
||||
],
|
||||
"dad-joke": [
|
||||
"I used to hate facial hair, but then it grew on me. Speaking of growing, have you tried my garden salad?",
|
||||
"Why don't eggs tell jokes? They'd crack each other up!",
|
||||
"I told my wife she was drawing her eyebrows too high. She looked surprised, then made me a sandwich.",
|
||||
],
|
||||
"one-liner": [
|
||||
"I'm reading a book about submarines and sandwiches. It's a sub-genre.",
|
||||
"Broken puppets for sale. No strings attached. Also, free spaghetti!",
|
||||
"I ordered a chicken and an egg online. I'll let you know which comes first.",
|
||||
],
|
||||
},
|
||||
"default": {
|
||||
"pun": [
|
||||
"Time flies like an arrow. Fruit flies like a banana!",
|
||||
"I used to be a banker, but I lost interest.",
|
||||
"I'm reading a book on the history of glue. I can't put it down!",
|
||||
],
|
||||
"dad-joke": [
|
||||
"I'm afraid for the calendar. Its days are numbered.",
|
||||
"I only know 25 letters of the alphabet. I don't know y.",
|
||||
"Did you hear about the claustrophobic astronaut? He just needed a little space.",
|
||||
],
|
||||
"one-liner": [
|
||||
"I told my wife she was drawing her eyebrows too high. She looked surprised.",
|
||||
"I'm not lazy, I'm on energy-saving mode.",
|
||||
"Parallel lines have so much in common. It's a shame they'll never meet.",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Processor(DynamicToolService):
|
||||
"""
|
||||
Joke tool service that demonstrates the tool service integration.
|
||||
"""
|
||||
|
||||
def __init__(self, **params):
|
||||
super(Processor, self).__init__(**params)
|
||||
logger.info("Joke service initialized")
|
||||
|
||||
async def invoke(self, user, config, arguments):
|
||||
"""
|
||||
Generate a joke based on the topic and style.
|
||||
|
||||
Args:
|
||||
user: The user requesting the joke
|
||||
config: Config values including 'style' (pun, dad-joke, one-liner)
|
||||
arguments: Arguments including 'topic' (programming, animals, food)
|
||||
|
||||
Returns:
|
||||
A personalized joke string
|
||||
"""
|
||||
# Get style from config (default: random)
|
||||
style = config.get("style", random.choice(["pun", "dad-joke", "one-liner"]))
|
||||
|
||||
# Get topic from arguments (default: random)
|
||||
topic = arguments.get("topic", "").lower()
|
||||
|
||||
# Map topic to our categories
|
||||
if "program" in topic or "code" in topic or "computer" in topic or "software" in topic:
|
||||
category = "programming"
|
||||
elif "llama" in topic:
|
||||
category = "llama"
|
||||
elif "animal" in topic or "dog" in topic or "cat" in topic or "bird" in topic:
|
||||
category = "animals"
|
||||
elif "food" in topic or "eat" in topic or "cook" in topic or "drink" in topic:
|
||||
category = "food"
|
||||
else:
|
||||
category = "default"
|
||||
|
||||
# Normalize style
|
||||
if style not in ["pun", "dad-joke", "one-liner"]:
|
||||
style = random.choice(["pun", "dad-joke", "one-liner"])
|
||||
|
||||
# Get jokes for this category and style
|
||||
jokes = JOKES.get(category, JOKES["default"]).get(style, JOKES["default"]["pun"])
|
||||
|
||||
# Pick a random joke
|
||||
joke = random.choice(jokes)
|
||||
|
||||
# Personalize the response
|
||||
response = f"Hey {user}! Here's a {style} for you:\n\n{joke}"
|
||||
|
||||
logger.debug(f"Generated joke for user={user}, style={style}, topic={topic}")
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
DynamicToolService.add_args(parser)
|
||||
# Override the topic default for this service
|
||||
for action in parser._actions:
|
||||
if '--topic' in action.option_strings:
|
||||
action.default = default_topic
|
||||
break
|
||||
|
||||
|
||||
def run():
|
||||
Processor.launch(default_ident, __doc__)
|
||||
Loading…
Add table
Add a link
Reference in a new issue