Per-flow librarian clients and per-workspace response queues (#865)

Replace singleton LibrarianClient with per-flow instances via the new
LibrarianSpec, giving each flow its own librarian tied to the
workspace-scoped request/response queues from the blueprint.

Move all workspace-scoped services (config, flow, librarian, knowledge)
from a single base-queue response producer to per-workspace response
producers created alongside the existing per-workspace request
consumers.  Update the gateway dispatcher and bootstrapper flow client
to subscribe to the matching workspace-scoped response queues.

Fix WorkspaceInit to register workspaces through the IAM
create-workspace API so they appear in __workspaces__ and are visible
to the gateway.  Simplify the bootstrapper gate to only check
config-svc reachability.

Updated tests accordingly.
This commit is contained in:
cybermaggedon 2026-05-06 12:01:01 +01:00 committed by GitHub
parent 01bf1d89d5
commit 03cc5ac80f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 405 additions and 735 deletions

View file

@ -177,8 +177,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response # Mock save_child_document on flow to avoid waiting for librarian response
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
# Mock message with TextDocument # Mock message with TextDocument
mock_message = MagicMock() mock_message = MagicMock()
@ -204,6 +203,7 @@ class TestRecursiveChunkerSimple(IsolatedAsyncioTestCase):
"output": mock_producer, "output": mock_producer,
"triples": mock_triples_producer, "triples": mock_triples_producer,
}.get(key) }.get(key)
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
# Act # Act
await processor.on_message(mock_message, mock_consumer, mock_flow) await processor.on_message(mock_message, mock_consumer, mock_flow)

View file

@ -177,8 +177,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Mock save_child_document to avoid librarian producer interactions # Mock save_child_document on flow to avoid librarian producer interactions
processor.librarian.save_child_document = AsyncMock(return_value="chunk-id")
# Mock message with TextDocument # Mock message with TextDocument
mock_message = MagicMock() mock_message = MagicMock()
@ -204,6 +203,7 @@ class TestTokenChunkerSimple(IsolatedAsyncioTestCase):
"output": mock_producer, "output": mock_producer,
"triples": mock_triples_producer, "triples": mock_triples_producer,
}.get(key) }.get(key)
mock_flow.librarian.save_child_document = AsyncMock(return_value="chunk-id")
# Act # Act
await processor.on_message(mock_message, mock_consumer, mock_flow) await processor.on_message(mock_message, mock_consumer, mock_flow)

View file

@ -156,6 +156,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
"output": mock_output_flow, "output": mock_output_flow,
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
config = { config = {
'id': 'test-mistral-ocr', 'id': 'test-mistral-ocr',
@ -171,9 +172,6 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
("# Page 2\nMore content", 2), ("# Page 2\nMore content", 2),
] ]
# Mock save_child_document
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
with patch.object(processor, 'ocr', return_value=ocr_result): with patch.object(processor, 'ocr', return_value=ocr_result):
await processor.on_message(mock_msg, None, mock_flow) await processor.on_message(mock_msg, None, mock_flow)
@ -227,8 +225,7 @@ class TestMistralOcrProcessor(IsolatedAsyncioTestCase):
Processor.add_args(mock_parser) Processor.add_args(mock_parser)
mock_parent_add_args.assert_called_once_with(mock_parser) mock_parent_add_args.assert_called_once_with(mock_parser)
assert mock_parser.add_argument.call_count == 3 assert mock_parser.add_argument.call_count == 1
# Check the API key arg is among them
call_args_list = [c[0] for c in mock_parser.add_argument.call_args_list] call_args_list = [c[0] for c in mock_parser.add_argument.call_args_list]
assert ('-k', '--api-key') in call_args_list assert ('-k', '--api-key') in call_args_list

View file

@ -72,6 +72,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
"output": mock_output_flow, "output": mock_output_flow,
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
config = { config = {
'id': 'test-pdf-decoder', 'id': 'test-pdf-decoder',
@ -80,9 +81,6 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
await processor.on_message(mock_msg, None, mock_flow) await processor.on_message(mock_msg, None, mock_flow)
# Verify output was sent for each page # Verify output was sent for each page
@ -148,6 +146,7 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
"output": mock_output_flow, "output": mock_output_flow,
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
config = { config = {
'id': 'test-pdf-decoder', 'id': 'test-pdf-decoder',
@ -156,9 +155,6 @@ class TestPdfDecoderProcessor(IsolatedAsyncioTestCase):
processor = Processor(**config) processor = Processor(**config)
# Mock save_child_document to avoid waiting for librarian response
processor.librarian.save_child_document = AsyncMock(return_value="mock-doc-id")
await processor.on_message(mock_msg, None, mock_flow) await processor.on_message(mock_msg, None, mock_flow)
mock_output_flow.send.assert_called_once() mock_output_flow.send.assert_called_once()

View file

@ -254,8 +254,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
# Mock save_child_document and magic mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-id")
processor.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "text/markdown" mock_magic.from_buffer.return_value = "text/markdown"
@ -310,7 +309,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
processor.librarian.save_child_document = AsyncMock(return_value="mock-id") mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "application/pdf" mock_magic.from_buffer.return_value = "application/pdf"
@ -361,7 +360,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
"triples": mock_triples_flow, "triples": mock_triples_flow,
}.get(name)) }.get(name))
processor.librarian.save_child_document = AsyncMock(return_value="mock-id") mock_flow.librarian.save_child_document = AsyncMock(return_value="mock-id")
with patch('trustgraph.decoding.universal.processor.magic') as mock_magic: with patch('trustgraph.decoding.universal.processor.magic') as mock_magic:
mock_magic.from_buffer.return_value = "application/pdf" mock_magic.from_buffer.return_value = "application/pdf"
@ -374,7 +373,7 @@ class TestUniversalProcessor(IsolatedAsyncioTestCase):
assert mock_triples_flow.send.call_count == 2 assert mock_triples_flow.send.call_count == 2
# save_child_document called twice (page + image) # save_child_document called twice (page + image)
assert processor.librarian.save_child_document.call_count == 2 assert mock_flow.librarian.save_child_document.call_count == 2
@patch('trustgraph.base.flow_processor.FlowProcessor.add_args') @patch('trustgraph.base.flow_processor.FlowProcessor.add_args')
def test_add_args(self, mock_parent_add_args): def test_add_args(self, mock_parent_add_args):

View file

@ -16,6 +16,7 @@ from . subscriber_spec import SubscriberSpec
from . request_response_spec import RequestResponseSpec from . request_response_spec import RequestResponseSpec
from . llm_service import LlmService, LlmResult, LlmChunk from . llm_service import LlmService, LlmResult, LlmChunk
from . librarian_client import LibrarianClient from . librarian_client import LibrarianClient
from . librarian_spec import LibrarianSpec
from . chunking_service import ChunkingService from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec from . embeddings_client import EmbeddingsClientSpec

View file

@ -4,13 +4,11 @@ for chunk-size and chunk-overlap parameters, and librarian client for
fetching large document content. fetching large document content.
""" """
import asyncio
import base64
import logging import logging
from .flow_processor import FlowProcessor from .flow_processor import FlowProcessor
from .parameter_spec import ParameterSpec from .parameter_spec import ParameterSpec
from .librarian_client import LibrarianClient from .librarian_spec import LibrarianSpec
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,35 +33,27 @@ class ChunkingService(FlowProcessor):
ParameterSpec(name="chunk-overlap") ParameterSpec(name="chunk-overlap")
) )
# Librarian client self.register_specification(
self.librarian = LibrarianClient( LibrarianSpec()
id=id,
backend=self.pubsub,
taskgroup=self.taskgroup,
) )
logger.debug("ChunkingService initialized with parameter specifications") logger.debug("ChunkingService initialized with parameter specifications")
async def start(self): async def get_document_text(self, doc, flow):
await super(ChunkingService, self).start()
await self.librarian.start()
async def get_document_text(self, doc, workspace):
""" """
Get text content from a TextDocument, fetching from librarian if needed. Get text content from a TextDocument, fetching from librarian if needed.
Args: Args:
doc: TextDocument with either inline text or document_id doc: TextDocument with either inline text or document_id
workspace: Workspace for librarian lookup (from flow.workspace) flow: Flow object with librarian client
Returns: Returns:
str: The document text content str: The document text content
""" """
if doc.document_id and not doc.text: if doc.document_id and not doc.text:
logger.info(f"Fetching document {doc.document_id} from librarian...") logger.info(f"Fetching document {doc.document_id} from librarian...")
text = await self.librarian.fetch_document_text( text = await flow.librarian.fetch_document_text(
document_id=doc.document_id, document_id=doc.document_id,
workspace=workspace,
) )
logger.info(f"Fetched {len(text)} characters from librarian") logger.info(f"Fetched {len(text)} characters from librarian")
return text return text

View file

