mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
RabbitMQ pub/sub backend with topic exchange architecture (#752)
Adds a RabbitMQ backend as an alternative to Pulsar, selectable via PUBSUB_BACKEND=rabbitmq. Both backends implement the same PubSubBackend protocol — no application code changes needed to switch. RabbitMQ topology: - Single topic exchange per topicspace (e.g. 'tg') - Routing key derived from queue class and topic name - Shared consumers: named queue bound to exchange (competing, round-robin) - Exclusive consumers: anonymous auto-delete queue (broadcast, each gets every message). Used by Subscriber and config push consumer. - Thread-local producer connections (pika is not thread-safe) - Push-based consumption via basic_consume with process_data_events for heartbeat processing Consumer model changes: - Consumer class creates one backend consumer per concurrent task (required for pika thread safety, harmless for Pulsar) - Consumer class accepts consumer_type parameter - Subscriber passes consumer_type='exclusive' for broadcast semantics - Config push consumer uses consumer_type='exclusive' so every processor instance receives config updates - handle_one_from_queue receives consumer as parameter for correct per-connection ack/nack LibrarianClient: - New shared client class replacing duplicated librarian request-response code across 6+ services (chunking, decoders, RAG, etc.) - Uses stream-document instead of get-document-content for fetching document content in 1MB chunks (avoids broker message size limits) - Standalone object (self.librarian = LibrarianClient(...)) not a mixin - get-document-content marked deprecated in schema and OpenAPI spec Serialisation: - Extracted dataclass_to_dict/dict_to_dataclass to shared serialization.py (used by both Pulsar and RabbitMQ backends) Librarian queues: - Changed from flow class (persistent) back to request/response class now that stream-document eliminates large single messages - API upload chunk size reduced from 5MB to 3MB to stay under broker limits after base64 encoding Factory and CLI: - get_pubsub() handles 'rabbitmq' backend with RabbitMQ connection params - add_pubsub_args() includes RabbitMQ options (host, port, credentials) - add_pubsub_args(standalone=True) defaults to localhost for CLI tools - init_trustgraph skips Pulsar admin setup for non-Pulsar backends - tg-dump-queues and tg-monitor-prompts use backend abstraction - BaseClient and ConfigClient accept generic pubsub config
This commit is contained in:
parent
4fb0b4d8e8
commit
24f0190ce7
36 changed files with 1277 additions and 1313 deletions
|
|
@ -14,6 +14,7 @@ dependencies = [
|
|||
"prometheus-client",
|
||||
"requests",
|
||||
"python-logging-loki",
|
||||
"pika",
|
||||
]
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
|
|
|
|||
|
|
@ -22,8 +22,9 @@ logger = logging.getLogger(__name__)
|
|||
# Lower threshold provides progress feedback and resumability on slower connections
|
||||
CHUNKED_UPLOAD_THRESHOLD = 2 * 1024 * 1024
|
||||
|
||||
# Default chunk size (5MB - S3 multipart minimum)
|
||||
DEFAULT_CHUNK_SIZE = 5 * 1024 * 1024
|
||||
# Default chunk size (3MB - stays under broker message size limits
|
||||
# after base64 encoding ~4MB)
|
||||
DEFAULT_CHUNK_SIZE = 3 * 1024 * 1024
|
||||
|
||||
|
||||
def to_value(x):
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from . producer_spec import ProducerSpec
|
|||
from . subscriber_spec import SubscriberSpec
|
||||
from . request_response_spec import RequestResponseSpec
|
||||
from . llm_service import LlmService, LlmResult, LlmChunk
|
||||
from . librarian_client import LibrarianClient
|
||||
from . chunking_service import ChunkingService
|
||||
from . embeddings_service import EmbeddingsService
|
||||
from . embeddings_client import EmbeddingsClientSpec
|
||||
|
|
|
|||
|
|
@ -68,11 +68,12 @@ class AsyncProcessor:
|
|||
processor = self.id, flow = None, name = "config",
|
||||
)
|
||||
|
||||
# Subscribe to config queue
|
||||
# Subscribe to config queue — exclusive so every processor
|
||||
# gets its own copy of config pushes (broadcast pattern)
|
||||
self.config_sub_task = Consumer(
|
||||
|
||||
taskgroup = self.taskgroup,
|
||||
backend = self.pubsub_backend, # Changed from client to backend
|
||||
backend = self.pubsub_backend,
|
||||
subscriber = config_subscriber_id,
|
||||
flow = None,
|
||||
|
||||
|
|
@ -83,9 +84,8 @@ class AsyncProcessor:
|
|||
|
||||
metrics = config_consumer_metrics,
|
||||
|
||||
# This causes new subscriptions to view the entire history of
|
||||
# configuration
|
||||
start_of_messages = True
|
||||
start_of_messages = True,
|
||||
consumer_type = 'exclusive',
|
||||
)
|
||||
|
||||
self.running = True
|
||||
|
|
|
|||
|
|
@ -7,23 +7,14 @@ fetching large document content.
|
|||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from .flow_processor import FlowProcessor
|
||||
from .parameter_spec import ParameterSpec
|
||||
from .consumer import Consumer
|
||||
from .producer import Producer
|
||||
from .metrics import ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ..schema import librarian_request_queue, librarian_response_queue
|
||||
from .librarian_client import LibrarianClient
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_librarian_request_queue = librarian_request_queue
|
||||
default_librarian_response_queue = librarian_response_queue
|
||||
|
||||
|
||||
class ChunkingService(FlowProcessor):
|
||||
"""Base service for chunking processors with parameter specification support"""
|
||||
|
|
@ -44,155 +35,18 @@ class ChunkingService(FlowProcessor):
|
|||
ParameterSpec(name="chunk-overlap")
|
||||
)
|
||||
|
||||
# Librarian client for fetching document content
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", default_librarian_request_queue
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", default_librarian_response_queue
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request"
|
||||
)
|
||||
|
||||
self.librarian_request_producer = Producer(
|
||||
# Librarian client
|
||||
self.librarian = LibrarianClient(
|
||||
id=id,
|
||||
backend=self.pubsub,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response"
|
||||
)
|
||||
|
||||
self.librarian_response_consumer = Consumer(
|
||||
taskgroup=self.taskgroup,
|
||||
backend=self.pubsub,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self.on_librarian_response,
|
||||
metrics=librarian_response_metrics,
|
||||
)
|
||||
|
||||
# Pending librarian requests: request_id -> asyncio.Future
|
||||
self.pending_requests = {}
|
||||
|
||||
logger.debug("ChunkingService initialized with parameter specifications")
|
||||
|
||||
async def start(self):
|
||||
await super(ChunkingService, self).start()
|
||||
await self.librarian_request_producer.start()
|
||||
await self.librarian_response_consumer.start()
|
||||
|
||||
async def on_librarian_response(self, msg, consumer, flow):
|
||||
"""Handle responses from the librarian service."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if request_id and request_id in self.pending_requests:
|
||||
future = self.pending_requests.pop(request_id)
|
||||
future.set_result(response)
|
||||
|
||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
||||
"""
|
||||
Fetch document content from librarian via Pulsar.
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="get-document-content",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return response.content
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout fetching document {document_id}")
|
||||
|
||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
||||
document_type="chunk", title=None, timeout=120):
|
||||
"""
|
||||
Save a child document (chunk) to the librarian.
|
||||
|
||||
Args:
|
||||
doc_id: ID for the new child document
|
||||
parent_id: ID of the parent document
|
||||
user: User ID
|
||||
content: Document content (bytes or str)
|
||||
document_type: Type of document ("chunk", etc.)
|
||||
title: Optional title
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
The document ID on success
|
||||
"""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind="text/plain",
|
||||
title=title or doc_id,
|
||||
parent_id=parent_id,
|
||||
document_type=document_type,
|
||||
)
|
||||
|
||||
request = LibrarianRequest(
|
||||
operation="add-child-document",
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content).decode("utf-8"),
|
||||
)
|
||||
|
||||
# Create future for response
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# Send request
|
||||
await self.librarian_request_producer.send(
|
||||
request, properties={"id": request_id}
|
||||
)
|
||||
|
||||
# Wait for response
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error saving chunk: {response.error.type}: {response.error.message}"
|
||||
)
|
||||
|
||||
return doc_id
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(request_id, None)
|
||||
raise RuntimeError(f"Timeout saving chunk {doc_id}")
|
||||
await self.librarian.start()
|
||||
|
||||
async def get_document_text(self, doc):
|
||||
"""
|
||||
|
|
@ -206,14 +60,10 @@ class ChunkingService(FlowProcessor):
|
|||
"""
|
||||
if doc.document_id and not doc.text:
|
||||
logger.info(f"Fetching document {doc.document_id} from librarian...")
|
||||
content = await self.fetch_document_content(
|
||||
text = await self.librarian.fetch_document_text(
|
||||
document_id=doc.document_id,
|
||||
user=doc.metadata.user,
|
||||
)
|
||||
# Content is base64 encoded
|
||||
if isinstance(content, str):
|
||||
content = content.encode('utf-8')
|
||||
text = base64.b64decode(content).decode("utf-8")
|
||||
logger.info(f"Fetched {len(text)} characters from librarian")
|
||||
return text
|
||||
else:
|
||||
|
|
@ -224,41 +74,31 @@ class ChunkingService(FlowProcessor):
|
|||
Extract chunk parameters from flow and return effective values
|
||||
|
||||
Args:
|
||||
msg: The message containing the document to chunk
|
||||
consumer: The consumer spec
|
||||
flow: The flow context
|
||||
default_chunk_size: Default chunk size from processor config
|
||||
default_chunk_overlap: Default chunk overlap from processor config
|
||||
msg: The message being processed
|
||||
consumer: The consumer instance
|
||||
flow: The flow object containing parameters
|
||||
default_chunk_size: Default chunk size if not configured
|
||||
default_chunk_overlap: Default chunk overlap if not configured
|
||||
|
||||
Returns:
|
||||
tuple: (chunk_size, chunk_overlap) - effective values to use
|
||||
tuple: (chunk_size, chunk_overlap) effective values
|
||||
"""
|
||||
# Extract parameters from flow (flow-configurable parameters)
|
||||
chunk_size = flow("chunk-size")
|
||||
chunk_overlap = flow("chunk-overlap")
|
||||
|
||||
# Use provided values or fall back to defaults
|
||||
effective_chunk_size = chunk_size if chunk_size is not None else default_chunk_size
|
||||
effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else default_chunk_overlap
|
||||
chunk_size = default_chunk_size
|
||||
chunk_overlap = default_chunk_overlap
|
||||
|
||||
logger.debug(f"Using chunk-size: {effective_chunk_size}")
|
||||
logger.debug(f"Using chunk-overlap: {effective_chunk_overlap}")
|
||||
try:
|
||||
cs = flow.parameters.get("chunk-size")
|
||||
if cs is not None:
|
||||
chunk_size = int(cs)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse chunk-size parameter: {e}")
|
||||
|
||||
return effective_chunk_size, effective_chunk_overlap
|
||||
try:
|
||||
co = flow.parameters.get("chunk-overlap")
|
||||
if co is not None:
|
||||
chunk_overlap = int(co)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not parse chunk-overlap parameter: {e}")
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add chunking service arguments to parser"""
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'--librarian-request-queue',
|
||||
default=default_librarian_request_queue,
|
||||
help=f'Librarian request queue (default: {default_librarian_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--librarian-response-queue',
|
||||
default=default_librarian_response_queue,
|
||||
help=f'Librarian response queue (default: {default_librarian_response_queue})',
|
||||
)
|
||||
return chunk_size, chunk_overlap
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class Consumer:
|
|||
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
|
||||
reconnect_time = 5,
|
||||
concurrency = 1, # Number of concurrent requests to handle
|
||||
consumer_type = 'shared',
|
||||
):
|
||||
|
||||
self.taskgroup = taskgroup
|
||||
|
|
@ -42,6 +43,8 @@ class Consumer:
|
|||
self.schema = schema
|
||||
self.handler = handler
|
||||
|
||||
self.consumer_type = consumer_type
|
||||
|
||||
self.rate_limit_retry_time = rate_limit_retry_time
|
||||
self.rate_limit_timeout = rate_limit_timeout
|
||||
|
||||
|
|
@ -93,33 +96,11 @@ class Consumer:
|
|||
if self.metrics:
|
||||
self.metrics.state("stopped")
|
||||
|
||||
try:
|
||||
|
||||
logger.info(f"Subscribing to topic: {self.topic}")
|
||||
|
||||
# Determine initial position
|
||||
if self.start_of_messages:
|
||||
initial_pos = 'earliest'
|
||||
else:
|
||||
initial_pos = 'latest'
|
||||
|
||||
# Create consumer via backend
|
||||
self.consumer = await asyncio.to_thread(
|
||||
self.backend.create_consumer,
|
||||
topic = self.topic,
|
||||
subscription = self.subscriber,
|
||||
schema = self.schema,
|
||||
initial_position = initial_pos,
|
||||
consumer_type = 'shared',
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"Consumer subscription exception: {e}", exc_info=True)
|
||||
await asyncio.sleep(self.reconnect_time)
|
||||
continue
|
||||
|
||||
logger.info(f"Successfully subscribed to topic: {self.topic}")
|
||||
# Determine initial position
|
||||
if self.start_of_messages:
|
||||
initial_pos = 'earliest'
|
||||
else:
|
||||
initial_pos = 'latest'
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.state("running")
|
||||
|
|
@ -128,14 +109,30 @@ class Consumer:
|
|||
|
||||
logger.info(f"Starting {self.concurrency} receiver threads")
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
|
||||
tasks = []
|
||||
|
||||
for i in range(0, self.concurrency):
|
||||
tasks.append(
|
||||
tg.create_task(self.consume_from_queue())
|
||||
# Create one backend consumer per concurrent task.
|
||||
# Each gets its own connection — required for backends
|
||||
# like RabbitMQ where connections are not thread-safe.
|
||||
consumers = []
|
||||
for i in range(self.concurrency):
|
||||
try:
|
||||
logger.info(f"Subscribing to topic: {self.topic} (worker {i})")
|
||||
c = await asyncio.to_thread(
|
||||
self.backend.create_consumer,
|
||||
topic = self.topic,
|
||||
subscription = self.subscriber,
|
||||
schema = self.schema,
|
||||
initial_position = initial_pos,
|
||||
consumer_type = self.consumer_type,
|
||||
)
|
||||
consumers.append(c)
|
||||
logger.info(f"Successfully subscribed to topic: {self.topic} (worker {i})")
|
||||
except Exception as e:
|
||||
logger.error(f"Consumer subscription exception (worker {i}): {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
for c in consumers:
|
||||
tg.create_task(self.consume_from_queue(c))
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.state("stopped")
|
||||
|
|
@ -143,23 +140,31 @@ class Consumer:
|
|||
except Exception as e:
|
||||
|
||||
logger.error(f"Consumer loop exception: {e}", exc_info=True)
|
||||
self.consumer.unsubscribe()
|
||||
self.consumer.close()
|
||||
self.consumer = None
|
||||
for c in consumers:
|
||||
try:
|
||||
c.unsubscribe()
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
consumers = []
|
||||
await asyncio.sleep(self.reconnect_time)
|
||||
continue
|
||||
|
||||
if self.consumer:
|
||||
self.consumer.unsubscribe()
|
||||
self.consumer.close()
|
||||
finally:
|
||||
for c in consumers:
|
||||
try:
|
||||
c.unsubscribe()
|
||||
c.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def consume_from_queue(self):
|
||||
async def consume_from_queue(self, consumer):
|
||||
|
||||
while self.running:
|
||||
|
||||
try:
|
||||
msg = await asyncio.to_thread(
|
||||
self.consumer.receive,
|
||||
consumer.receive,
|
||||
timeout_millis=2000
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -168,9 +173,9 @@ class Consumer:
|
|||
continue
|
||||
raise e
|
||||
|
||||
await self.handle_one_from_queue(msg)
|
||||
await self.handle_one_from_queue(msg, consumer)
|
||||
|
||||
async def handle_one_from_queue(self, msg):
|
||||
async def handle_one_from_queue(self, msg, consumer):
|
||||
|
||||
expiry = time.time() + self.rate_limit_timeout
|
||||
|
||||
|
|
@ -183,7 +188,7 @@ class Consumer:
|
|||
|
||||
# Message failed to be processed, this causes it to
|
||||
# be retried
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
consumer.negative_acknowledge(msg)
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.process("error")
|
||||
|
|
@ -206,7 +211,7 @@ class Consumer:
|
|||
logger.debug("Message processed successfully")
|
||||
|
||||
# Acknowledge successful processing of the message
|
||||
self.consumer.acknowledge(msg)
|
||||
consumer.acknowledge(msg)
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.process("success")
|
||||
|
|
@ -233,7 +238,7 @@ class Consumer:
|
|||
|
||||
# Message failed to be processed, this causes it to
|
||||
# be retried
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
consumer.negative_acknowledge(msg)
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.process("error")
|
||||
|
|
|
|||
246
trustgraph-base/trustgraph/base/librarian_client.py
Normal file
246
trustgraph-base/trustgraph/base/librarian_client.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
"""
|
||||
Shared librarian client for services that need to communicate
|
||||
with the librarian via pub/sub.
|
||||
|
||||
Provides request-response and streaming operations over the message
|
||||
broker, with proper support for large documents via stream-document.
|
||||
|
||||
Usage:
|
||||
self.librarian = LibrarianClient(
|
||||
id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params
|
||||
)
|
||||
await self.librarian.start()
|
||||
content = await self.librarian.fetch_document_content(doc_id, user)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from .consumer import Consumer
|
||||
from .producer import Producer
|
||||
from .metrics import ConsumerMetrics, ProducerMetrics
|
||||
|
||||
from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
|
||||
from ..schema import librarian_request_queue, librarian_response_queue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LibrarianClient:
|
||||
"""Client for librarian request-response over the message broker."""
|
||||
|
||||
def __init__(self, id, backend, taskgroup, **params):
|
||||
|
||||
librarian_request_q = params.get(
|
||||
"librarian_request_queue", librarian_request_queue,
|
||||
)
|
||||
librarian_response_q = params.get(
|
||||
"librarian_response_queue", librarian_response_queue,
|
||||
)
|
||||
|
||||
librarian_request_metrics = ProducerMetrics(
|
||||
processor=id, flow=None, name="librarian-request",
|
||||
)
|
||||
|
||||
self._producer = Producer(
|
||||
backend=backend,
|
||||
topic=librarian_request_q,
|
||||
schema=LibrarianRequest,
|
||||
metrics=librarian_request_metrics,
|
||||
)
|
||||
|
||||
librarian_response_metrics = ConsumerMetrics(
|
||||
processor=id, flow=None, name="librarian-response",
|
||||
)
|
||||
|
||||
self._consumer = Consumer(
|
||||
taskgroup=taskgroup,
|
||||
backend=backend,
|
||||
flow=None,
|
||||
topic=librarian_response_q,
|
||||
subscriber=f"{id}-librarian",
|
||||
schema=LibrarianResponse,
|
||||
handler=self._on_response,
|
||||
metrics=librarian_response_metrics,
|
||||
consumer_type='exclusive',
|
||||
)
|
||||
|
||||
# Single-response requests: request_id -> asyncio.Future
|
||||
self._pending = {}
|
||||
# Streaming requests: request_id -> asyncio.Queue
|
||||
self._streams = {}
|
||||
|
||||
async def start(self):
|
||||
"""Start the librarian producer and consumer."""
|
||||
await self._producer.start()
|
||||
await self._consumer.start()
|
||||
|
||||
async def _on_response(self, msg, consumer, flow):
|
||||
"""Route librarian responses to the right waiter."""
|
||||
response = msg.value()
|
||||
request_id = msg.properties().get("id")
|
||||
|
||||
if not request_id:
|
||||
return
|
||||
|
||||
if request_id in self._pending:
|
||||
future = self._pending.pop(request_id)
|
||||
future.set_result(response)
|
||||
elif request_id in self._streams:
|
||||
await self._streams[request_id].put(response)
|
||||
|
||||
async def request(self, request, timeout=120):
|
||||
"""Send a request to the librarian and wait for a single response."""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._pending[request_id] = future
|
||||
|
||||
try:
|
||||
await self._producer.send(
|
||||
request, properties={"id": request_id},
|
||||
)
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: "
|
||||
f"{response.error.message}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._pending.pop(request_id, None)
|
||||
raise RuntimeError("Timeout waiting for librarian response")
|
||||
|
||||
async def stream(self, request, timeout=120):
|
||||
"""Send a request and collect streamed response chunks."""
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
q = asyncio.Queue()
|
||||
self._streams[request_id] = q
|
||||
|
||||
try:
|
||||
await self._producer.send(
|
||||
request, properties={"id": request_id},
|
||||
)
|
||||
|
||||
chunks = []
|
||||
while True:
|
||||
response = await asyncio.wait_for(q.get(), timeout=timeout)
|
||||
|
||||
if response.error:
|
||||
raise RuntimeError(
|
||||
f"Librarian error: {response.error.type}: "
|
||||
f"{response.error.message}"
|
||||
)
|
||||
|
||||
chunks.append(response)
|
||||
|
||||
if response.is_final:
|
||||
break
|
||||
|
||||
return chunks
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._streams.pop(request_id, None)
|
||||
raise RuntimeError("Timeout waiting for librarian stream")
|
||||
finally:
|
||||
self._streams.pop(request_id, None)
|
||||
|
||||
async def fetch_document_content(self, document_id, user, timeout=120):
|
||||
"""Fetch document content using streaming.
|
||||
|
||||
Returns base64-encoded content. Caller is responsible for decoding.
|
||||
"""
|
||||
req = LibrarianRequest(
|
||||
operation="stream-document",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
chunks = await self.stream(req, timeout=timeout)
|
||||
|
||||
# Decode each chunk's base64 to raw bytes, concatenate,
|
||||
# re-encode for the caller.
|
||||
raw = b""
|
||||
for chunk in chunks:
|
||||
if chunk.content:
|
||||
if isinstance(chunk.content, bytes):
|
||||
raw += base64.b64decode(chunk.content)
|
||||
else:
|
||||
raw += base64.b64decode(
|
||||
chunk.content.encode("utf-8")
|
||||
)
|
||||
|
||||
return base64.b64encode(raw)
|
||||
|
||||
async def fetch_document_text(self, document_id, user, timeout=120):
|
||||
"""Fetch document content and decode as UTF-8 text."""
|
||||
content = await self.fetch_document_content(
|
||||
document_id, user, timeout=timeout,
|
||||
)
|
||||
return base64.b64decode(content).decode("utf-8")
|
||||
|
||||
async def fetch_document_metadata(self, document_id, user, timeout=120):
|
||||
"""Fetch document metadata from the librarian."""
|
||||
req = LibrarianRequest(
|
||||
operation="get-document-metadata",
|
||||
document_id=document_id,
|
||||
user=user,
|
||||
)
|
||||
response = await self.request(req, timeout=timeout)
|
||||
return response.document_metadata
|
||||
|
||||
async def save_child_document(self, doc_id, parent_id, user, content,
|
||||
document_type="chunk", title=None,
|
||||
kind="text/plain", timeout=120):
|
||||
"""Save a child document to the librarian."""
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind=kind,
|
||||
title=title or doc_id,
|
||||
parent_id=parent_id,
|
||||
document_type=document_type,
|
||||
)
|
||||
|
||||
req = LibrarianRequest(
|
||||
operation="add-child-document",
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content).decode("utf-8"),
|
||||
)
|
||||
|
||||
await self.request(req, timeout=timeout)
|
||||
return doc_id
|
||||
|
||||
async def save_document(self, doc_id, user, content, title=None,
|
||||
document_type="answer", kind="text/plain",
|
||||
timeout=120):
|
||||
"""Save a document to the librarian."""
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
|
||||
doc_metadata = DocumentMetadata(
|
||||
id=doc_id,
|
||||
user=user,
|
||||
kind=kind,
|
||||
title=title or doc_id,
|
||||
document_type=document_type,
|
||||
)
|
||||
|
||||
req = LibrarianRequest(
|
||||
operation="add-document",
|
||||
document_id=doc_id,
|
||||
document_metadata=doc_metadata,
|
||||
content=base64.b64encode(content).decode("utf-8"),
|
||||
user=user,
|
||||
)
|
||||
|
||||
await self.request(req, timeout=timeout)
|
||||
return doc_id
|
||||
|
|
@ -8,6 +8,12 @@ logger = logging.getLogger(__name__)
|
|||
DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
|
||||
DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None)
|
||||
|
||||
DEFAULT_RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", 'rabbitmq')
|
||||
DEFAULT_RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", '5672'))
|
||||
DEFAULT_RABBITMQ_USERNAME = os.getenv("RABBITMQ_USERNAME", 'guest')
|
||||
DEFAULT_RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", 'guest')
|
||||
DEFAULT_RABBITMQ_VHOST = os.getenv("RABBITMQ_VHOST", '/')
|
||||
|
||||
|
||||
def get_pubsub(**config):
|
||||
"""
|
||||
|
|
@ -29,6 +35,15 @@ def get_pubsub(**config):
|
|||
api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY),
|
||||
listener=config.get('pulsar_listener'),
|
||||
)
|
||||
elif backend_type == 'rabbitmq':
|
||||
from .rabbitmq_backend import RabbitMQBackend
|
||||
return RabbitMQBackend(
|
||||
host=config.get('rabbitmq_host', DEFAULT_RABBITMQ_HOST),
|
||||
port=config.get('rabbitmq_port', DEFAULT_RABBITMQ_PORT),
|
||||
username=config.get('rabbitmq_username', DEFAULT_RABBITMQ_USERNAME),
|
||||
password=config.get('rabbitmq_password', DEFAULT_RABBITMQ_PASSWORD),
|
||||
vhost=config.get('rabbitmq_vhost', DEFAULT_RABBITMQ_VHOST),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
|
||||
|
||||
|
|
@ -44,8 +59,9 @@ def add_pubsub_args(parser, standalone=False):
|
|||
standalone: If True, default host is localhost (for CLI tools
|
||||
that run outside containers)
|
||||
"""
|
||||
host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST
|
||||
listener_default = 'localhost' if standalone else None
|
||||
pulsar_host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST
|
||||
pulsar_listener = 'localhost' if standalone else None
|
||||
rabbitmq_host = 'localhost' if standalone else DEFAULT_RABBITMQ_HOST
|
||||
|
||||
parser.add_argument(
|
||||
'--pubsub-backend',
|
||||
|
|
@ -53,10 +69,11 @@ def add_pubsub_args(parser, standalone=False):
|
|||
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
|
||||
)
|
||||
|
||||
# Pulsar options
|
||||
parser.add_argument(
|
||||
'-p', '--pulsar-host',
|
||||
default=host,
|
||||
help=f'Pulsar host (default: {host})',
|
||||
default=pulsar_host,
|
||||
help=f'Pulsar host (default: {pulsar_host})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -67,6 +84,38 @@ def add_pubsub_args(parser, standalone=False):
|
|||
|
||||
parser.add_argument(
|
||||
'--pulsar-listener',
|
||||
default=listener_default,
|
||||
help=f'Pulsar listener (default: {listener_default or "none"})',
|
||||
default=pulsar_listener,
|
||||
help=f'Pulsar listener (default: {pulsar_listener or "none"})',
|
||||
)
|
||||
|
||||
# RabbitMQ options
|
||||
parser.add_argument(
|
||||
'--rabbitmq-host',
|
||||
default=rabbitmq_host,
|
||||
help=f'RabbitMQ host (default: {rabbitmq_host})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rabbitmq-port',
|
||||
type=int,
|
||||
default=DEFAULT_RABBITMQ_PORT,
|
||||
help=f'RabbitMQ port (default: {DEFAULT_RABBITMQ_PORT})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rabbitmq-username',
|
||||
default=DEFAULT_RABBITMQ_USERNAME,
|
||||
help='RabbitMQ username',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rabbitmq-password',
|
||||
default=DEFAULT_RABBITMQ_PASSWORD,
|
||||
help='RabbitMQ password',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rabbitmq-vhost',
|
||||
default=DEFAULT_RABBITMQ_VHOST,
|
||||
help=f'RabbitMQ vhost (default: {DEFAULT_RABBITMQ_VHOST})',
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,122 +9,14 @@ import pulsar
|
|||
import _pulsar
|
||||
import json
|
||||
import logging
|
||||
import base64
|
||||
import types
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Any, get_type_hints
|
||||
from typing import Any
|
||||
|
||||
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
|
||||
from .serialization import dataclass_to_dict, dict_to_dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def dataclass_to_dict(obj: Any) -> dict:
|
||||
"""
|
||||
Recursively convert a dataclass to a dictionary, handling None values and bytes.
|
||||
|
||||
None values are excluded from the dictionary (not serialized).
|
||||
Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior).
|
||||
Handles nested dataclasses, lists, and dictionaries recursively.
|
||||
"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Handle bytes - decode to UTF-8 for JSON serialization
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode('utf-8')
|
||||
|
||||
# Handle dataclass - convert to dict then recursively process all values
|
||||
if is_dataclass(obj):
|
||||
result = {}
|
||||
for key, value in asdict(obj).items():
|
||||
result[key] = dataclass_to_dict(value) if value is not None else None
|
||||
return result
|
||||
|
||||
# Handle list - recursively process all items
|
||||
if isinstance(obj, list):
|
||||
return [dataclass_to_dict(item) for item in obj]
|
||||
|
||||
# Handle dict - recursively process all values
|
||||
if isinstance(obj, dict):
|
||||
return {k: dataclass_to_dict(v) for k, v in obj.items()}
|
||||
|
||||
# Return primitive types as-is
|
||||
return obj
|
||||
|
||||
|
||||
def dict_to_dataclass(data: dict, cls: type) -> Any:
|
||||
"""
|
||||
Convert a dictionary back to a dataclass instance.
|
||||
|
||||
Handles nested dataclasses and missing fields.
|
||||
Uses get_type_hints() to resolve forward references (string annotations).
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
if not is_dataclass(cls):
|
||||
return data
|
||||
|
||||
# Get field types from the dataclass, resolving forward references
|
||||
# get_type_hints() evaluates string annotations like "Triple | None"
|
||||
try:
|
||||
field_types = get_type_hints(cls)
|
||||
except Exception:
|
||||
# Fallback if get_type_hints fails (shouldn't happen normally)
|
||||
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
|
||||
kwargs = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if key in field_types:
|
||||
field_type = field_types[key]
|
||||
|
||||
# Handle modern union types (X | Y)
|
||||
if isinstance(field_type, types.UnionType):
|
||||
# Check if it's Optional (X | None)
|
||||
if type(None) in field_type.__args__:
|
||||
# Get the non-None type
|
||||
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||
else:
|
||||
kwargs[key] = value
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Check if this is a generic type (list, dict, etc.)
|
||||
elif hasattr(field_type, '__origin__'):
|
||||
# Handle list[T]
|
||||
if field_type.__origin__ == list:
|
||||
item_type = field_type.__args__[0] if field_type.__args__ else None
|
||||
if item_type and is_dataclass(item_type) and isinstance(value, list):
|
||||
kwargs[key] = [
|
||||
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Handle old-style Optional[T] (which is Union[T, None])
|
||||
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
|
||||
# Get the non-None type from Union
|
||||
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||
else:
|
||||
kwargs[key] = value
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Handle direct dataclass fields
|
||||
elif is_dataclass(field_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, field_type)
|
||||
# Handle bytes fields (UTF-8 encoded strings from JSON)
|
||||
elif field_type == bytes and isinstance(value, str):
|
||||
kwargs[key] = value.encode('utf-8')
|
||||
else:
|
||||
kwargs[key] = value
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
class PulsarMessage:
|
||||
"""Wrapper for Pulsar messages to match Message protocol."""
|
||||
|
||||
|
|
|
|||
390
trustgraph-base/trustgraph/base/rabbitmq_backend.py
Normal file
390
trustgraph-base/trustgraph/base/rabbitmq_backend.py
Normal file
|
|
@ -0,0 +1,390 @@
|
|||
"""
|
||||
RabbitMQ backend implementation for pub/sub abstraction.
|
||||
|
||||
Uses a single topic exchange per topicspace. The logical queue name
|
||||
becomes the routing key. Consumer behavior is determined by the
|
||||
subscription name:
|
||||
|
||||
- Same subscription + same topic = shared queue (competing consumers)
|
||||
- Different subscriptions = separate queues (broadcast / fan-out)
|
||||
|
||||
This mirrors Pulsar's subscription model using idiomatic RabbitMQ.
|
||||
|
||||
Architecture:
|
||||
Producer --> [tg exchange] --routing key--> [named queue] --> Consumer
|
||||
--routing key--> [named queue] --> Consumer
|
||||
--routing key--> [exclusive q] --> Subscriber
|
||||
|
||||
Uses basic_consume (push) instead of basic_get (polling) for
|
||||
efficient message delivery.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import pika
|
||||
from typing import Any
|
||||
|
||||
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
|
||||
from .serialization import dataclass_to_dict, dict_to_dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RabbitMQMessage:
|
||||
"""Wrapper for RabbitMQ messages to match Message protocol."""
|
||||
|
||||
def __init__(self, method, properties, body, schema_cls):
|
||||
self._method = method
|
||||
self._properties = properties
|
||||
self._body = body
|
||||
self._schema_cls = schema_cls
|
||||
self._value = None
|
||||
|
||||
def value(self) -> Any:
|
||||
"""Deserialize and return the message value as a dataclass."""
|
||||
if self._value is None:
|
||||
data_dict = json.loads(self._body.decode('utf-8'))
|
||||
self._value = dict_to_dataclass(data_dict, self._schema_cls)
|
||||
return self._value
|
||||
|
||||
def properties(self) -> dict:
|
||||
"""Return message properties from AMQP headers."""
|
||||
headers = self._properties.headers or {}
|
||||
return dict(headers)
|
||||
|
||||
|
||||
class RabbitMQBackendProducer:
|
||||
"""Publishes messages to a topic exchange with a routing key.
|
||||
|
||||
Uses thread-local connections so each thread gets its own
|
||||
connection/channel. This avoids wire corruption from concurrent
|
||||
threads writing to the same socket (pika is not thread-safe).
|
||||
"""
|
||||
|
||||
def __init__(self, connection_params, exchange_name, routing_key,
|
||||
durable):
|
||||
self._connection_params = connection_params
|
||||
self._exchange_name = exchange_name
|
||||
self._routing_key = routing_key
|
||||
self._durable = durable
|
||||
self._local = threading.local()
|
||||
|
||||
def _get_channel(self):
|
||||
"""Get or create a thread-local connection and channel."""
|
||||
conn = getattr(self._local, 'connection', None)
|
||||
chan = getattr(self._local, 'channel', None)
|
||||
|
||||
if conn is None or not conn.is_open or chan is None or not chan.is_open:
|
||||
# Close stale connection if any
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
conn = pika.BlockingConnection(self._connection_params)
|
||||
chan = conn.channel()
|
||||
chan.exchange_declare(
|
||||
exchange=self._exchange_name,
|
||||
exchange_type='topic',
|
||||
durable=True,
|
||||
)
|
||||
self._local.connection = conn
|
||||
self._local.channel = chan
|
||||
|
||||
return chan
|
||||
|
||||
def send(self, message: Any, properties: dict = {}) -> None:
|
||||
data_dict = dataclass_to_dict(message)
|
||||
json_data = json.dumps(data_dict)
|
||||
|
||||
amqp_properties = pika.BasicProperties(
|
||||
delivery_mode=2 if self._durable else 1,
|
||||
content_type='application/json',
|
||||
headers=properties if properties else None,
|
||||
)
|
||||
|
||||
for attempt in range(2):
|
||||
try:
|
||||
channel = self._get_channel()
|
||||
channel.basic_publish(
|
||||
exchange=self._exchange_name,
|
||||
routing_key=self._routing_key,
|
||||
body=json_data.encode('utf-8'),
|
||||
properties=amqp_properties,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"RabbitMQ send failed (attempt {attempt + 1}): {e}"
|
||||
)
|
||||
# Force reconnect on next attempt
|
||||
self._local.connection = None
|
||||
self._local.channel = None
|
||||
if attempt == 1:
|
||||
raise
|
||||
|
||||
def flush(self) -> None:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the thread-local connection if any."""
|
||||
conn = getattr(self._local, 'connection', None)
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._local.connection = None
|
||||
self._local.channel = None
|
||||
|
||||
|
||||
class RabbitMQBackendConsumer:
|
||||
"""Consumes from a queue bound to a topic exchange.
|
||||
|
||||
Uses basic_consume (push model) with messages delivered to an
|
||||
internal thread-safe queue. process_data_events() drives both
|
||||
message delivery and heartbeat processing.
|
||||
"""
|
||||
|
||||
def __init__(self, connection_params, exchange_name, routing_key,
|
||||
queue_name, schema_cls, durable, exclusive=False,
|
||||
auto_delete=False):
|
||||
self._connection_params = connection_params
|
||||
self._exchange_name = exchange_name
|
||||
self._routing_key = routing_key
|
||||
self._queue_name = queue_name
|
||||
self._schema_cls = schema_cls
|
||||
self._durable = durable
|
||||
self._exclusive = exclusive
|
||||
self._auto_delete = auto_delete
|
||||
self._connection = None
|
||||
self._channel = None
|
||||
self._consumer_tag = None
|
||||
self._incoming = queue.Queue()
|
||||
|
||||
def _connect(self):
|
||||
self._connection = pika.BlockingConnection(self._connection_params)
|
||||
self._channel = self._connection.channel()
|
||||
|
||||
# Declare the topic exchange
|
||||
self._channel.exchange_declare(
|
||||
exchange=self._exchange_name,
|
||||
exchange_type='topic',
|
||||
durable=True,
|
||||
)
|
||||
|
||||
# Declare the queue — anonymous if exclusive
|
||||
result = self._channel.queue_declare(
|
||||
queue=self._queue_name,
|
||||
durable=self._durable,
|
||||
exclusive=self._exclusive,
|
||||
auto_delete=self._auto_delete,
|
||||
)
|
||||
# Capture actual name (important for anonymous queues where name='')
|
||||
self._queue_name = result.method.queue
|
||||
|
||||
self._channel.queue_bind(
|
||||
queue=self._queue_name,
|
||||
exchange=self._exchange_name,
|
||||
routing_key=self._routing_key,
|
||||
)
|
||||
|
||||
self._channel.basic_qos(prefetch_count=1)
|
||||
|
||||
# Register push-based consumer
|
||||
self._consumer_tag = self._channel.basic_consume(
|
||||
queue=self._queue_name,
|
||||
on_message_callback=self._on_message,
|
||||
auto_ack=False,
|
||||
)
|
||||
|
||||
def _on_message(self, channel, method, properties, body):
|
||||
"""Callback invoked by pika when a message arrives."""
|
||||
self._incoming.put((method, properties, body))
|
||||
|
||||
def _is_alive(self):
|
||||
return (
|
||||
self._connection is not None
|
||||
and self._connection.is_open
|
||||
and self._channel is not None
|
||||
and self._channel.is_open
|
||||
)
|
||||
|
||||
def receive(self, timeout_millis: int = 2000) -> Message:
|
||||
"""Receive a message. Raises TimeoutError if none available."""
|
||||
if not self._is_alive():
|
||||
self._connect()
|
||||
|
||||
timeout_seconds = timeout_millis / 1000.0
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
# Check if a message was already delivered
|
||||
try:
|
||||
method, properties, body = self._incoming.get_nowait()
|
||||
return RabbitMQMessage(
|
||||
method, properties, body, self._schema_cls,
|
||||
)
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
# Drive pika's I/O — delivers messages and processes heartbeats
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining > 0:
|
||||
self._connection.process_data_events(
|
||||
time_limit=min(0.1, remaining),
|
||||
)
|
||||
|
||||
raise TimeoutError("No message received within timeout")
|
||||
|
||||
def acknowledge(self, message: Message) -> None:
|
||||
if isinstance(message, RabbitMQMessage) and message._method:
|
||||
self._channel.basic_ack(
|
||||
delivery_tag=message._method.delivery_tag,
|
||||
)
|
||||
|
||||
def negative_acknowledge(self, message: Message) -> None:
|
||||
if isinstance(message, RabbitMQMessage) and message._method:
|
||||
self._channel.basic_nack(
|
||||
delivery_tag=message._method.delivery_tag,
|
||||
requeue=True,
|
||||
)
|
||||
|
||||
def unsubscribe(self) -> None:
|
||||
if self._consumer_tag and self._channel and self._channel.is_open:
|
||||
try:
|
||||
self._channel.basic_cancel(self._consumer_tag)
|
||||
except Exception:
|
||||
pass
|
||||
self._consumer_tag = None
|
||||
|
||||
def close(self) -> None:
|
||||
self.unsubscribe()
|
||||
try:
|
||||
if self._channel and self._channel.is_open:
|
||||
self._channel.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if self._connection and self._connection.is_open:
|
||||
self._connection.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._channel = None
|
||||
self._connection = None
|
||||
|
||||
|
||||
class RabbitMQBackend:
|
||||
"""RabbitMQ pub/sub backend using a topic exchange per topicspace."""
|
||||
|
||||
def __init__(self, host='localhost', port=5672, username='guest',
|
||||
password='guest', vhost='/'):
|
||||
self._connection_params = pika.ConnectionParameters(
|
||||
host=host,
|
||||
port=port,
|
||||
virtual_host=vhost,
|
||||
credentials=pika.PlainCredentials(username, password),
|
||||
)
|
||||
logger.info(f"RabbitMQ backend: {host}:{port} vhost={vhost}")
|
||||
|
||||
def _parse_queue_id(self, queue_id: str) -> tuple[str, str, str, bool]:
|
||||
"""
|
||||
Parse queue identifier into exchange, routing key, and durability.
|
||||
|
||||
Format: class:topicspace:topic
|
||||
Returns: (exchange_name, routing_key, class, durable)
|
||||
"""
|
||||
if ':' not in queue_id:
|
||||
return 'tg', queue_id, 'flow', False
|
||||
|
||||
parts = queue_id.split(':', 2)
|
||||
if len(parts) != 3:
|
||||
raise ValueError(
|
||||
f"Invalid queue format: {queue_id}, "
|
||||
f"expected class:topicspace:topic"
|
||||
)
|
||||
|
||||
cls, topicspace, topic = parts
|
||||
|
||||
if cls in ('flow', 'state'):
|
||||
durable = True
|
||||
elif cls in ('request', 'response'):
|
||||
durable = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid queue class: {cls}, "
|
||||
f"expected flow, request, response, or state"
|
||||
)
|
||||
|
||||
# Exchange per topicspace, routing key includes class
|
||||
exchange_name = topicspace
|
||||
routing_key = f"{cls}.{topic}"
|
||||
|
||||
return exchange_name, routing_key, cls, durable
|
||||
|
||||
# Keep map_queue_name for backward compatibility with tests
|
||||
def map_queue_name(self, queue_id: str) -> tuple[str, bool]:
|
||||
exchange, routing_key, cls, durable = self._parse_queue_id(queue_id)
|
||||
return f"{exchange}.{routing_key}", durable
|
||||
|
||||
def create_producer(self, topic: str, schema: type,
|
||||
**options) -> BackendProducer:
|
||||
exchange, routing_key, cls, durable = self._parse_queue_id(topic)
|
||||
logger.debug(
|
||||
f"Creating producer: exchange={exchange}, "
|
||||
f"routing_key={routing_key}"
|
||||
)
|
||||
return RabbitMQBackendProducer(
|
||||
self._connection_params, exchange, routing_key, durable,
|
||||
)
|
||||
|
||||
def create_consumer(self, topic: str, subscription: str, schema: type,
|
||||
initial_position: str = 'latest',
|
||||
consumer_type: str = 'shared',
|
||||
**options) -> BackendConsumer:
|
||||
"""Create a consumer with a queue bound to the topic exchange.
|
||||
|
||||
consumer_type='shared': Named durable queue. Multiple consumers
|
||||
with the same subscription compete (round-robin).
|
||||
consumer_type='exclusive': Anonymous ephemeral queue. Each
|
||||
consumer gets its own copy of every message (broadcast).
|
||||
"""
|
||||
exchange, routing_key, cls, durable = self._parse_queue_id(topic)
|
||||
|
||||
if consumer_type == 'exclusive' and cls == 'state':
|
||||
# State broadcast: named durable queue per subscriber.
|
||||
# Retains messages so late-starting processors see current state.
|
||||
queue_name = f"{exchange}.{routing_key}.{subscription}"
|
||||
queue_durable = True
|
||||
exclusive = False
|
||||
auto_delete = False
|
||||
elif consumer_type == 'exclusive':
|
||||
# Broadcast: anonymous queue, auto-deleted on disconnect
|
||||
queue_name = ''
|
||||
queue_durable = False
|
||||
exclusive = True
|
||||
auto_delete = True
|
||||
else:
|
||||
# Shared: named queue, competing consumers
|
||||
queue_name = f"{exchange}.{routing_key}.{subscription}"
|
||||
queue_durable = durable
|
||||
exclusive = False
|
||||
auto_delete = False
|
||||
|
||||
logger.debug(
|
||||
f"Creating consumer: exchange={exchange}, "
|
||||
f"routing_key={routing_key}, queue={queue_name or '(anonymous)'}, "
|
||||
f"type={consumer_type}"
|
||||
)
|
||||
|
||||
return RabbitMQBackendConsumer(
|
||||
self._connection_params, exchange, routing_key,
|
||||
queue_name, schema, queue_durable, exclusive, auto_delete,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
115
trustgraph-base/trustgraph/base/serialization.py
Normal file
115
trustgraph-base/trustgraph/base/serialization.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
JSON serialization helpers for dataclass ↔ dict conversion.
|
||||
|
||||
Used by pub/sub backends that use JSON as their wire format.
|
||||
"""
|
||||
|
||||
import types
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Any, get_type_hints
|
||||
|
||||
|
||||
def dataclass_to_dict(obj: Any) -> dict:
|
||||
"""
|
||||
Recursively convert a dataclass to a dictionary, handling None values and bytes.
|
||||
|
||||
None values are excluded from the dictionary (not serialized).
|
||||
Bytes values are decoded as UTF-8 strings for JSON serialization.
|
||||
Handles nested dataclasses, lists, and dictionaries recursively.
|
||||
"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Handle bytes - decode to UTF-8 for JSON serialization
|
||||
if isinstance(obj, bytes):
|
||||
return obj.decode('utf-8')
|
||||
|
||||
# Handle dataclass - convert to dict then recursively process all values
|
||||
if is_dataclass(obj):
|
||||
result = {}
|
||||
for key, value in asdict(obj).items():
|
||||
result[key] = dataclass_to_dict(value) if value is not None else None
|
||||
return result
|
||||
|
||||
# Handle list - recursively process all items
|
||||
if isinstance(obj, list):
|
||||
return [dataclass_to_dict(item) for item in obj]
|
||||
|
||||
# Handle dict - recursively process all values
|
||||
if isinstance(obj, dict):
|
||||
return {k: dataclass_to_dict(v) for k, v in obj.items()}
|
||||
|
||||
# Return primitive types as-is
|
||||
return obj
|
||||
|
||||
|
||||
def dict_to_dataclass(data: dict, cls: type) -> Any:
|
||||
"""
|
||||
Convert a dictionary back to a dataclass instance.
|
||||
|
||||
Handles nested dataclasses and missing fields.
|
||||
Uses get_type_hints() to resolve forward references (string annotations).
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
if not is_dataclass(cls):
|
||||
return data
|
||||
|
||||
# Get field types from the dataclass, resolving forward references
|
||||
# get_type_hints() evaluates string annotations like "Triple | None"
|
||||
try:
|
||||
field_types = get_type_hints(cls)
|
||||
except Exception:
|
||||
# Fallback if get_type_hints fails (shouldn't happen normally)
|
||||
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
|
||||
kwargs = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if key in field_types:
|
||||
field_type = field_types[key]
|
||||
|
||||
# Handle modern union types (X | Y)
|
||||
if isinstance(field_type, types.UnionType):
|
||||
# Check if it's Optional (X | None)
|
||||
if type(None) in field_type.__args__:
|
||||
# Get the non-None type
|
||||
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||
else:
|
||||
kwargs[key] = value
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Check if this is a generic type (list, dict, etc.)
|
||||
elif hasattr(field_type, '__origin__'):
|
||||
# Handle list[T]
|
||||
if field_type.__origin__ == list:
|
||||
item_type = field_type.__args__[0] if field_type.__args__ else None
|
||||
if item_type and is_dataclass(item_type) and isinstance(value, list):
|
||||
kwargs[key] = [
|
||||
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Handle old-style Optional[T] (which is Union[T, None])
|
||||
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
|
||||
# Get the non-None type from Union
|
||||
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
|
||||
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, actual_type)
|
||||
else:
|
||||
kwargs[key] = value
|
||||
else:
|
||||
kwargs[key] = value
|
||||
# Handle direct dataclass fields
|
||||
elif is_dataclass(field_type) and isinstance(value, dict):
|
||||
kwargs[key] = dict_to_dataclass(value, field_type)
|
||||
# Handle bytes fields (UTF-8 encoded strings from JSON)
|
||||
elif field_type == bytes and isinstance(value, str):
|
||||
kwargs[key] = value.encode('utf-8')
|
||||
else:
|
||||
kwargs[key] = value
|
||||
|
||||
return cls(**kwargs)
|
||||
|
|
@ -51,7 +51,7 @@ class Subscriber:
|
|||
topic=self.topic,
|
||||
subscription=self.subscription,
|
||||
schema=self.schema,
|
||||
consumer_type='shared',
|
||||
consumer_type='exclusive',
|
||||
)
|
||||
|
||||
self.task = asyncio.create_task(self.run())
|
||||
|
|
|
|||
|
|
@ -18,9 +18,7 @@ class BaseClient:
|
|||
output_queue=None,
|
||||
input_schema=None,
|
||||
output_schema=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pulsar_api_key=None,
|
||||
listener=None,
|
||||
**pubsub_config,
|
||||
):
|
||||
|
||||
if input_queue == None: raise RuntimeError("Need input_queue")
|
||||
|
|
@ -32,12 +30,7 @@ class BaseClient:
|
|||
subscriber = str(uuid.uuid4())
|
||||
|
||||
# Create backend using factory
|
||||
self.backend = get_pubsub(
|
||||
pulsar_host=pulsar_host,
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
pulsar_listener=listener,
|
||||
pubsub_backend='pulsar'
|
||||
)
|
||||
self.backend = get_pubsub(**pubsub_config)
|
||||
|
||||
self.producer = self.backend.create_producer(
|
||||
topic=input_queue,
|
||||
|
|
|
|||
|
|
@ -33,9 +33,7 @@ class ConfigClient(BaseClient):
|
|||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
listener=None,
|
||||
pulsar_api_key=None,
|
||||
**pubsub_config,
|
||||
):
|
||||
|
||||
if input_queue == None:
|
||||
|
|
@ -48,11 +46,9 @@ class ConfigClient(BaseClient):
|
|||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
pulsar_host=pulsar_host,
|
||||
pulsar_api_key=pulsar_api_key,
|
||||
input_schema=ConfigRequest,
|
||||
output_schema=ConfigResponse,
|
||||
listener=listener,
|
||||
**pubsub_config,
|
||||
)
|
||||
|
||||
def get(self, keys, timeout=300):
|
||||
|
|
|
|||
|
|
@ -24,10 +24,13 @@ from ..core.metadata import Metadata
|
|||
# <- (document_metadata)
|
||||
# <- (error)
|
||||
|
||||
# get-document-content
|
||||
# get-document-content [DEPRECATED — use stream-document instead]
|
||||
# -> (document_id)
|
||||
# <- (content)
|
||||
# <- (error)
|
||||
# NOTE: Returns entire document in a single message. Fails for documents
|
||||
# exceeding the broker's max message size. Use stream-document which
|
||||
# returns content in chunks.
|
||||
|
||||
# add-processing
|
||||
# -> (processing_id, processing_metadata)
|
||||
|
|
@ -220,5 +223,5 @@ class LibrarianResponse:
|
|||
# FIXME: Is this right? Using persistence on librarian so that
|
||||
# message chunking works
|
||||
|
||||
librarian_request_queue = queue('librarian-request', cls='flow')
|
||||
librarian_response_queue = queue('librarian-response', cls='flow')
|
||||
librarian_request_queue = queue('librarian', cls='request')
|
||||
librarian_response_queue = queue('librarian', cls='response')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue