RabbitMQ pub/sub backend with topic exchange architecture (#752)

Adds a RabbitMQ backend as an alternative to Pulsar, selectable via
PUBSUB_BACKEND=rabbitmq. Both backends implement the same PubSubBackend
protocol — no application code changes needed to switch.

RabbitMQ topology:
- Single topic exchange per topicspace (e.g. 'tg')
- Routing key derived from queue class and topic name
- Shared consumers: named queue bound to exchange (competing, round-robin)
- Exclusive consumers: anonymous auto-delete queue (broadcast, each gets
  every message). Used by Subscriber and config push consumer.
- Thread-local producer connections (pika is not thread-safe)
- Push-based consumption via basic_consume with process_data_events
  for heartbeat processing

Consumer model changes:
- Consumer class creates one backend consumer per concurrent task
  (required for pika thread safety, harmless for Pulsar)
- Consumer class accepts consumer_type parameter
- Subscriber passes consumer_type='exclusive' for broadcast semantics
- Config push consumer uses consumer_type='exclusive' so every
  processor instance receives config updates
- handle_one_from_queue receives consumer as parameter for correct
  per-connection ack/nack

LibrarianClient:
- New shared client class replacing duplicated librarian request-response
  code across 6+ services (chunking, decoders, RAG, etc.)
- Uses stream-document instead of get-document-content for fetching
  document content in 1MB chunks (avoids broker message size limits)
- Standalone object (self.librarian = LibrarianClient(...)) not a mixin
- get-document-content marked deprecated in schema and OpenAPI spec

Serialisation:
- Extracted dataclass_to_dict/dict_to_dataclass to shared
  serialization.py (used by both Pulsar and RabbitMQ backends)

Librarian queues:
- Changed from flow class (persistent) back to request/response class
  now that stream-document eliminates large single messages
- API upload chunk size reduced from 5MB to 3MB to stay under broker
  limits after base64 encoding

Factory and CLI:
- get_pubsub() handles 'rabbitmq' backend with RabbitMQ connection params
- add_pubsub_args() includes RabbitMQ options (host, port, credentials)
- add_pubsub_args(standalone=True) defaults to localhost for CLI tools
- init_trustgraph skips Pulsar admin setup for non-Pulsar backends
- tg-dump-queues and tg-monitor-prompts use backend abstraction
- BaseClient and ConfigClient accept generic pubsub config
This commit is contained in:
cybermaggedon 2026-04-02 12:47:16 +01:00 committed by GitHub
parent 4fb0b4d8e8
commit 24f0190ce7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 1277 additions and 1313 deletions

View file

@ -14,6 +14,7 @@ from . producer_spec import ProducerSpec
from . subscriber_spec import SubscriberSpec
from . request_response_spec import RequestResponseSpec
from . llm_service import LlmService, LlmResult, LlmChunk
from . librarian_client import LibrarianClient
from . chunking_service import ChunkingService
from . embeddings_service import EmbeddingsService
from . embeddings_client import EmbeddingsClientSpec

View file

@ -68,11 +68,12 @@ class AsyncProcessor:
processor = self.id, flow = None, name = "config",
)
# Subscribe to config queue
# Subscribe to config queue — exclusive so every processor
# gets its own copy of config pushes (broadcast pattern)
self.config_sub_task = Consumer(
taskgroup = self.taskgroup,
backend = self.pubsub_backend, # Changed from client to backend
backend = self.pubsub_backend,
subscriber = config_subscriber_id,
flow = None,
@ -83,9 +84,8 @@ class AsyncProcessor:
metrics = config_consumer_metrics,
# This causes new subscriptions to view the entire history of
# configuration
start_of_messages = True
start_of_messages = True,
consumer_type = 'exclusive',
)
self.running = True

View file