@ -1,6 +1,4 @@
import asyncio
class Flow: class Flow:
""" """
Runtime representation of a deployed flow process. Runtime representation of a deployed flow process.
@ -22,16 +20,22 @@ class Flow:
self.parameter = {} self.parameter = {}
self.librarian = None
for spec in processor.specifications: for spec in processor.specifications:
spec.add(self, processor, defn) spec.add(self, processor, defn)
async def start(self): async def start(self):
if self.librarian:
await self.librarian.start()
for c in self.consumer.values(): for c in self.consumer.values():
await c.start() await c.start()
async def stop(self): async def stop(self):
for c in self.consumer.values(): for c in self.consumer.values():
await c.stop() await c.stop()
if self.librarian:
await self.librarian.stop()
def __call__(self, key): def __call__(self, key):
if key in self.producer: return self.producer[key] if key in self.producer: return self.producer[key]

View file

@ -10,7 +10,7 @@ Usage:
id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params
) )
await self.librarian.start() await self.librarian.start()
content = await self.librarian.fetch_document_content(doc_id, workspace) content = await self.librarian.fetch_document_content(doc_id)
""" """
import asyncio import asyncio
@ -39,9 +39,14 @@ class LibrarianClient:
librarian_response_q = params.get( librarian_response_q = params.get(
"librarian_response_queue", librarian_response_queue, "librarian_response_queue", librarian_response_queue,
) )
subscriber = params.get(
"librarian_subscriber", f"{id}-librarian",
)
flow_name = params.get("flow_name")
librarian_request_metrics = ProducerMetrics( librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request", processor=id, flow=flow_name, name="librarian-request",
) )
self._producer = Producer( self._producer = Producer(
@ -52,7 +57,7 @@ class LibrarianClient:
) )
librarian_response_metrics = ConsumerMetrics( librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response", processor=id, flow=flow_name, name="librarian-response",
) )
self._consumer = Consumer( self._consumer = Consumer(
@ -60,7 +65,7 @@ class LibrarianClient:
backend=backend, backend=backend,
flow=None, flow=None,
topic=librarian_response_q, topic=librarian_response_q,
subscriber=f"{id}-librarian", subscriber=subscriber,
schema=LibrarianResponse, schema=LibrarianResponse,
handler=self._on_response, handler=self._on_response,
metrics=librarian_response_metrics, metrics=librarian_response_metrics,
@ -76,6 +81,11 @@ class LibrarianClient:
await self._producer.start() await self._producer.start()
await self._consumer.start() await self._consumer.start()
async def stop(self):
"""Stop the librarian producer and consumer."""
await self._consumer.stop()
await self._producer.stop()
async def _on_response(self, msg, consumer, flow): async def _on_response(self, msg, consumer, flow):
"""Route librarian responses to the right waiter.""" """Route librarian responses to the right waiter."""
response = msg.value() response = msg.value()
@ -150,7 +160,7 @@ class LibrarianClient:
finally: finally:
self._streams.pop(request_id, None) self._streams.pop(request_id, None)
async def fetch_document_content(self, document_id, workspace, timeout=120): async def fetch_document_content(self, document_id, timeout=120):
"""Fetch document content using streaming. """Fetch document content using streaming.
Returns base64-encoded content. Caller is responsible for decoding. Returns base64-encoded content. Caller is responsible for decoding.
@ -158,7 +168,6 @@ class LibrarianClient:
req = LibrarianRequest( req = LibrarianRequest(
operation="stream-document", operation="stream-document",
document_id=document_id, document_id=document_id,
workspace=workspace,
) )
chunks = await self.stream(req, timeout=timeout) chunks = await self.stream(req, timeout=timeout)
@ -176,24 +185,23 @@ class LibrarianClient:
return base64.b64encode(raw) return base64.b64encode(raw)
async def fetch_document_text(self, document_id, workspace, timeout=120): async def fetch_document_text(self, document_id, timeout=120):
"""Fetch document content and decode as UTF-8 text.""" """Fetch document content and decode as UTF-8 text."""
content = await self.fetch_document_content( content = await self.fetch_document_content(
document_id, workspace, timeout=timeout, document_id, timeout=timeout,
) )
return base64.b64decode(content).decode("utf-8") return base64.b64decode(content).decode("utf-8")
async def fetch_document_metadata(self, document_id, workspace, timeout=120): async def fetch_document_metadata(self, document_id, timeout=120):
"""Fetch document metadata from the librarian.""" """Fetch document metadata from the librarian."""
req = LibrarianRequest( req = LibrarianRequest(
operation="get-document-metadata", operation="get-document-metadata",
document_id=document_id, document_id=document_id,
workspace=workspace,
) )
response = await self.request(req, timeout=timeout) response = await self.request(req, timeout=timeout)
return response.document_metadata return response.document_metadata
async def save_child_document(self, doc_id, parent_id, workspace, content, async def save_child_document(self, doc_id, parent_id, content,
document_type="chunk", title=None, document_type="chunk", title=None,
kind="text/plain", timeout=120): kind="text/plain", timeout=120):
"""Save a child document to the librarian.""" """Save a child document to the librarian."""
@ -217,7 +225,7 @@ class LibrarianClient:
await self.request(req, timeout=timeout) await self.request(req, timeout=timeout)
return doc_id return doc_id
async def save_document(self, doc_id, workspace, content, title=None, async def save_document(self, doc_id, content, title=None,
document_type="answer", kind="text/plain", document_type="answer", kind="text/plain",
timeout=120): timeout=120):
"""Save a document to the librarian.""" """Save a document to the librarian."""
@ -236,7 +244,6 @@ class LibrarianClient:
document_id=doc_id, document_id=doc_id,
document_metadata=doc_metadata, document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"), content=base64.b64encode(content).decode("utf-8"),
workspace=workspace,
) )
await self.request(req, timeout=timeout) await self.request(req, timeout=timeout)

View file

@ -0,0 +1,31 @@
from __future__ import annotations
import uuid
from typing import Any
from . spec import Spec
from . librarian_client import LibrarianClient
class LibrarianSpec(Spec):
def __init__(self, request_name="librarian-request",
response_name="librarian-response"):
self.request_name = request_name
self.response_name = response_name
def add(self, flow: Any, processor: Any, definition: dict[str, Any]) -> None:
client = LibrarianClient(
id=flow.id,
backend=processor.pubsub,
taskgroup=processor.taskgroup,
librarian_request_queue=definition["topics"][self.request_name],
librarian_response_queue=definition["topics"][self.response_name],
librarian_subscriber=(
processor.id + "--" + flow.workspace + "--" +
flow.name + "--librarian--" + str(uuid.uuid4())
),
flow_name=flow.name,
)
flow.librarian = client

View file

@ -61,6 +61,10 @@ class FlowContext:
def __call__(self, service_name): def __call__(self, service_name):
return self._flow(service_name) return self._flow(service_name)
@property
def librarian(self):
return self._flow.librarian
class UsageTracker: class UsageTracker:
"""Accumulates token usage across multiple prompt calls.""" """Accumulates token usage across multiple prompt calls."""
@ -320,9 +324,9 @@ class PatternBase:
f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought" f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
) )
try: try:
await self.processor.save_answer_content( await flow.librarian.save_document(
doc_id=thought_doc_id, doc_id=thought_doc_id,
workspace=flow.workspace,
content=act.thought, content=act.thought,
title=f"Agent Thought: {act.name}", title=f"Agent Thought: {act.name}",
) )
@ -389,9 +393,9 @@ class PatternBase:
f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
) )
try: try:
await self.processor.save_answer_content( await flow.librarian.save_document(
doc_id=observation_doc_id, doc_id=observation_doc_id,
workspace=flow.workspace,
content=observation_text, content=observation_text,
title=f"Agent Observation", title=f"Agent Observation",
) )
@ -445,9 +449,9 @@ class PatternBase:
if answer_text: if answer_text:
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer" answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
try: try:
await self.processor.save_answer_content( await flow.librarian.save_document(
doc_id=answer_doc_id, doc_id=answer_doc_id,
workspace=flow.workspace,
content=answer_text, content=answer_text,
title=f"Agent Answer: {request.question[:50]}...", title=f"Agent Answer: {request.question[:50]}...",
) )
@ -521,8 +525,8 @@ class PatternBase:
doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc" doc_id = f"urn:trustgraph:agent:{session_id}/finding/{index}/doc"
try: try:
await self.processor.save_answer_content( await flow.librarian.save_document(
doc_id=doc_id, workspace=flow.workspace, doc_id=doc_id,
content=answer_text, content=answer_text,
title=f"Finding: {goal[:60]}", title=f"Finding: {goal[:60]}",
) )
@ -574,8 +578,8 @@ class PatternBase:
doc_id = f"urn:trustgraph:agent:{session_id}/step/{index}/doc" doc_id = f"urn:trustgraph:agent:{session_id}/step/{index}/doc"
try: try:
await self.processor.save_answer_content( await flow.librarian.save_document(
doc_id=doc_id, workspace=flow.workspace, doc_id=doc_id,
content=answer_text, content=answer_text,
title=f"Step result: {goal[:60]}", title=f"Step result: {goal[:60]}",
) )
@ -606,8 +610,8 @@ class PatternBase:
doc_id = f"urn:trustgraph:agent:{session_id}/synthesis/doc" doc_id = f"urn:trustgraph:agent:{session_id}/synthesis/doc"
try: try:
await self.processor.save_answer_content( await flow.librarian.save_document(
doc_id=doc_id, workspace=flow.workspace, doc_id=doc_id,
content=answer_text, content=answer_text,
title="Synthesis", title="Synthesis",
) )

View file

@ -7,26 +7,17 @@ to select between ReactPattern, PlanThenExecutePattern, and
SupervisorPattern at runtime. SupervisorPattern at runtime.
""" """
import asyncio
import base64
import json import json
import functools import functools
import logging import logging
import uuid
from datetime import datetime
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
from ... base import ProducerSpec from ... base import ProducerSpec, LibrarianSpec
from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ..orchestrator.pattern_base import UsageTracker, PatternBase from ..orchestrator.pattern_base import UsageTracker, PatternBase
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from trustgraph.provenance import ( from trustgraph.provenance import (
agent_session_uri, agent_session_uri,
@ -52,8 +43,6 @@ logger = logging.getLogger(__name__)
default_ident = "agent-manager" default_ident = "agent-manager"
default_max_iterations = 10 default_max_iterations = 10
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(AgentService): class Processor(AgentService):
@ -151,94 +140,9 @@ class Processor(AgentService):
) )
) )
# Librarian client self.register_specification(
librarian_request_q = params.get( LibrarianSpec()
"librarian_request_queue", default_librarian_request_queue
) )
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
self.pending_librarian_requests = {}
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
response = msg.value()
request_id = msg.properties().get("id")
if request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, workspace, content, title=None,
timeout=120):
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
workspace=workspace,
kind="text/plain",
title=title or "Agent Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
workspace=workspace,
)
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: "
f"{response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
def provenance_session_uri(self, session_id): def provenance_session_uri(self, session_id):
return agent_session_uri(session_id) return agent_session_uri(session_id)

View file

@ -3,7 +3,6 @@ Simple agent infrastructure broadly implements the ReAct flow.
""" """
import asyncio import asyncio
import base64
import json import json
import re import re
import sys import sys
@ -19,14 +18,10 @@ logger = logging.getLogger(__name__)
from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec from ... base import AgentService, TextCompletionClientSpec, PromptClientSpec
from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec from ... base import GraphRagClientSpec, ToolClientSpec, StructuredQueryClientSpec
from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec from ... base import RowEmbeddingsQueryClientSpec, EmbeddingsClientSpec
from ... base import ProducerSpec from ... base import ProducerSpec, LibrarianSpec
from ... base import Consumer, Producer
from ... base import ConsumerMetrics, ProducerMetrics
from ... schema import AgentRequest, AgentResponse, AgentStep, Error from ... schema import AgentRequest, AgentResponse, AgentStep, Error
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
# Provenance imports for agent explainability # Provenance imports for agent explainability
from trustgraph.provenance import ( from trustgraph.provenance import (
@ -51,8 +46,6 @@ from . types import Final, Action, Tool, Argument
default_ident = "agent-manager" default_ident = "agent-manager"
default_max_iterations = 10 default_max_iterations = 10
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(AgentService): class Processor(AgentService):
@ -141,112 +134,9 @@ class Processor(AgentService):
) )
) )
# Librarian client for storing answer content self.register_specification(
librarian_request_q = params.get( LibrarianSpec()
"librarian_request_queue", default_librarian_request_queue
) )
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_librarian_requests = {}
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
workspace: Workspace for isolation
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
workspace=workspace,
kind="text/plain",
title=title or "Agent Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
workspace=workspace,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_tools_config(self, workspace, config, version): async def on_tools_config(self, workspace, config, version):
@ -611,9 +501,9 @@ class Processor(AgentService):
if act_decision.thought: if act_decision.thought:
t_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought" t_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/thought"
try: try:
await self.save_answer_content( await flow.librarian.save_document(
doc_id=t_doc_id, doc_id=t_doc_id,
workspace=flow.workspace,
content=act_decision.thought, content=act_decision.thought,
title=f"Agent Thought: {act_decision.name}", title=f"Agent Thought: {act_decision.name}",
) )
@ -691,9 +581,9 @@ class Processor(AgentService):
if f: if f:
answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer" answer_doc_id = f"urn:trustgraph:agent:{session_id}/answer"
try: try:
await self.save_answer_content( await flow.librarian.save_document(
doc_id=answer_doc_id, doc_id=answer_doc_id,
workspace=flow.workspace,
content=f, content=f,
title=f"Agent Answer: {request.question[:50]}...", title=f"Agent Answer: {request.question[:50]}...",
) )
@ -768,9 +658,8 @@ class Processor(AgentService):
if act.observation: if act.observation:
observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation" observation_doc_id = f"urn:trustgraph:agent:{session_id}/i{iteration_num}/observation"
try: try:
await self.save_answer_content( await flow.librarian.save_document(
doc_id=observation_doc_id, doc_id=observation_doc_id,
workspace=flow.workspace,
content=act.observation, content=act.observation,
title=f"Agent Observation", title=f"Agent Observation",
) )

View file

@ -22,6 +22,7 @@ class InitContext:
logger: logging.Logger logger: logging.Logger
config: Any # ConfigClient config: Any # ConfigClient
make_flow_client: Any # callable(workspace) -> RequestResponse make_flow_client: Any # callable(workspace) -> RequestResponse
make_iam_client: Any # callable() -> RequestResponse
class Initialiser: class Initialiser:
@ -35,7 +36,7 @@ class Initialiser:
* ``wait_for_services`` (bool, default ``True``): when ``True`` the * ``wait_for_services`` (bool, default ``True``): when ``True`` the
initialiser only runs after the bootstrapper's service gate has initialiser only runs after the bootstrapper's service gate has
passed (config-svc and flow-svc reachable). Set ``False`` for passed (config-svc reachable). Set ``False`` for
initialisers that bring up infrastructure the gate itself initialisers that bring up infrastructure the gate itself
depends on principally Pulsar topology, without which depends on principally Pulsar topology, without which
config-svc cannot come online. config-svc cannot come online.

View file