@ -7,23 +7,14 @@ fetching large document content.
import asyncio
import base64
import logging
import uuid
from .flow_processor import FlowProcessor
from .parameter_spec import ParameterSpec
from .consumer import Consumer
from .producer import Producer
from .metrics import ConsumerMetrics, ProducerMetrics
from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ..schema import librarian_request_queue, librarian_response_queue
from .librarian_client import LibrarianClient
# Module logger
logger = logging.getLogger(__name__)
default_librarian_request_queue = librarian_request_queue
default_librarian_response_queue = librarian_response_queue
class ChunkingService(FlowProcessor):
"""Base service for chunking processors with parameter specification support"""
@ -44,155 +35,18 @@ class ChunkingService(FlowProcessor):
ParameterSpec(name="chunk-overlap")
)
# Librarian client for fetching document content
librarian_request_q = params.get(
"librarian_request_queue", default_librarian_request_queue
)
librarian_response_q = params.get(
"librarian_response_queue", default_librarian_response_queue
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request"
)
self.librarian_request_producer = Producer(
# Librarian client
self.librarian = LibrarianClient(
id=id,
backend=self.pubsub,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response"
)
self.librarian_response_consumer = Consumer(
taskgroup=self.taskgroup,
backend=self.pubsub,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self.on_librarian_response,
metrics=librarian_response_metrics,
)
# Pending librarian requests: request_id -> asyncio.Future
self.pending_requests = {}
logger.debug("ChunkingService initialized with parameter specifications")
async def start(self):
await super(ChunkingService, self).start()
await self.librarian_request_producer.start()
await self.librarian_response_consumer.start()
async def on_librarian_response(self, msg, consumer, flow):
"""Handle responses from the librarian service."""
response = msg.value()
request_id = msg.properties().get("id")
if request_id and request_id in self.pending_requests:
future = self.pending_requests.pop(request_id)
future.set_result(response)
async def fetch_document_content(self, document_id, user, timeout=120):
"""
Fetch document content from librarian via Pulsar.
"""
request_id = str(uuid.uuid4())
request = LibrarianRequest(
operation="get-document-content",
document_id=document_id,
user=user,
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: {response.error.message}"
)
return response.content
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout fetching document {document_id}")
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="chunk", title=None, timeout=120):
"""
Save a child document (chunk) to the librarian.
Args:
doc_id: ID for the new child document
parent_id: ID of the parent document
user: User ID
content: Document content (bytes or str)
document_type: Type of document ("chunk", etc.)
title: Optional title
timeout: Request timeout in seconds
Returns:
The document ID on success
"""
request_id = str(uuid.uuid4())
if isinstance(content, str):
content = content.encode("utf-8")
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind="text/plain",
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
request = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
# Create future for response
future = asyncio.get_event_loop().create_future()
self.pending_requests[request_id] = future
try:
# Send request
await self.librarian_request_producer.send(
request, properties={"id": request_id}
)
# Wait for response
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error saving chunk: {response.error.type}: {response.error.message}"
)
return doc_id
except asyncio.TimeoutError:
self.pending_requests.pop(request_id, None)
raise RuntimeError(f"Timeout saving chunk {doc_id}")
await self.librarian.start()
async def get_document_text(self, doc):
"""
@ -206,14 +60,10 @@ class ChunkingService(FlowProcessor):
"""
if doc.document_id and not doc.text:
logger.info(f"Fetching document {doc.document_id} from librarian...")
content = await self.fetch_document_content(
text = await self.librarian.fetch_document_text(
document_id=doc.document_id,
user=doc.metadata.user,
)
# Content is base64 encoded
if isinstance(content, str):
content = content.encode('utf-8')
text = base64.b64decode(content).decode("utf-8")
logger.info(f"Fetched {len(text)} characters from librarian")
return text
else:
@ -224,41 +74,31 @@ class ChunkingService(FlowProcessor):
Extract chunk parameters from flow and return effective values
Args:
msg: The message containing the document to chunk
consumer: The consumer spec
flow: The flow context
default_chunk_size: Default chunk size from processor config
default_chunk_overlap: Default chunk overlap from processor config
msg: The message being processed
consumer: The consumer instance
flow: The flow object containing parameters
default_chunk_size: Default chunk size if not configured
default_chunk_overlap: Default chunk overlap if not configured
Returns:
tuple: (chunk_size, chunk_overlap) - effective values to use
tuple: (chunk_size, chunk_overlap) effective values
"""
# Extract parameters from flow (flow-configurable parameters)
chunk_size = flow("chunk-size")
chunk_overlap = flow("chunk-overlap")
# Use provided values or fall back to defaults
effective_chunk_size = chunk_size if chunk_size is not None else default_chunk_size
effective_chunk_overlap = chunk_overlap if chunk_overlap is not None else default_chunk_overlap
chunk_size = default_chunk_size
chunk_overlap = default_chunk_overlap
logger.debug(f"Using chunk-size: {effective_chunk_size}")
logger.debug(f"Using chunk-overlap: {effective_chunk_overlap}")
try:
cs = flow.parameters.get("chunk-size")
if cs is not None:
chunk_size = int(cs)
except Exception as e:
logger.warning(f"Could not parse chunk-size parameter: {e}")
return effective_chunk_size, effective_chunk_overlap
try:
co = flow.parameters.get("chunk-overlap")
if co is not None:
chunk_overlap = int(co)
except Exception as e:
logger.warning(f"Could not parse chunk-overlap parameter: {e}")
@staticmethod
def add_args(parser):
"""Add chunking service arguments to parser"""
FlowProcessor.add_args(parser)
parser.add_argument(
'--librarian-request-queue',
default=default_librarian_request_queue,
help=f'Librarian request queue (default: {default_librarian_request_queue})',
)
parser.add_argument(
'--librarian-response-queue',
default=default_librarian_response_queue,
help=f'Librarian response queue (default: {default_librarian_response_queue})',
)
return chunk_size, chunk_overlap

View file

@ -32,6 +32,7 @@ class Consumer:
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
reconnect_time = 5,
concurrency = 1, # Number of concurrent requests to handle
consumer_type = 'shared',
):
self.taskgroup = taskgroup
@ -42,6 +43,8 @@ class Consumer:
self.schema = schema
self.handler = handler
self.consumer_type = consumer_type
self.rate_limit_retry_time = rate_limit_retry_time
self.rate_limit_timeout = rate_limit_timeout
@ -93,33 +96,11 @@ class Consumer:
if self.metrics:
self.metrics.state("stopped")
try:
logger.info(f"Subscribing to topic: {self.topic}")
# Determine initial position
if self.start_of_messages:
initial_pos = 'earliest'
else:
initial_pos = 'latest'
# Create consumer via backend
self.consumer = await asyncio.to_thread(
self.backend.create_consumer,
topic = self.topic,
subscription = self.subscriber,
schema = self.schema,
initial_position = initial_pos,
consumer_type = 'shared',
)
except Exception as e:
logger.error(f"Consumer subscription exception: {e}", exc_info=True)
await asyncio.sleep(self.reconnect_time)
continue
logger.info(f"Successfully subscribed to topic: {self.topic}")
# Determine initial position
if self.start_of_messages:
initial_pos = 'earliest'
else:
initial_pos = 'latest'
if self.metrics:
self.metrics.state("running")
@ -128,14 +109,30 @@ class Consumer:
logger.info(f"Starting {self.concurrency} receiver threads")
async with asyncio.TaskGroup() as tg:
tasks = []
for i in range(0, self.concurrency):
tasks.append(
tg.create_task(self.consume_from_queue())
# Create one backend consumer per concurrent task.
# Each gets its own connection — required for backends
# like RabbitMQ where connections are not thread-safe.
consumers = []
for i in range(self.concurrency):
try:
logger.info(f"Subscribing to topic: {self.topic} (worker {i})")
c = await asyncio.to_thread(
self.backend.create_consumer,
topic = self.topic,
subscription = self.subscriber,
schema = self.schema,
initial_position = initial_pos,
consumer_type = self.consumer_type,
)
consumers.append(c)
logger.info(f"Successfully subscribed to topic: {self.topic} (worker {i})")
except Exception as e:
logger.error(f"Consumer subscription exception (worker {i}): {e}", exc_info=True)
raise
async with asyncio.TaskGroup() as tg:
for c in consumers:
tg.create_task(self.consume_from_queue(c))
if self.metrics:
self.metrics.state("stopped")
@ -143,23 +140,31 @@ class Consumer:
except Exception as e:
logger.error(f"Consumer loop exception: {e}", exc_info=True)
self.consumer.unsubscribe()
self.consumer.close()
self.consumer = None
for c in consumers:
try:
c.unsubscribe()
c.close()
except Exception:
pass
consumers = []
await asyncio.sleep(self.reconnect_time)
continue
if self.consumer:
self.consumer.unsubscribe()
self.consumer.close()
finally:
for c in consumers:
try:
c.unsubscribe()
c.close()
except Exception:
pass
async def consume_from_queue(self):
async def consume_from_queue(self, consumer):
while self.running:
try:
msg = await asyncio.to_thread(
self.consumer.receive,
consumer.receive,
timeout_millis=2000
)
except Exception as e:
@ -168,9 +173,9 @@ class Consumer:
continue
raise e
await self.handle_one_from_queue(msg)
await self.handle_one_from_queue(msg, consumer)
async def handle_one_from_queue(self, msg):
async def handle_one_from_queue(self, msg, consumer):
expiry = time.time() + self.rate_limit_timeout
@ -183,7 +188,7 @@ class Consumer:
# Message failed to be processed, this causes it to
# be retried
self.consumer.negative_acknowledge(msg)
consumer.negative_acknowledge(msg)
if self.metrics:
self.metrics.process("error")
@ -206,7 +211,7 @@ class Consumer:
logger.debug("Message processed successfully")
# Acknowledge successful processing of the message
self.consumer.acknowledge(msg)
consumer.acknowledge(msg)
if self.metrics:
self.metrics.process("success")
@ -233,7 +238,7 @@ class Consumer:
# Message failed to be processed, this causes it to
# be retried
self.consumer.negative_acknowledge(msg)
consumer.negative_acknowledge(msg)
if self.metrics:
self.metrics.process("error")

View file

@ -0,0 +1,246 @@
"""
Shared librarian client for services that need to communicate
with the librarian via pub/sub.
Provides request-response and streaming operations over the message
broker, with proper support for large documents via stream-document.
Usage:
self.librarian = LibrarianClient(
id=id, backend=self.pubsub, taskgroup=self.taskgroup, **params
)
await self.librarian.start()
content = await self.librarian.fetch_document_content(doc_id, user)
"""
import asyncio
import base64
import logging
import uuid
from .consumer import Consumer
from .producer import Producer
from .metrics import ConsumerMetrics, ProducerMetrics
from ..schema import LibrarianRequest, LibrarianResponse, DocumentMetadata
from ..schema import librarian_request_queue, librarian_response_queue
logger = logging.getLogger(__name__)
class LibrarianClient:
"""Client for librarian request-response over the message broker."""
def __init__(self, id, backend, taskgroup, **params):
librarian_request_q = params.get(
"librarian_request_queue", librarian_request_queue,
)
librarian_response_q = params.get(
"librarian_response_queue", librarian_response_queue,
)
librarian_request_metrics = ProducerMetrics(
processor=id, flow=None, name="librarian-request",
)
self._producer = Producer(
backend=backend,
topic=librarian_request_q,
schema=LibrarianRequest,
metrics=librarian_request_metrics,
)
librarian_response_metrics = ConsumerMetrics(
processor=id, flow=None, name="librarian-response",
)
self._consumer = Consumer(
taskgroup=taskgroup,
backend=backend,
flow=None,
topic=librarian_response_q,
subscriber=f"{id}-librarian",
schema=LibrarianResponse,
handler=self._on_response,
metrics=librarian_response_metrics,
consumer_type='exclusive',
)
# Single-response requests: request_id -> asyncio.Future
self._pending = {}
# Streaming requests: request_id -> asyncio.Queue
self._streams = {}
async def start(self):
"""Start the librarian producer and consumer."""
await self._producer.start()
await self._consumer.start()
async def _on_response(self, msg, consumer, flow):
"""Route librarian responses to the right waiter."""
response = msg.value()
request_id = msg.properties().get("id")
if not request_id:
return
if request_id in self._pending:
future = self._pending.pop(request_id)
future.set_result(response)
elif request_id in self._streams:
await self._streams[request_id].put(response)
async def request(self, request, timeout=120):
"""Send a request to the librarian and wait for a single response."""
request_id = str(uuid.uuid4())
future = asyncio.get_event_loop().create_future()
self._pending[request_id] = future
try:
await self._producer.send(
request, properties={"id": request_id},
)
response = await asyncio.wait_for(future, timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: "
f"{response.error.message}"
)
return response
except asyncio.TimeoutError:
self._pending.pop(request_id, None)
raise RuntimeError("Timeout waiting for librarian response")
async def stream(self, request, timeout=120):
"""Send a request and collect streamed response chunks."""
request_id = str(uuid.uuid4())
q = asyncio.Queue()
self._streams[request_id] = q
try:
await self._producer.send(
request, properties={"id": request_id},
)
chunks = []
while True:
response = await asyncio.wait_for(q.get(), timeout=timeout)
if response.error:
raise RuntimeError(
f"Librarian error: {response.error.type}: "
f"{response.error.message}"
)
chunks.append(response)
if response.is_final:
break
return chunks
except asyncio.TimeoutError:
self._streams.pop(request_id, None)
raise RuntimeError("Timeout waiting for librarian stream")
finally:
self._streams.pop(request_id, None)
async def fetch_document_content(self, document_id, user, timeout=120):
"""Fetch document content using streaming.
Returns base64-encoded content. Caller is responsible for decoding.
"""
req = LibrarianRequest(
operation="stream-document",
document_id=document_id,
user=user,
)
chunks = await self.stream(req, timeout=timeout)
# Decode each chunk's base64 to raw bytes, concatenate,
# re-encode for the caller.
raw = b""
for chunk in chunks:
if chunk.content:
if isinstance(chunk.content, bytes):
raw += base64.b64decode(chunk.content)
else:
raw += base64.b64decode(
chunk.content.encode("utf-8")
)
return base64.b64encode(raw)
async def fetch_document_text(self, document_id, user, timeout=120):
"""Fetch document content and decode as UTF-8 text."""
content = await self.fetch_document_content(
document_id, user, timeout=timeout,
)
return base64.b64decode(content).decode("utf-8")
async def fetch_document_metadata(self, document_id, user, timeout=120):
"""Fetch document metadata from the librarian."""
req = LibrarianRequest(
operation="get-document-metadata",
document_id=document_id,
user=user,
)
response = await self.request(req, timeout=timeout)
return response.document_metadata
async def save_child_document(self, doc_id, parent_id, user, content,
document_type="chunk", title=None,
kind="text/plain", timeout=120):
"""Save a child document to the librarian."""
if isinstance(content, str):
content = content.encode("utf-8")
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind=kind,
title=title or doc_id,
parent_id=parent_id,
document_type=document_type,
)
req = LibrarianRequest(
operation="add-child-document",
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
)
await self.request(req, timeout=timeout)
return doc_id
async def save_document(self, doc_id, user, content, title=None,
document_type="answer", kind="text/plain",
timeout=120):
"""Save a document to the librarian."""
if isinstance(content, str):
content = content.encode("utf-8")
doc_metadata = DocumentMetadata(
id=doc_id,
user=user,
kind=kind,
title=title or doc_id,
document_type=document_type,
)
req = LibrarianRequest(
operation="add-document",
document_id=doc_id,
document_metadata=doc_metadata,
content=base64.b64encode(content).decode("utf-8"),
user=user,
)
await self.request(req, timeout=timeout)
return doc_id

View file

@ -8,6 +8,12 @@ logger = logging.getLogger(__name__)
DEFAULT_PULSAR_HOST = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
DEFAULT_PULSAR_API_KEY = os.getenv("PULSAR_API_KEY", None)
DEFAULT_RABBITMQ_HOST = os.getenv("RABBITMQ_HOST", 'rabbitmq')
DEFAULT_RABBITMQ_PORT = int(os.getenv("RABBITMQ_PORT", '5672'))
DEFAULT_RABBITMQ_USERNAME = os.getenv("RABBITMQ_USERNAME", 'guest')
DEFAULT_RABBITMQ_PASSWORD = os.getenv("RABBITMQ_PASSWORD", 'guest')
DEFAULT_RABBITMQ_VHOST = os.getenv("RABBITMQ_VHOST", '/')
def get_pubsub(**config):
"""
@ -29,6 +35,15 @@ def get_pubsub(**config):
api_key=config.get('pulsar_api_key', DEFAULT_PULSAR_API_KEY),
listener=config.get('pulsar_listener'),
)
elif backend_type == 'rabbitmq':
from .rabbitmq_backend import RabbitMQBackend
return RabbitMQBackend(
host=config.get('rabbitmq_host', DEFAULT_RABBITMQ_HOST),
port=config.get('rabbitmq_port', DEFAULT_RABBITMQ_PORT),
username=config.get('rabbitmq_username', DEFAULT_RABBITMQ_USERNAME),
password=config.get('rabbitmq_password', DEFAULT_RABBITMQ_PASSWORD),
vhost=config.get('rabbitmq_vhost', DEFAULT_RABBITMQ_VHOST),
)
else:
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
@ -44,8 +59,9 @@ def add_pubsub_args(parser, standalone=False):
standalone: If True, default host is localhost (for CLI tools
that run outside containers)
"""
host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST
listener_default = 'localhost' if standalone else None
pulsar_host = STANDALONE_PULSAR_HOST if standalone else DEFAULT_PULSAR_HOST
pulsar_listener = 'localhost' if standalone else None
rabbitmq_host = 'localhost' if standalone else DEFAULT_RABBITMQ_HOST
parser.add_argument(
'--pubsub-backend',
@ -53,10 +69,11 @@ def add_pubsub_args(parser, standalone=False):
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
)
# Pulsar options
parser.add_argument(
'-p', '--pulsar-host',
default=host,
help=f'Pulsar host (default: {host})',
default=pulsar_host,
help=f'Pulsar host (default: {pulsar_host})',
)
parser.add_argument(
@ -67,6 +84,38 @@ def add_pubsub_args(parser, standalone=False):
parser.add_argument(
'--pulsar-listener',
default=listener_default,
help=f'Pulsar listener (default: {listener_default or "none"})',
default=pulsar_listener,
help=f'Pulsar listener (default: {pulsar_listener or "none"})',
)
# RabbitMQ options
parser.add_argument(
'--rabbitmq-host',
default=rabbitmq_host,
help=f'RabbitMQ host (default: {rabbitmq_host})',
)
parser.add_argument(
'--rabbitmq-port',
type=int,
default=DEFAULT_RABBITMQ_PORT,
help=f'RabbitMQ port (default: {DEFAULT_RABBITMQ_PORT})',
)
parser.add_argument(
'--rabbitmq-username',
default=DEFAULT_RABBITMQ_USERNAME,
help='RabbitMQ username',
)
parser.add_argument(
'--rabbitmq-password',
default=DEFAULT_RABBITMQ_PASSWORD,
help='RabbitMQ password',
)
parser.add_argument(
'--rabbitmq-vhost',
default=DEFAULT_RABBITMQ_VHOST,
help=f'RabbitMQ vhost (default: {DEFAULT_RABBITMQ_VHOST})',
)

View file

@ -9,122 +9,14 @@ import pulsar
import _pulsar
import json
import logging
import base64
import types
from dataclasses import asdict, is_dataclass
from typing import Any, get_type_hints
from typing import Any
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
from .serialization import dataclass_to_dict, dict_to_dataclass
logger = logging.getLogger(__name__)
def dataclass_to_dict(obj: Any) -> dict:
"""
Recursively convert a dataclass to a dictionary, handling None values and bytes.
None values are excluded from the dictionary (not serialized).
Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior).
Handles nested dataclasses, lists, and dictionaries recursively.
"""
if obj is None:
return None
# Handle bytes - decode to UTF-8 for JSON serialization
if isinstance(obj, bytes):
return obj.decode('utf-8')
# Handle dataclass - convert to dict then recursively process all values
if is_dataclass(obj):
result = {}
for key, value in asdict(obj).items():
result[key] = dataclass_to_dict(value) if value is not None else None
return result
# Handle list - recursively process all items
if isinstance(obj, list):
return [dataclass_to_dict(item) for item in obj]
# Handle dict - recursively process all values
if isinstance(obj, dict):
return {k: dataclass_to_dict(v) for k, v in obj.items()}
# Return primitive types as-is
return obj
def dict_to_dataclass(data: dict, cls: type) -> Any:
"""
Convert a dictionary back to a dataclass instance.
Handles nested dataclasses and missing fields.
Uses get_type_hints() to resolve forward references (string annotations).
"""
if data is None:
return None
if not is_dataclass(cls):
return data
# Get field types from the dataclass, resolving forward references
# get_type_hints() evaluates string annotations like "Triple | None"
try:
field_types = get_type_hints(cls)
except Exception:
# Fallback if get_type_hints fails (shouldn't happen normally)
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
kwargs = {}
for key, value in data.items():
if key in field_types:
field_type = field_types[key]
# Handle modern union types (X | Y)
if isinstance(field_type, types.UnionType):
# Check if it's Optional (X | None)
if type(None) in field_type.__args__:
# Get the non-None type
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Check if this is a generic type (list, dict, etc.)
elif hasattr(field_type, '__origin__'):
# Handle list[T]
if field_type.__origin__ == list:
item_type = field_type.__args__[0] if field_type.__args__ else None
if item_type and is_dataclass(item_type) and isinstance(value, list):
kwargs[key] = [
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
for item in value
]
else:
kwargs[key] = value
# Handle old-style Optional[T] (which is Union[T, None])
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
# Get the non-None type from Union
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Handle direct dataclass fields
elif is_dataclass(field_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, field_type)
# Handle bytes fields (UTF-8 encoded strings from JSON)
elif field_type == bytes and isinstance(value, str):
kwargs[key] = value.encode('utf-8')
else:
kwargs[key] = value
return cls(**kwargs)
class PulsarMessage:
"""Wrapper for Pulsar messages to match Message protocol."""

View file

@ -0,0 +1,390 @@
"""
RabbitMQ backend implementation for pub/sub abstraction.
Uses a single topic exchange per topicspace. The logical queue name
becomes the routing key. Consumer behavior is determined by the
subscription name:
- Same subscription + same topic = shared queue (competing consumers)
- Different subscriptions = separate queues (broadcast / fan-out)
This mirrors Pulsar's subscription model using idiomatic RabbitMQ.
Architecture:
Producer --> [tg exchange] --routing key--> [named queue] --> Consumer
--routing key--> [named queue] --> Consumer
--routing key--> [exclusive q] --> Subscriber
Uses basic_consume (push) instead of basic_get (polling) for
efficient message delivery.
"""
import json
import time
import logging
import queue
import threading
import pika
from typing import Any
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
from .serialization import dataclass_to_dict, dict_to_dataclass
logger = logging.getLogger(__name__)
class RabbitMQMessage:
"""Wrapper for RabbitMQ messages to match Message protocol."""
def __init__(self, method, properties, body, schema_cls):
self._method = method
self._properties = properties
self._body = body
self._schema_cls = schema_cls
self._value = None
def value(self) -> Any:
"""Deserialize and return the message value as a dataclass."""
if self._value is None:
data_dict = json.loads(self._body.decode('utf-8'))
self._value = dict_to_dataclass(data_dict, self._schema_cls)
return self._value
def properties(self) -> dict:
"""Return message properties from AMQP headers."""
headers = self._properties.headers or {}
return dict(headers)
class RabbitMQBackendProducer:
"""Publishes messages to a topic exchange with a routing key.
Uses thread-local connections so each thread gets its own
connection/channel. This avoids wire corruption from concurrent
threads writing to the same socket (pika is not thread-safe).
"""
def __init__(self, connection_params, exchange_name, routing_key,
durable):
self._connection_params = connection_params
self._exchange_name = exchange_name
self._routing_key = routing_key
self._durable = durable
self._local = threading.local()
def _get_channel(self):
"""Get or create a thread-local connection and channel."""
conn = getattr(self._local, 'connection', None)
chan = getattr(self._local, 'channel', None)
if conn is None or not conn.is_open or chan is None or not chan.is_open:
# Close stale connection if any
if conn is not None:
try:
conn.close()
except Exception:
pass
conn = pika.BlockingConnection(self._connection_params)
chan = conn.channel()
chan.exchange_declare(
exchange=self._exchange_name,
exchange_type='topic',
durable=True,
)
self._local.connection = conn
self._local.channel = chan
return chan
def send(self, message: Any, properties: dict = {}) -> None:
data_dict = dataclass_to_dict(message)
json_data = json.dumps(data_dict)
amqp_properties = pika.BasicProperties(
delivery_mode=2 if self._durable else 1,
content_type='application/json',
headers=properties if properties else None,
)
for attempt in range(2):
try:
channel = self._get_channel()
channel.basic_publish(
exchange=self._exchange_name,
routing_key=self._routing_key,
body=json_data.encode('utf-8'),
properties=amqp_properties,
)
return
except Exception as e:
logger.warning(
f"RabbitMQ send failed (attempt {attempt + 1}): {e}"
)
# Force reconnect on next attempt
self._local.connection = None
self._local.channel = None
if attempt == 1:
raise
def flush(self) -> None:
pass
def close(self) -> None:
"""Close the thread-local connection if any."""
conn = getattr(self._local, 'connection', None)
if conn is not None:
try:
conn.close()
except Exception:
pass
self._local.connection = None
self._local.channel = None
class RabbitMQBackendConsumer:
"""Consumes from a queue bound to a topic exchange.
Uses basic_consume (push model) with messages delivered to an
internal thread-safe queue. process_data_events() drives both
message delivery and heartbeat processing.
"""
def __init__(self, connection_params, exchange_name, routing_key,
queue_name, schema_cls, durable, exclusive=False,
auto_delete=False):
self._connection_params = connection_params
self._exchange_name = exchange_name
self._routing_key = routing_key
self._queue_name = queue_name
self._schema_cls = schema_cls
self._durable = durable
self._exclusive = exclusive
self._auto_delete = auto_delete
self._connection = None
self._channel = None
self._consumer_tag = None
self._incoming = queue.Queue()
def _connect(self):
self._connection = pika.BlockingConnection(self._connection_params)
self._channel = self._connection.channel()
# Declare the topic exchange
self._channel.exchange_declare(
exchange=self._exchange_name,
exchange_type='topic',
durable=True,
)
# Declare the queue — anonymous if exclusive
result = self._channel.queue_declare(
queue=self._queue_name,
durable=self._durable,
exclusive=self._exclusive,
auto_delete=self._auto_delete,
)
# Capture actual name (important for anonymous queues where name='')
self._queue_name = result.method.queue
self._channel.queue_bind(
queue=self._queue_name,
exchange=self._exchange_name,
routing_key=self._routing_key,
)
self._channel.basic_qos(prefetch_count=1)
# Register push-based consumer
self._consumer_tag = self._channel.basic_consume(
queue=self._queue_name,
on_message_callback=self._on_message,
auto_ack=False,
)
def _on_message(self, channel, method, properties, body):
"""Callback invoked by pika when a message arrives."""
self._incoming.put((method, properties, body))
def _is_alive(self):
return (
self._connection is not None
and self._connection.is_open
and self._channel is not None
and self._channel.is_open
)
def receive(self, timeout_millis: int = 2000) -> Message:
"""Receive a message. Raises TimeoutError if none available."""
if not self._is_alive():
self._connect()
timeout_seconds = timeout_millis / 1000.0
deadline = time.monotonic() + timeout_seconds
while time.monotonic() < deadline:
# Check if a message was already delivered
try:
method, properties, body = self._incoming.get_nowait()
return RabbitMQMessage(
method, properties, body, self._schema_cls,
)
except queue.Empty:
pass
# Drive pika's I/O — delivers messages and processes heartbeats
remaining = deadline - time.monotonic()
if remaining > 0:
self._connection.process_data_events(
time_limit=min(0.1, remaining),
)
raise TimeoutError("No message received within timeout")
def acknowledge(self, message: Message) -> None:
if isinstance(message, RabbitMQMessage) and message._method:
self._channel.basic_ack(
delivery_tag=message._method.delivery_tag,
)
def negative_acknowledge(self, message: Message) -> None:
if isinstance(message, RabbitMQMessage) and message._method:
self._channel.basic_nack(
delivery_tag=message._method.delivery_tag,
requeue=True,
)
def unsubscribe(self) -> None:
if self._consumer_tag and self._channel and self._channel.is_open:
try:
self._channel.basic_cancel(self._consumer_tag)
except Exception:
pass
self._consumer_tag = None
def close(self) -> None:
self.unsubscribe()
try:
if self._channel and self._channel.is_open:
self._channel.close()
except Exception:
pass
try:
if self._connection and self._connection.is_open:
self._connection.close()
except Exception:
pass
self._channel = None
self._connection = None
class RabbitMQBackend:
"""RabbitMQ pub/sub backend using a topic exchange per topicspace."""
def __init__(self, host='localhost', port=5672, username='guest',
password='guest', vhost='/'):
self._connection_params = pika.ConnectionParameters(
host=host,
port=port,
virtual_host=vhost,
credentials=pika.PlainCredentials(username, password),
)
logger.info(f"RabbitMQ backend: {host}:{port} vhost={vhost}")
def _parse_queue_id(self, queue_id: str) -> tuple[str, str, str, bool]:
"""
Parse queue identifier into exchange, routing key, and durability.
Format: class:topicspace:topic
Returns: (exchange_name, routing_key, class, durable)
"""
if ':' not in queue_id:
return 'tg', queue_id, 'flow', False
parts = queue_id.split(':', 2)
if len(parts) != 3:
raise ValueError(
f"Invalid queue format: {queue_id}, "
f"expected class:topicspace:topic"
)
cls, topicspace, topic = parts
if cls in ('flow', 'state'):
durable = True
elif cls in ('request', 'response'):
durable = False
else:
raise ValueError(
f"Invalid queue class: {cls}, "
f"expected flow, request, response, or state"
)
# Exchange per topicspace, routing key includes class
exchange_name = topicspace
routing_key = f"{cls}.{topic}"
return exchange_name, routing_key, cls, durable
# Keep map_queue_name for backward compatibility with tests
def map_queue_name(self, queue_id: str) -> tuple[str, bool]:
exchange, routing_key, cls, durable = self._parse_queue_id(queue_id)
return f"{exchange}.{routing_key}", durable
def create_producer(self, topic: str, schema: type,
**options) -> BackendProducer:
exchange, routing_key, cls, durable = self._parse_queue_id(topic)
logger.debug(
f"Creating producer: exchange={exchange}, "
f"routing_key={routing_key}"
)
return RabbitMQBackendProducer(
self._connection_params, exchange, routing_key, durable,
)
def create_consumer(self, topic: str, subscription: str, schema: type,
initial_position: str = 'latest',
consumer_type: str = 'shared',
**options) -> BackendConsumer:
"""Create a consumer with a queue bound to the topic exchange.
consumer_type='shared': Named durable queue. Multiple consumers
with the same subscription compete (round-robin).
consumer_type='exclusive': Anonymous ephemeral queue. Each
consumer gets its own copy of every message (broadcast).
"""
exchange, routing_key, cls, durable = self._parse_queue_id(topic)
if consumer_type == 'exclusive' and cls == 'state':
# State broadcast: named durable queue per subscriber.
# Retains messages so late-starting processors see current state.
queue_name = f"{exchange}.{routing_key}.{subscription}"
queue_durable = True
exclusive = False
auto_delete = False
elif consumer_type == 'exclusive':
# Broadcast: anonymous queue, auto-deleted on disconnect
queue_name = ''
queue_durable = False
exclusive = True
auto_delete = True
else:
# Shared: named queue, competing consumers
queue_name = f"{exchange}.{routing_key}.{subscription}"
queue_durable = durable
exclusive = False
auto_delete = False
logger.debug(
f"Creating consumer: exchange={exchange}, "
f"routing_key={routing_key}, queue={queue_name or '(anonymous)'}, "
f"type={consumer_type}"
)
return RabbitMQBackendConsumer(
self._connection_params, exchange, routing_key,
queue_name, schema, queue_durable, exclusive, auto_delete,
)
def close(self) -> None:
pass

View file

@ -0,0 +1,115 @@
"""
JSON serialization helpers for dataclass dict conversion.
Used by pub/sub backends that use JSON as their wire format.
"""
import types
from dataclasses import asdict, is_dataclass
from typing import Any, get_type_hints
def dataclass_to_dict(obj: Any) -> dict:
"""
Recursively convert a dataclass to a dictionary, handling None values and bytes.
None values are excluded from the dictionary (not serialized).
Bytes values are decoded as UTF-8 strings for JSON serialization.
Handles nested dataclasses, lists, and dictionaries recursively.
"""
if obj is None:
return None
# Handle bytes - decode to UTF-8 for JSON serialization
if isinstance(obj, bytes):
return obj.decode('utf-8')
# Handle dataclass - convert to dict then recursively process all values
if is_dataclass(obj):
result = {}
for key, value in asdict(obj).items():
result[key] = dataclass_to_dict(value) if value is not None else None
return result
# Handle list - recursively process all items
if isinstance(obj, list):
return [dataclass_to_dict(item) for item in obj]
# Handle dict - recursively process all values
if isinstance(obj, dict):
return {k: dataclass_to_dict(v) for k, v in obj.items()}
# Return primitive types as-is
return obj
def dict_to_dataclass(data: dict, cls: type) -> Any:
"""
Convert a dictionary back to a dataclass instance.
Handles nested dataclasses and missing fields.
Uses get_type_hints() to resolve forward references (string annotations).
"""
if data is None:
return None
if not is_dataclass(cls):
return data
# Get field types from the dataclass, resolving forward references
# get_type_hints() evaluates string annotations like "Triple | None"
try:
field_types = get_type_hints(cls)
except Exception:
# Fallback if get_type_hints fails (shouldn't happen normally)
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
kwargs = {}
for key, value in data.items():
if key in field_types:
field_type = field_types[key]
# Handle modern union types (X | Y)
if isinstance(field_type, types.UnionType):
# Check if it's Optional (X | None)
if type(None) in field_type.__args__:
# Get the non-None type
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Check if this is a generic type (list, dict, etc.)
elif hasattr(field_type, '__origin__'):
# Handle list[T]
if field_type.__origin__ == list:
item_type = field_type.__args__[0] if field_type.__args__ else None
if item_type and is_dataclass(item_type) and isinstance(value, list):
kwargs[key] = [
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
for item in value
]
else:
kwargs[key] = value
# Handle old-style Optional[T] (which is Union[T, None])
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
# Get the non-None type from Union
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Handle direct dataclass fields
elif is_dataclass(field_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, field_type)
# Handle bytes fields (UTF-8 encoded strings from JSON)
elif field_type == bytes and isinstance(value, str):
kwargs[key] = value.encode('utf-8')
else:
kwargs[key] = value
return cls(**kwargs)

View file

@ -51,7 +51,7 @@ class Subscriber:
topic=self.topic,
subscription=self.subscription,
schema=self.schema,
consumer_type='shared',
consumer_type='exclusive',
)
self.task = asyncio.create_task(self.run())