@ -28,6 +28,10 @@ from trustgraph.schema import (
FlowRequest, FlowResponse, FlowRequest, FlowResponse,
flow_request_queue, flow_response_queue, flow_request_queue, flow_response_queue,
) )
from trustgraph.schema import (
IamRequest, IamResponse,
iam_request_queue, iam_response_queue,
)
from .. base import Initialiser, InitContext from .. base import Initialiser, InitContext
@ -189,13 +193,31 @@ class Processor(AsyncProcessor):
request_metrics=ProducerMetrics( request_metrics=ProducerMetrics(
processor=self.id, flow=None, name="flow-request", processor=self.id, flow=None, name="flow-request",
), ),
response_topic=flow_response_queue, response_topic=f"{flow_response_queue}:{workspace}",
response_schema=FlowResponse, response_schema=FlowResponse,
response_metrics=SubscriberMetrics( response_metrics=SubscriberMetrics(
processor=self.id, flow=None, name="flow-response", processor=self.id, flow=None, name="flow-response",
), ),
) )
def _make_iam_client(self):
rr_id = str(uuid.uuid4())
return RequestResponse(
backend=self.pubsub_backend,
subscription=f"{self.id}--iam--{rr_id}",
consumer_name=self.id,
request_topic=iam_request_queue,
request_schema=IamRequest,
request_metrics=ProducerMetrics(
processor=self.id, flow=None, name="iam-request",
),
response_topic=iam_response_queue,
response_schema=IamResponse,
response_metrics=SubscriberMetrics(
processor=self.id, flow=None, name="iam-response",
),
)
async def _open_clients(self): async def _open_clients(self):
config = self._make_config_client() config = self._make_config_client()
await config.start() await config.start()
@ -211,13 +233,6 @@ class Processor(AsyncProcessor):
# Service gate. # Service gate.
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _gate_workspace(self):
for spec in self.specs:
ws = getattr(spec.instance, "workspace", None)
if ws and not ws.startswith("_"):
return ws
return None
async def _gate_ready(self, config): async def _gate_ready(self, config):
try: try:
await config.keys(SYSTEM_WORKSPACE, INIT_STATE_TYPE) await config.keys(SYSTEM_WORKSPACE, INIT_STATE_TYPE)
@ -227,33 +242,6 @@ class Processor(AsyncProcessor):
) )
return False return False
workspace = self._gate_workspace()
if workspace is None:
return True
flow = self._make_flow_client(workspace)
try:
await flow.start()
resp = await flow.request(
FlowRequest(
operation="list-blueprints",
),
timeout=5,
)
if resp.error:
logger.info(
f"Gate: flow-svc error: "
f"{resp.error.type}: {resp.error.message}"
)
return False
except Exception as e:
logger.info(
f"Gate: flow-svc not ready ({type(e).__name__}: {e})"
)
return False
finally:
await self._safe_stop(flow)
return True return True
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -307,6 +295,7 @@ class Processor(AsyncProcessor):
logger=child_logger, logger=child_logger,
config=config, config=config,
make_flow_client=self._make_flow_client, make_flow_client=self._make_flow_client,
make_iam_client=self._make_iam_client,
) )
child_logger.info( child_logger.info(

View file

@ -39,8 +39,6 @@ TEMPLATE_WORKSPACE = "__template__"
class TemplateSeed(Initialiser): class TemplateSeed(Initialiser):
wait_for_services = False
def __init__(self, config_file, overwrite=False, **kwargs): def __init__(self, config_file, overwrite=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if not config_file: if not config_file:

View file

@ -26,6 +26,8 @@ the next cycle once the prerequisite is satisfied.
import json import json
from trustgraph.schema import IamRequest, WorkspaceInput
from .. base import Initialiser from .. base import Initialiser
TEMPLATE_WORKSPACE = "__template__" TEMPLATE_WORKSPACE = "__template__"
@ -33,8 +35,6 @@ TEMPLATE_WORKSPACE = "__template__"
class WorkspaceInit(Initialiser): class WorkspaceInit(Initialiser):
wait_for_services = False
def __init__( def __init__(
self, self,
workspace="default", workspace="default",
@ -61,6 +61,8 @@ class WorkspaceInit(Initialiser):
self.overwrite = overwrite self.overwrite = overwrite
async def run(self, ctx, old_flag, new_flag): async def run(self, ctx, old_flag, new_flag):
await self._create_workspace(ctx)
if self.source == "seed-file": if self.source == "seed-file":
tree = self._load_seed_file() tree = self._load_seed_file()
else: else:
@ -107,6 +109,39 @@ class WorkspaceInit(Initialiser):
) )
return tree return tree
async def _create_workspace(self, ctx):
"""Register the workspace via the IAM create-workspace API."""
iam = ctx.make_iam_client()
await iam.start()
try:
resp = await iam.request(
IamRequest(
operation="create-workspace",
workspace_record=WorkspaceInput(
id=self.workspace,
name=self.workspace.title(),
enabled=True,
),
),
timeout=10,
)
if resp.error:
if resp.error.type == "duplicate":
ctx.logger.info(
f"Workspace {self.workspace!r} already exists in IAM"
)
else:
raise RuntimeError(
f"IAM create-workspace failed: "
f"{resp.error.type}: {resp.error.message}"
)
else:
ctx.logger.info(
f"Workspace {self.workspace!r} created via IAM"
)
finally:
await iam.stop()
async def _write_all(self, ctx, tree): async def _write_all(self, ctx, tree):
values = [] values = []
for type_name, entries in tree.items(): for type_name, entries in tree.items():
@ -114,6 +149,7 @@ class WorkspaceInit(Initialiser):
values.append((type_name, key, json.dumps(value))) values.append((type_name, key, json.dumps(value)))
if values: if values:
await ctx.config.put_many(self.workspace, values) await ctx.config.put_many(self.workspace, values)
ctx.logger.info( ctx.logger.info(
f"Workspace {self.workspace!r} populated with " f"Workspace {self.workspace!r} populated with "
f"{len(values)} entries" f"{len(values)} entries"
@ -134,6 +170,7 @@ class WorkspaceInit(Initialiser):
if values: if values:
await ctx.config.put_many(self.workspace, values) await ctx.config.put_many(self.workspace, values)
written += len(values) written += len(values)
ctx.logger.info( ctx.logger.info(
f"Workspace {self.workspace!r} upsert-missing: " f"Workspace {self.workspace!r} upsert-missing: "
f"{written} new entries" f"{written} new entries"

View file

@ -95,7 +95,7 @@ class Processor(ChunkingService):
logger.info(f"Chunking document {v.metadata.id}...") logger.info(f"Chunking document {v.metadata.id}...")
# Get text content (fetches from librarian if needed) # Get text content (fetches from librarian if needed)
text = await self.get_document_text(v, flow.workspace) text = await self.get_document_text(v, flow)
# Extract chunk parameters from flow (allows runtime override) # Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document( chunk_size, chunk_overlap = await self.chunk_document(
@ -141,10 +141,9 @@ class Processor(ChunkingService):
chunk_length = len(chunk.page_content) chunk_length = len(chunk.page_content)
# Save chunk to librarian as child document # Save chunk to librarian as child document
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=chunk_doc_id, doc_id=chunk_doc_id,
parent_id=parent_doc_id, parent_id=parent_doc_id,
workspace=flow.workspace,
content=chunk_content, content=chunk_content,
document_type="chunk", document_type="chunk",
title=f"Chunk {chunk_index}", title=f"Chunk {chunk_index}",

View file

@ -92,7 +92,7 @@ class Processor(ChunkingService):
logger.info(f"Chunking document {v.metadata.id}...") logger.info(f"Chunking document {v.metadata.id}...")
# Get text content (fetches from librarian if needed) # Get text content (fetches from librarian if needed)
text = await self.get_document_text(v, flow.workspace) text = await self.get_document_text(v, flow)
# Extract chunk parameters from flow (allows runtime override) # Extract chunk parameters from flow (allows runtime override)
chunk_size, chunk_overlap = await self.chunk_document( chunk_size, chunk_overlap = await self.chunk_document(
@ -137,10 +137,9 @@ class Processor(ChunkingService):
chunk_length = len(chunk.page_content) chunk_length = len(chunk.page_content)
# Save chunk to librarian as child document # Save chunk to librarian as child document
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=chunk_doc_id, doc_id=chunk_doc_id,
parent_id=parent_doc_id, parent_id=parent_doc_id,
workspace=flow.workspace,
content=chunk_content, content=chunk_content,
document_type="chunk", document_type="chunk",
title=f"Chunk {chunk_index}", title=f"Chunk {chunk_index}",

View file

@ -67,7 +67,7 @@ class Processor(AsyncProcessor):
config_request_queue = params.get( config_request_queue = params.get(
"config_request_queue", default_config_request_queue "config_request_queue", default_config_request_queue
) )
config_response_queue = params.get( self.config_response_queue_base = params.get(
"config_response_queue", default_config_response_queue "config_response_queue", default_config_response_queue
) )
config_push_queue = params.get( config_push_queue = params.get(
@ -130,7 +130,7 @@ class Processor(AsyncProcessor):
self.config_response_producer = Producer( self.config_response_producer = Producer(
backend = self.pubsub, backend = self.pubsub,
topic = config_response_queue, topic = self.config_response_queue_base,
schema = ConfigResponse, schema = ConfigResponse,
metrics = config_response_metrics, metrics = config_response_metrics,
) )
@ -208,17 +208,31 @@ class Processor(AsyncProcessor):
) )
async def _add_workspace_consumer(self, workspace_id): async def _add_workspace_consumer(self, workspace_id):
queue = workspace_queue( req_queue = workspace_queue(
self.config_request_queue_base, workspace_id, self.config_request_queue_base, workspace_id,
) )
resp_queue = workspace_queue(
self.config_response_queue_base, workspace_id,
)
await self.pubsub.ensure_topic(queue) await self.pubsub.ensure_topic(req_queue)
await self.pubsub.ensure_topic(resp_queue)
response_producer = Producer(
backend=self.pubsub,
topic=resp_queue,
schema=ConfigResponse,
metrics=ProducerMetrics(
processor=self.id, flow=None,
name=f"config-response-{workspace_id}",
),
)
consumer = Consumer( consumer = Consumer(
taskgroup=self.taskgroup, taskgroup=self.taskgroup,
backend=self.pubsub, backend=self.pubsub,
flow=None, flow=None,
topic=queue, topic=req_queue,
subscriber=self.id, subscriber=self.id,
schema=ConfigRequest, schema=ConfigRequest,
handler=partial( handler=partial(
@ -231,17 +245,23 @@ class Processor(AsyncProcessor):
), ),
) )
await response_producer.start()
await consumer.start() await consumer.start()
self.workspace_consumers[workspace_id] = consumer
self.workspace_consumers[workspace_id] = {
"consumer": consumer,
"response": response_producer,
}
logger.info( logger.info(
f"Subscribed to workspace config queue: {workspace_id}" f"Subscribed to workspace config queue: {workspace_id}"
) )
async def _remove_workspace_consumer(self, workspace_id): async def _remove_workspace_consumer(self, workspace_id):
consumer = self.workspace_consumers.pop(workspace_id, None) clients = self.workspace_consumers.pop(workspace_id, None)
if consumer: if clients:
await consumer.stop() for client in clients.values():
await client.stop()
logger.info( logger.info(
f"Unsubscribed from workspace config queue: {workspace_id}" f"Unsubscribed from workspace config queue: {workspace_id}"
) )
@ -249,6 +269,7 @@ class Processor(AsyncProcessor):
async def start(self): async def start(self):
await self.pubsub.ensure_topic(self.config_request_queue_base) await self.pubsub.ensure_topic(self.config_request_queue_base)
await self.config_response_producer.start()
await self.push() # Startup poke: empty types = everything await self.push() # Startup poke: empty types = everything
await self.system_consumer.start() await self.system_consumer.start()
@ -307,9 +328,11 @@ class Processor(AsyncProcessor):
f"workspace={workspace}..." f"workspace={workspace}..."
) )
producer = self.workspace_consumers[workspace]["response"]
resp = await self.config.handle_workspace(v, workspace) resp = await self.config.handle_workspace(v, workspace)
await self.config_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -322,7 +345,7 @@ class Processor(AsyncProcessor):
), ),
) )
await self.config_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )

View file

@ -48,7 +48,7 @@ class Processor(WorkspaceProcessor):
"knowledge_request_queue", default_knowledge_request_queue "knowledge_request_queue", default_knowledge_request_queue
) )
knowledge_response_queue = params.get( self.knowledge_response_queue_base = params.get(
"knowledge_response_queue", default_knowledge_response_queue "knowledge_response_queue", default_knowledge_response_queue
) )
@ -70,24 +70,13 @@ class Processor(WorkspaceProcessor):
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"knowledge_request_queue": self.knowledge_request_queue_base, "knowledge_request_queue": self.knowledge_request_queue_base,
"knowledge_response_queue": knowledge_response_queue, "knowledge_response_queue": self.knowledge_response_queue_base,
"cassandra_host": self.cassandra_host, "cassandra_host": self.cassandra_host,
"cassandra_username": self.cassandra_username, "cassandra_username": self.cassandra_username,
"cassandra_password": self.cassandra_password, "cassandra_password": self.cassandra_password,
} }
) )
knowledge_response_metrics = ProducerMetrics(
processor = self.id, flow = None, name = "knowledge-response"
)
self.knowledge_response_producer = Producer(
backend = self.pubsub,
topic = knowledge_response_queue,
schema = KnowledgeResponse,
metrics = knowledge_response_metrics,
)
self.knowledge = KnowledgeManager( self.knowledge = KnowledgeManager(
cassandra_host = self.cassandra_host, cassandra_host = self.cassandra_host,
cassandra_username = self.cassandra_username, cassandra_username = self.cassandra_username,
@ -109,17 +98,31 @@ class Processor(WorkspaceProcessor):
if workspace in self.workspace_consumers: if workspace in self.workspace_consumers:
return return
queue = workspace_queue( req_queue = workspace_queue(
self.knowledge_request_queue_base, workspace, self.knowledge_request_queue_base, workspace,
) )
resp_queue = workspace_queue(
self.knowledge_response_queue_base, workspace,
)
await self.pubsub.ensure_topic(queue) await self.pubsub.ensure_topic(req_queue)
await self.pubsub.ensure_topic(resp_queue)
response_producer = Producer(
backend=self.pubsub,
topic=resp_queue,
schema=KnowledgeResponse,
metrics=ProducerMetrics(
processor=self.id, flow=None,
name=f"knowledge-response-{workspace}",
),
)
consumer = Consumer( consumer = Consumer(
taskgroup=self.taskgroup, taskgroup=self.taskgroup,
backend=self.pubsub, backend=self.pubsub,
flow=None, flow=None,
topic=queue, topic=req_queue,
subscriber=self.id, subscriber=self.id,
schema=KnowledgeRequest, schema=KnowledgeRequest,
handler=partial( handler=partial(
@ -131,22 +134,27 @@ class Processor(WorkspaceProcessor):
), ),
) )
await response_producer.start()
await consumer.start() await consumer.start()
self.workspace_consumers[workspace] = consumer
self.workspace_consumers[workspace] = {
"consumer": consumer,
"response": response_producer,
}
logger.info(f"Subscribed to workspace queue: {workspace}") logger.info(f"Subscribed to workspace queue: {workspace}")
async def on_workspace_deleted(self, workspace): async def on_workspace_deleted(self, workspace):
consumer = self.workspace_consumers.pop(workspace, None) clients = self.workspace_consumers.pop(workspace, None)
if consumer: if clients:
await consumer.stop() for client in clients.values():
await client.stop()
logger.info(f"Unsubscribed from workspace queue: {workspace}") logger.info(f"Unsubscribed from workspace queue: {workspace}")
async def start(self): async def start(self):
await super(Processor, self).start() await super(Processor, self).start()
await self.knowledge_response_producer.start()
async def on_knowledge_config(self, workspace, config, version): async def on_knowledge_config(self, workspace, config, version):
@ -164,7 +172,7 @@ class Processor(WorkspaceProcessor):
logger.debug(f"Flows for {workspace}: {self.flows[workspace]}") logger.debug(f"Flows for {workspace}: {self.flows[workspace]}")
async def process_request(self, v, id, workspace): async def process_request(self, v, id, workspace, producer):
if v.operation is None: if v.operation is None:
raise RequestError("Null operation") raise RequestError("Null operation")
@ -184,7 +192,7 @@ class Processor(WorkspaceProcessor):
raise RequestError(f"Invalid operation: {v.operation}") raise RequestError(f"Invalid operation: {v.operation}")
async def respond(x): async def respond(x):
await self.knowledge_response_producer.send( await producer.send(
x, { "id": id } x, { "id": id }
) )
return await impls[v.operation](v, respond, workspace) return await impls[v.operation](v, respond, workspace)
@ -199,11 +207,13 @@ class Processor(WorkspaceProcessor):
logger.info(f"Handling knowledge input {id}...") logger.info(f"Handling knowledge input {id}...")
producer = self.workspace_consumers[workspace]["response"]
try: try:
# We don't send a response back here, the processing # We don't send a response back here, the processing
# implementation sends whatever it needs to send. # implementation sends whatever it needs to send.
await self.process_request(v, id, workspace) await self.process_request(v, id, workspace, producer)
return return
@ -215,7 +225,7 @@ class Processor(WorkspaceProcessor):
) )
) )
await self.knowledge_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -228,7 +238,7 @@ class Processor(WorkspaceProcessor):
) )
) )
await self.knowledge_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )

View file

@ -16,9 +16,8 @@ import os
from mistralai import Mistral from mistralai import Mistral
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
from ... provenance import ( from ... provenance import (
document_uri, page_uri as make_page_uri, derived_entity_triples, document_uri, page_uri as make_page_uri, derived_entity_triples,
@ -36,9 +35,6 @@ COMPONENT_VERSION = "1.0.0"
default_ident = "document-decoder" default_ident = "document-decoder"
default_api_key = os.getenv("MISTRAL_TOKEN") default_api_key = os.getenv("MISTRAL_TOKEN")
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
pages_per_chunk = 5 pages_per_chunk = 5
def chunks(lst, n): def chunks(lst, n):
@ -98,9 +94,8 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client self.register_specification(
self.librarian = LibrarianClient( LibrarianSpec()
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
) )
if api_key is None: if api_key is None:
@ -113,10 +108,6 @@ class Processor(FlowProcessor):
logger.info("Mistral OCR processor initialized") logger.info("Mistral OCR processor initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian.start()
def ocr(self, blob): def ocr(self, blob):
""" """
Run Mistral OCR on a PDF blob, returning per-page markdown strings. Run Mistral OCR on a PDF blob, returning per-page markdown strings.
@ -198,9 +189,9 @@ class Processor(FlowProcessor):
# Check MIME type if fetching from librarian # Check MIME type if fetching from librarian
if v.document_id: if v.document_id:
doc_meta = await self.librarian.fetch_document_metadata( doc_meta = await flow.librarian.fetch_document_metadata(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error( logger.error(
@ -213,9 +204,9 @@ class Processor(FlowProcessor):
# Get PDF content - fetch from librarian or use inline data # Get PDF content - fetch from librarian or use inline data
if v.document_id: if v.document_id:
logger.info(f"Fetching document {v.document_id} from librarian...") logger.info(f"Fetching document {v.document_id} from librarian...")
content = await self.librarian.fetch_document_content( content = await flow.librarian.fetch_document_content(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
if isinstance(content, str): if isinstance(content, str):
content = content.encode('utf-8') content = content.encode('utf-8')
@ -240,10 +231,10 @@ class Processor(FlowProcessor):
page_content = markdown.encode("utf-8") page_content = markdown.encode("utf-8")
# Save page as child document in librarian # Save page as child document in librarian
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=page_doc_id, doc_id=page_doc_id,
parent_id=source_doc_id, parent_id=source_doc_id,
workspace=flow.workspace,
content=page_content, content=page_content,
document_type="page", document_type="page",
title=f"Page {page_num}", title=f"Page {page_num}",
@ -297,18 +288,6 @@ class Processor(FlowProcessor):
help=f'Mistral API Key' help=f'Mistral API Key'
) )
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)
def run(): def run():
Processor.launch(default_ident, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -12,9 +12,8 @@ import tempfile
import base64 import base64
import logging import logging
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
PyPDFLoader = None PyPDFLoader = None
@ -32,9 +31,6 @@ logger = logging.getLogger(__name__)
default_ident = "document-decoder" default_ident = "document-decoder"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor): class Processor(FlowProcessor):
@ -70,17 +66,12 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client self.register_specification(
self.librarian = LibrarianClient( LibrarianSpec()
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
) )
logger.info("PDF decoder initialized") logger.info("PDF decoder initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian.start()
async def on_message(self, msg, consumer, flow): async def on_message(self, msg, consumer, flow):
logger.debug("PDF message received") logger.debug("PDF message received")
@ -91,9 +82,9 @@ class Processor(FlowProcessor):
# Check MIME type if fetching from librarian # Check MIME type if fetching from librarian
if v.document_id: if v.document_id:
doc_meta = await self.librarian.fetch_document_metadata( doc_meta = await flow.librarian.fetch_document_metadata(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error( logger.error(
@ -112,9 +103,9 @@ class Processor(FlowProcessor):
logger.info(f"Fetching document {v.document_id} from librarian...") logger.info(f"Fetching document {v.document_id} from librarian...")
fp.close() fp.close()
content = await self.librarian.fetch_document_content( content = await flow.librarian.fetch_document_content(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
# Content is base64 encoded # Content is base64 encoded
@ -154,10 +145,10 @@ class Processor(FlowProcessor):
page_content = page.page_content.encode("utf-8") page_content = page.page_content.encode("utf-8")
# Save page as child document in librarian # Save page as child document in librarian
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=page_doc_id, doc_id=page_doc_id,
parent_id=source_doc_id, parent_id=source_doc_id,
workspace=flow.workspace,
content=page_content, content=page_content,
document_type="page", document_type="page",
title=f"Page {page_num}", title=f"Page {page_num}",
@ -210,18 +201,6 @@ class Processor(FlowProcessor):
def add_args(parser): def add_args(parser):
FlowProcessor.add_args(parser) FlowProcessor.add_args(parser)
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)
def run(): def run():
Processor.launch(default_ident, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -41,7 +41,7 @@ class Processor(WorkspaceProcessor):
self.flow_request_queue_base = params.get( self.flow_request_queue_base = params.get(
"flow_request_queue", default_flow_request_queue "flow_request_queue", default_flow_request_queue
) )
flow_response_queue = params.get( self.flow_response_queue_base = params.get(
"flow_response_queue", default_flow_response_queue "flow_response_queue", default_flow_response_queue
) )
@ -54,17 +54,6 @@ class Processor(WorkspaceProcessor):
} }
) )
flow_response_metrics = ProducerMetrics(
processor = self.id, flow = None, name = "flow-response"
)
self.flow_response_producer = Producer(
backend = self.pubsub,
topic = flow_response_queue,
schema = FlowResponse,
metrics = flow_response_metrics,
)
config_req_metrics = ProducerMetrics( config_req_metrics = ProducerMetrics(
processor=self.id, flow=None, name="config-request", processor=self.id, flow=None, name="config-request",
) )
@ -96,17 +85,31 @@ class Processor(WorkspaceProcessor):
if workspace in self.workspace_consumers: if workspace in self.workspace_consumers:
return return
queue = workspace_queue( req_queue = workspace_queue(
self.flow_request_queue_base, workspace, self.flow_request_queue_base, workspace,
) )
resp_queue = workspace_queue(
self.flow_response_queue_base, workspace,
)
await self.pubsub.ensure_topic(queue) await self.pubsub.ensure_topic(req_queue)
await self.pubsub.ensure_topic(resp_queue)
response_producer = Producer(
backend=self.pubsub,
topic=resp_queue,
schema=FlowResponse,
metrics=ProducerMetrics(
processor=self.id, flow=None,
name=f"flow-response-{workspace}",
),
)
consumer = Consumer( consumer = Consumer(
taskgroup=self.taskgroup, taskgroup=self.taskgroup,
backend=self.pubsub, backend=self.pubsub,
flow=None, flow=None,
topic=queue, topic=req_queue,
subscriber=self.id, subscriber=self.id,
schema=FlowRequest, schema=FlowRequest,
handler=partial( handler=partial(
@ -118,16 +121,22 @@ class Processor(WorkspaceProcessor):
), ),
) )
await response_producer.start()
await consumer.start() await consumer.start()
self.workspace_consumers[workspace] = consumer
self.workspace_consumers[workspace] = {
"consumer": consumer,
"response": response_producer,
}
logger.info(f"Subscribed to workspace queue: {workspace}") logger.info(f"Subscribed to workspace queue: {workspace}")
async def on_workspace_deleted(self, workspace): async def on_workspace_deleted(self, workspace):
consumer = self.workspace_consumers.pop(workspace, None) clients = self.workspace_consumers.pop(workspace, None)
if consumer: if clients:
await consumer.stop() for client in clients.values():
await client.stop()
logger.info(f"Unsubscribed from workspace queue: {workspace}") logger.info(f"Unsubscribed from workspace queue: {workspace}")
async def start(self): async def start(self):
@ -149,9 +158,11 @@ class Processor(WorkspaceProcessor):
logger.debug(f"Handling flow request {id}...") logger.debug(f"Handling flow request {id}...")
producer = self.workspace_consumers[workspace]["response"]
resp = await self.flow.handle(v, workspace) resp = await self.flow.handle(v, workspace)
await self.flow_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -166,7 +177,7 @@ class Processor(WorkspaceProcessor):
), ),
) )
await self.flow_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )

View file

@ -7,11 +7,11 @@ import logging
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from ... schema import flow_request_queue from ... schema import flow_request_queue, flow_response_queue
from ... schema import librarian_request_queue from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import knowledge_request_queue from ... schema import knowledge_request_queue, knowledge_response_queue
from ... schema import collection_request_queue from ... schema import collection_request_queue, collection_response_queue
from ... schema import config_request_queue from ... schema import config_request_queue, config_response_queue
from . config import ConfigRequestor from . config import ConfigRequestor
from . flow import FlowRequestor from . flow import FlowRequestor
@ -96,6 +96,14 @@ workspace_default_request_queues = {
"collection-management": collection_request_queue, "collection-management": collection_request_queue,
} }
workspace_default_response_queues = {
"config": config_response_queue,
"flow": flow_response_queue,
"librarian": librarian_response_queue,
"knowledge": knowledge_response_queue,
"collection-management": collection_response_queue,
}
global_dispatchers = {**system_dispatchers, **workspace_dispatchers} global_dispatchers = {**system_dispatchers, **workspace_dispatchers}
sender_dispatchers = { sender_dispatchers = {
@ -267,11 +275,16 @@ class DispatcherManager:
response_queue = self.queue_overrides[kind].get("response") response_queue = self.queue_overrides[kind].get("response")
if kind in workspace_dispatchers and workspace: if kind in workspace_dispatchers and workspace:
base_queue = ( base_req_queue = (
request_queue request_queue
or workspace_default_request_queues[kind] or workspace_default_request_queues[kind]
) )
request_queue = f"{base_queue}:{workspace}" request_queue = f"{base_req_queue}:{workspace}"
base_resp_queue = (
response_queue
or workspace_default_response_queues[kind]
)
response_queue = f"{base_resp_queue}:{workspace}"
consumer_name = f"{self.prefix}-{kind}-{workspace}" consumer_name = f"{self.prefix}-{kind}-{workspace}"
else: else:
consumer_name = f"{self.prefix}-{kind}-request" consumer_name = f"{self.prefix}-{kind}-request"

View file

@ -69,7 +69,7 @@ class Processor(WorkspaceProcessor):
"librarian_request_queue", default_librarian_request_queue "librarian_request_queue", default_librarian_request_queue
) )
librarian_response_queue = params.get( self.librarian_response_queue_base = params.get(
"librarian_response_queue", default_librarian_response_queue "librarian_response_queue", default_librarian_response_queue
) )
@ -77,7 +77,7 @@ class Processor(WorkspaceProcessor):
"collection_request_queue", default_collection_request_queue "collection_request_queue", default_collection_request_queue
) )
collection_response_queue = params.get( self.collection_response_queue_base = params.get(
"collection_response_queue", default_collection_response_queue "collection_response_queue", default_collection_response_queue
) )
@ -132,9 +132,9 @@ class Processor(WorkspaceProcessor):
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"librarian_request_queue": self.librarian_request_queue_base, "librarian_request_queue": self.librarian_request_queue_base,
"librarian_response_queue": librarian_response_queue, "librarian_response_queue": self.librarian_response_queue_base,
"collection_request_queue": self.collection_request_queue_base, "collection_request_queue": self.collection_request_queue_base,
"collection_response_queue": collection_response_queue, "collection_response_queue": self.collection_response_queue_base,
"object_store_endpoint": object_store_endpoint, "object_store_endpoint": object_store_endpoint,
"object_store_access_key": object_store_access_key, "object_store_access_key": object_store_access_key,
"cassandra_host": self.cassandra_host, "cassandra_host": self.cassandra_host,
@ -143,28 +143,6 @@ class Processor(WorkspaceProcessor):
} }
) )
librarian_response_metrics = ProducerMetrics(
processor = self.id, flow = None, name = "librarian-response"
)
collection_response_metrics = ProducerMetrics(
processor = self.id, flow = None, name = "collection-response"
)
self.librarian_response_producer = Producer(
backend = self.pubsub,
topic = librarian_response_queue,
schema = LibrarianResponse,
metrics = librarian_response_metrics,
)
self.collection_response_producer = Producer(
backend = self.pubsub,
topic = collection_response_queue,
schema = CollectionManagementResponse,
metrics = collection_response_metrics,
)
# Config service client for collection management # Config service client for collection management
config_request_metrics = ProducerMetrics( config_request_metrics = ProducerMetrics(
processor = id, flow = None, name = "config-request" processor = id, flow = None, name = "config-request"
@ -230,21 +208,49 @@ class Processor(WorkspaceProcessor):
if workspace in self.workspace_consumers: if workspace in self.workspace_consumers:
return return
lib_queue = workspace_queue( lib_req_queue = workspace_queue(
self.librarian_request_queue_base, workspace, self.librarian_request_queue_base, workspace,
) )
col_queue = workspace_queue( lib_resp_queue = workspace_queue(
self.librarian_response_queue_base, workspace,
)
col_req_queue = workspace_queue(
self.collection_request_queue_base, workspace, self.collection_request_queue_base, workspace,
) )
col_resp_queue = workspace_queue(
self.collection_response_queue_base, workspace,
)
await self.pubsub.ensure_topic(lib_queue) await self.pubsub.ensure_topic(lib_req_queue)
await self.pubsub.ensure_topic(col_queue) await self.pubsub.ensure_topic(lib_resp_queue)
await self.pubsub.ensure_topic(col_req_queue)
await self.pubsub.ensure_topic(col_resp_queue)
lib_response_producer = Producer(
backend=self.pubsub,
topic=lib_resp_queue,
schema=LibrarianResponse,
metrics=ProducerMetrics(
processor=self.id, flow=None,
name=f"librarian-response-{workspace}",
),
)
col_response_producer = Producer(
backend=self.pubsub,
topic=col_resp_queue,
schema=CollectionManagementResponse,
metrics=ProducerMetrics(
processor=self.id, flow=None,
name=f"collection-response-{workspace}",
),
)
lib_consumer = Consumer( lib_consumer = Consumer(
taskgroup=self.taskgroup, taskgroup=self.taskgroup,
backend=self.pubsub, backend=self.pubsub,
flow=None, flow=None,
topic=lib_queue, topic=lib_req_queue,
subscriber=self.id, subscriber=self.id,
schema=LibrarianRequest, schema=LibrarianRequest,
handler=partial( handler=partial(
@ -260,7 +266,7 @@ class Processor(WorkspaceProcessor):
taskgroup=self.taskgroup, taskgroup=self.taskgroup,
backend=self.pubsub, backend=self.pubsub,
flow=None, flow=None,
topic=col_queue, topic=col_req_queue,
subscriber=self.id, subscriber=self.id,
schema=CollectionManagementRequest, schema=CollectionManagementRequest,
handler=partial( handler=partial(
@ -272,29 +278,31 @@ class Processor(WorkspaceProcessor):
), ),
) )
await lib_response_producer.start()
await col_response_producer.start()
await lib_consumer.start() await lib_consumer.start()
await col_consumer.start() await col_consumer.start()
self.workspace_consumers[workspace] = { self.workspace_consumers[workspace] = {
"librarian": lib_consumer, "librarian": lib_consumer,
"librarian-response": lib_response_producer,
"collection": col_consumer, "collection": col_consumer,
"collection-response": col_response_producer,
} }
logger.info(f"Subscribed to workspace queues: {workspace}") logger.info(f"Subscribed to workspace queues: {workspace}")
async def on_workspace_deleted(self, workspace): async def on_workspace_deleted(self, workspace):
consumers = self.workspace_consumers.pop(workspace, None) clients = self.workspace_consumers.pop(workspace, None)
if consumers: if clients:
for consumer in consumers.values(): for client in clients.values():
await consumer.stop() await client.stop()
logger.info(f"Unsubscribed from workspace queues: {workspace}") logger.info(f"Unsubscribed from workspace queues: {workspace}")
async def start(self): async def start(self):
await super(Processor, self).start() await super(Processor, self).start()
await self.librarian_response_producer.start()
await self.collection_response_producer.start()
await self.config_request_producer.start() await self.config_request_producer.start()
await self.config_response_consumer.start() await self.config_response_consumer.start()
@ -505,12 +513,14 @@ class Processor(WorkspaceProcessor):
logger.info(f"Handling librarian input {id}...") logger.info(f"Handling librarian input {id}...")
producer = self.workspace_consumers[workspace]["librarian-response"]
try: try:
# Handle streaming operations specially # Handle streaming operations specially
if v.operation == "stream-document": if v.operation == "stream-document":
async for resp in self.librarian.stream_document(v, workspace): async for resp in self.librarian.stream_document(v, workspace):
await self.librarian_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
return return
@ -518,7 +528,7 @@ class Processor(WorkspaceProcessor):
# Non-streaming operations # Non-streaming operations
resp = await self.process_request(v, workspace) resp = await self.process_request(v, workspace)
await self.librarian_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -532,7 +542,7 @@ class Processor(WorkspaceProcessor):
), ),
) )
await self.librarian_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -545,7 +555,7 @@ class Processor(WorkspaceProcessor):
), ),
) )
await self.librarian_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
@ -576,9 +586,11 @@ class Processor(WorkspaceProcessor):
logger.info(f"Handling collection request {id}...") logger.info(f"Handling collection request {id}...")
producer = self.workspace_consumers[workspace]["collection-response"]
try: try:
resp = await self.process_collection_request(v, workspace) resp = await self.process_collection_request(v, workspace)
await self.collection_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
except RequestError as e: except RequestError as e:
@ -589,7 +601,7 @@ class Processor(WorkspaceProcessor):
), ),
timestamp=datetime.now().isoformat() timestamp=datetime.now().isoformat()
) )
await self.collection_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )
except Exception as e: except Exception as e:
@ -600,7 +612,7 @@ class Processor(WorkspaceProcessor):
), ),
timestamp=datetime.now().isoformat() timestamp=datetime.now().isoformat()
) )
await self.collection_response_producer.send( await producer.send(
resp, properties={"id": id} resp, properties={"id": id}
) )

View file

@ -4,21 +4,16 @@ Simple RAG service, performs query using document RAG an LLM.
Input is query, output is response. Input is query, output is response.
""" """
import asyncio
import base64
import logging import logging
import uuid
from ... schema import DocumentRagQuery, DocumentRagResponse, Error from ... schema import DocumentRagQuery, DocumentRagResponse, Error
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... provenance import GRAPH_RETRIEVAL from ... provenance import GRAPH_RETRIEVAL
from . document_rag import DocumentRag from . document_rag import DocumentRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import DocumentEmbeddingsClientSpec from ... base import DocumentEmbeddingsClientSpec
from ... base import LibrarianClient from ... base import LibrarianSpec
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,58 +80,14 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client self.register_specification(
self.librarian = LibrarianClient( LibrarianSpec()
id=id,
backend=self.pubsub,
taskgroup=self.taskgroup,
) )
async def start(self):
await super(Processor, self).start()
await self.librarian.start()
async def fetch_chunk_content(self, chunk_id, workspace, timeout=120):
"""Fetch chunk content from librarian. Chunks are small so
single request-response is fine."""
return await self.librarian.fetch_document_text(
document_id=chunk_id, workspace=workspace, timeout=timeout,
)
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
"""Save answer content to the librarian."""
doc_metadata = DocumentMetadata(
id=doc_id,
workspace=workspace,
kind="text/plain",
title=title or "DocumentRAG Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
workspace=workspace,
)
await self.librarian.request(request, timeout=timeout)
return doc_id
async def on_request(self, msg, consumer, flow): async def on_request(self, msg, consumer, flow):
try: try:
self.rag = DocumentRag(
embeddings_client = flow("embeddings-request"),
doc_embeddings_client = flow("document-embeddings-request"),
prompt_client = flow("prompt-request"),
fetch_chunk = self.fetch_chunk_content,
verbose=True,
)
v = msg.value() v = msg.value()
# Sender-produced ID # Sender-produced ID
@ -144,15 +95,25 @@ class Processor(FlowProcessor):
logger.info(f"Handling input {id}...") logger.info(f"Handling input {id}...")
async def fetch_chunk(chunk_id, timeout=120):
return await flow.librarian.fetch_document_text(
document_id=chunk_id, timeout=timeout,
)
self.rag = DocumentRag(
embeddings_client = flow("embeddings-request"),
doc_embeddings_client = flow("document-embeddings-request"),
prompt_client = flow("prompt-request"),
fetch_chunk = fetch_chunk,
verbose=True,
)
if v.doc_limit: if v.doc_limit:
doc_limit = v.doc_limit doc_limit = v.doc_limit
else: else:
doc_limit = self.doc_limit doc_limit = self.doc_limit
# Real-time explainability callback - emits triples and IDs as they're generated
# Triples are stored in the request's collection with a named graph (urn:graph:retrieval)
async def send_explainability(triples, explain_id): async def send_explainability(triples, explain_id):
# Send triples to explainability queue - stores in same collection with named graph
await flow("explainability").send(Triples( await flow("explainability").send(Triples(
metadata=Metadata( metadata=Metadata(
id=explain_id, id=explain_id,
@ -161,7 +122,6 @@ class Processor(FlowProcessor):
triples=triples, triples=triples,
)) ))
# Send explain data to response queue
await flow("response").send( await flow("response").send(
DocumentRagResponse( DocumentRagResponse(
response=None, response=None,
@ -173,13 +133,12 @@ class Processor(FlowProcessor):
properties={"id": id} properties={"id": id}
) )
# Callback to save answer content to librarian
async def save_answer(doc_id, answer_text): async def save_answer(doc_id, answer_text):
await self.save_answer_content( await flow.librarian.save_document(
doc_id=doc_id, doc_id=doc_id,
workspace=flow.workspace,
content=answer_text, content=answer_text,
title=f"DocumentRAG Answer: {v.query[:50]}...", title=f"DocumentRAG Answer: {v.query[:50]}...",
document_type="answer",
) )
# Check if streaming is requested # Check if streaming is requested

View file

@ -4,29 +4,22 @@ Simple RAG service, performs query using graph RAG an LLM.
Input is query, output is response. Input is query, output is response.
""" """
import asyncio
import base64
import logging import logging
import uuid
from ... schema import GraphRagQuery, GraphRagResponse, Error from ... schema import GraphRagQuery, GraphRagResponse, Error
from ... schema import Triples, Metadata from ... schema import Triples, Metadata
from ... schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... provenance import GRAPH_RETRIEVAL from ... provenance import GRAPH_RETRIEVAL
from . graph_rag import GraphRag from . graph_rag import GraphRag
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec from ... base import FlowProcessor, ConsumerSpec, ProducerSpec
from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
from ... base import Consumer, Producer, ConsumerMetrics, ProducerMetrics from ... base import LibrarianSpec
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
default_ident = "graph-rag" default_ident = "graph-rag"
default_concurrency = 1 default_concurrency = 1
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor): class Processor(FlowProcessor):
@ -117,115 +110,12 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client for storing answer content self.register_specification(
librarian_request_q = params.get( LibrarianSpec()
"librarian_request_queue", default_librarian_request_queue
) )
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_librarian_requests = {}
logger.info("Graph RAG service initialized") logger.info("Graph RAG service initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_librarian_requests:
future = self.pending_librarian_requests.pop(request_id)
future.set_result(response)
async def save_answer_content(self, doc_id, workspace, content, title=None, timeout=120):
"""
Save answer content to the librarian.
Args:
doc_id: ID for the answer document
workspace: Workspace for isolation
content: Answer text content
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
doc_metadata = DocumentMetadata(
id=doc_id,
workspace=workspace,
kind="text/plain",
title=title or "GraphRAG Answer",
document_type="answer",
)
request = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content.encode("utf-8")).decode("utf-8"),
workspace=workspace,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_librarian_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving answer: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_librarian_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving answer document {doc_id}")
async def on_request(self, msg, consumer, flow): async def on_request(self, msg, consumer, flow):
try: try:
@ -306,13 +196,12 @@ class Processor(FlowProcessor):
else: else:
edge_limit = self.default_edge_limit edge_limit = self.default_edge_limit
# Callback to save answer content to librarian
async def save_answer(doc_id, answer_text): async def save_answer(doc_id, answer_text):
await self.save_answer_content( await flow.librarian.save_document(
doc_id=doc_id, doc_id=doc_id,
workspace=flow.workspace,
content=answer_text, content=answer_text,
title=f"GraphRAG Answer: {v.query[:50]}...", title=f"GraphRAG Answer: {v.query[:50]}...",
document_type="answer",
) )
# Check if streaming is requested # Check if streaming is requested

View file

@ -13,9 +13,8 @@ import pytesseract
from pdf2image import convert_from_bytes from pdf2image import convert_from_bytes
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
from ... provenance import ( from ... provenance import (
document_uri, page_uri as make_page_uri, derived_entity_triples, document_uri, page_uri as make_page_uri, derived_entity_triples,
@ -31,9 +30,6 @@ logger = logging.getLogger(__name__)
default_ident = "document-decoder" default_ident = "document-decoder"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class Processor(FlowProcessor): class Processor(FlowProcessor):
def __init__(self, **params): def __init__(self, **params):
@ -68,17 +64,12 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client self.register_specification(
self.librarian = LibrarianClient( LibrarianSpec()
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
) )
logger.info("PDF OCR processor initialized") logger.info("PDF OCR processor initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian.start()
async def on_message(self, msg, consumer, flow): async def on_message(self, msg, consumer, flow):
logger.info("PDF message received") logger.info("PDF message received")
@ -89,9 +80,8 @@ class Processor(FlowProcessor):
# Check MIME type if fetching from librarian # Check MIME type if fetching from librarian
if v.document_id: if v.document_id:
doc_meta = await self.librarian.fetch_document_metadata( doc_meta = await flow.librarian.fetch_document_metadata(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf": if doc_meta and doc_meta.kind and doc_meta.kind != "application/pdf":
logger.error( logger.error(
@ -104,9 +94,8 @@ class Processor(FlowProcessor):
# Get PDF content - fetch from librarian or use inline data # Get PDF content - fetch from librarian or use inline data
if v.document_id: if v.document_id:
logger.info(f"Fetching document {v.document_id} from librarian...") logger.info(f"Fetching document {v.document_id} from librarian...")
content = await self.librarian.fetch_document_content( content = await flow.librarian.fetch_document_content(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
if isinstance(content, str): if isinstance(content, str):
content = content.encode('utf-8') content = content.encode('utf-8')
@ -138,10 +127,9 @@ class Processor(FlowProcessor):
page_content = text.encode("utf-8") page_content = text.encode("utf-8")
# Save page as child document in librarian # Save page as child document in librarian
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=page_doc_id, doc_id=page_doc_id,
parent_id=source_doc_id, parent_id=source_doc_id,
workspace=flow.workspace,
content=page_content, content=page_content,
document_type="page", document_type="page",
title=f"Page {page_num}", title=f"Page {page_num}",
@ -189,18 +177,6 @@ class Processor(FlowProcessor):
FlowProcessor.add_args(parser) FlowProcessor.add_args(parser)
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)
def run(): def run():
Processor.launch(default_ident, __doc__) Processor.launch(default_ident, __doc__)

View file

@ -23,9 +23,8 @@ import os
from unstructured.partition.auto import partition from unstructured.partition.auto import partition
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianSpec
from ... provenance import ( from ... provenance import (
document_uri, page_uri as make_page_uri, document_uri, page_uri as make_page_uri,
@ -44,9 +43,6 @@ logger = logging.getLogger(__name__)
default_ident = "document-decoder" default_ident = "document-decoder"
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
# Mime type to unstructured content_type mapping # Mime type to unstructured content_type mapping
# unstructured auto-detects most formats, but we pass the hint when available # unstructured auto-detects most formats, but we pass the hint when available
MIME_EXTENSIONS = { MIME_EXTENSIONS = {
@ -162,17 +158,12 @@ class Processor(FlowProcessor):
) )
) )
# Librarian client self.register_specification(
self.librarian = LibrarianClient( LibrarianSpec()
id=id, backend=self.pubsub, taskgroup=self.taskgroup,
) )
logger.info("Universal decoder initialized") logger.info("Universal decoder initialized")
async def start(self):
await super(Processor, self).start()
await self.librarian.start()
def extract_elements(self, blob, mime_type=None): def extract_elements(self, blob, mime_type=None):
""" """
Extract elements from a document using unstructured. Extract elements from a document using unstructured.
@ -272,10 +263,9 @@ class Processor(FlowProcessor):
page_content = text.encode("utf-8") page_content = text.encode("utf-8")
# Save to librarian # Save to librarian
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=doc_id, doc_id=doc_id,
parent_id=parent_doc_id, parent_id=parent_doc_id,
workspace=flow.workspace,
content=page_content, content=page_content,
document_type="page" if is_page else "section", document_type="page" if is_page else "section",
title=label, title=label,
@ -351,10 +341,9 @@ class Processor(FlowProcessor):
# Save to librarian # Save to librarian
if img_content: if img_content:
await self.librarian.save_child_document( await flow.librarian.save_child_document(
doc_id=img_uri, doc_id=img_uri,
parent_id=parent_doc_id, parent_id=parent_doc_id,
workspace=flow.workspace,
content=img_content, content=img_content,
document_type="image", document_type="image",
title=f"Image from page {page_number}" if page_number else "Image", title=f"Image from page {page_number}" if page_number else "Image",
@ -399,15 +388,13 @@ class Processor(FlowProcessor):
f"Fetching document {v.document_id} from librarian..." f"Fetching document {v.document_id} from librarian..."
) )
doc_meta = await self.librarian.fetch_document_metadata( doc_meta = await flow.librarian.fetch_document_metadata(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
mime_type = doc_meta.kind if doc_meta else None mime_type = doc_meta.kind if doc_meta else None
content = await self.librarian.fetch_document_content( content = await flow.librarian.fetch_document_content(
document_id=v.document_id, document_id=v.document_id,
workspace=flow.workspace,
) )
if isinstance(content, str): if isinstance(content, str):
@ -571,19 +558,6 @@ class Processor(FlowProcessor):
help='Apply section strategy within pages too (default: false)', help='Apply section strategy within pages too (default: false)',
) )
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue '
f'(default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue '
f'(default: {default_librarian_response_queue})',
)
def run(): def run():