Messaging fabric plugins (#592)

* Plugin architecture for messaging fabric

* Schemas use a technology neutral expression

* Schemas strictness has uncovered some incorrect schema use which is fixed
This commit is contained in:
cybermaggedon 2025-12-17 21:40:43 +00:00 committed by GitHub
parent 1865b3f3c8
commit 34eb083836
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
100 changed files with 2342 additions and 828 deletions

View file

@ -15,7 +15,7 @@ from prometheus_client import start_http_server, Info
from .. schema import ConfigPush, config_push_queue
from .. log_level import LogLevel
from . pubsub import PulsarClient
from . pubsub import PulsarClient, get_pubsub
from . producer import Producer
from . consumer import Consumer
from . metrics import ProcessorMetrics, ConsumerMetrics
@ -34,8 +34,11 @@ class AsyncProcessor:
# Store the identity
self.id = params.get("id")
# Register a pulsar client
self.pulsar_client_object = PulsarClient(**params)
# Create pub/sub backend via factory
self.pubsub_backend = get_pubsub(**params)
# Store pulsar_host for backward compatibility
self._pulsar_host = params.get("pulsar_host", "pulsar://pulsar:6650")
# Initialise metrics, records the parameters
ProcessorMetrics(processor = self.id).info({
@ -70,7 +73,7 @@ class AsyncProcessor:
self.config_sub_task = Consumer(
taskgroup = self.taskgroup,
client = self.pulsar_client,
backend = self.pubsub_backend, # Changed from client to backend
subscriber = config_subscriber_id,
flow = None,
@ -96,16 +99,16 @@ class AsyncProcessor:
# This is called to stop all threads. An over-ride point for extra
# functionality
def stop(self):
self.pulsar_client.close()
self.pubsub_backend.close()
self.running = False
# Returns the pulsar host
# Returns the pub/sub backend (new interface)
@property
def pulsar_host(self): return self.pulsar_client_object.pulsar_host
def pubsub(self): return self.pubsub_backend
# Returns the pulsar client
# Returns the pulsar host (backward compatibility)
@property
def pulsar_client(self): return self.pulsar_client_object.client
def pulsar_host(self): return self._pulsar_host
# Register a new event handler for configuration change
def register_config_handler(self, handler):
@ -247,6 +250,14 @@ class AsyncProcessor:
@staticmethod
def add_args(parser):
# Pub/sub backend selection
parser.add_argument(
'--pubsub-backend',
default=os.getenv('PUBSUB_BACKEND', 'pulsar'),
choices=['pulsar', 'mqtt'],
help='Pub/sub backend (default: pulsar, env: PUBSUB_BACKEND)',
)
PulsarClient.add_args(parser)
add_logging_args(parser)

View file

@ -0,0 +1,148 @@
"""
Backend abstraction interfaces for pub/sub systems.
This module defines Protocol classes that all pub/sub backends must implement,
allowing TrustGraph to work with different messaging systems (Pulsar, MQTT, Kafka, etc.)
"""
from typing import Protocol, Any, runtime_checkable
@runtime_checkable
class Message(Protocol):
"""Protocol for a received message."""
def value(self) -> Any:
"""
Get the deserialized message content.
Returns:
Dataclass instance representing the message
"""
...
def properties(self) -> dict:
"""
Get message properties/metadata.
Returns:
Dictionary of message properties
"""
...
@runtime_checkable
class BackendProducer(Protocol):
"""Protocol for backend-specific producer."""
def send(self, message: Any, properties: dict = {}) -> None:
"""
Send a message (dataclass instance) with optional properties.
Args:
message: Dataclass instance to send
properties: Optional metadata properties
"""
...
def flush(self) -> None:
"""Flush any buffered messages."""
...
def close(self) -> None:
"""Close the producer."""
...
@runtime_checkable
class BackendConsumer(Protocol):
"""Protocol for backend-specific consumer."""
def receive(self, timeout_millis: int = 2000) -> Message:
"""
Receive a message from the topic.
Args:
timeout_millis: Timeout in milliseconds
Returns:
Message object
Raises:
TimeoutError: If no message received within timeout
"""
...
def acknowledge(self, message: Message) -> None:
"""
Acknowledge successful processing of a message.
Args:
message: The message to acknowledge
"""
...
def negative_acknowledge(self, message: Message) -> None:
"""
Negative acknowledge - triggers redelivery.
Args:
message: The message to negatively acknowledge
"""
...
def unsubscribe(self) -> None:
"""Unsubscribe from the topic."""
...
def close(self) -> None:
"""Close the consumer."""
...
@runtime_checkable
class PubSubBackend(Protocol):
"""Protocol defining the interface all pub/sub backends must implement."""
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
"""
Create a producer for a topic.
Args:
topic: Generic topic format (qos/tenant/namespace/queue)
schema: Dataclass type for messages
**options: Backend-specific options (e.g., chunking_enabled)
Returns:
Backend-specific producer instance
"""
...
def create_consumer(
self,
topic: str,
subscription: str,
schema: type,
initial_position: str = 'latest',
consumer_type: str = 'shared',
**options
) -> BackendConsumer:
"""
Create a consumer for a topic.
Args:
topic: Generic topic format (qos/tenant/namespace/queue)
subscription: Subscription/consumer group name
schema: Dataclass type for messages
initial_position: 'earliest' or 'latest' (some backends may ignore)
consumer_type: 'shared', 'exclusive', 'failover' (some backends may ignore)
**options: Backend-specific options
Returns:
Backend-specific consumer instance
"""
...
def close(self) -> None:
"""Close the backend connection."""
...

View file

@ -9,9 +9,6 @@
# one handler, and a single thread of concurrency, nothing too outrageous
# will happen if synchronous / blocking code is used
from pulsar.schema import JsonSchema
import pulsar
import _pulsar
import asyncio
import time
import logging
@ -21,11 +18,15 @@ from .. exceptions import TooManyRequests
# Module logger
logger = logging.getLogger(__name__)
# Timeout exception - can come from different backends
class TimeoutError(Exception):
pass
class Consumer:
def __init__(
self, taskgroup, flow, client, topic, subscriber, schema,
handler,
self, taskgroup, flow, backend, topic, subscriber, schema,
handler,
metrics = None,
start_of_messages=False,
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
@ -35,7 +36,7 @@ class Consumer:
self.taskgroup = taskgroup
self.flow = flow
self.client = client
self.backend = backend # Changed from 'client' to 'backend'
self.topic = topic
self.subscriber = subscriber
self.schema = schema
@ -96,18 +97,20 @@ class Consumer:
logger.info(f"Subscribing to topic: {self.topic}")
# Determine initial position
if self.start_of_messages:
pos = pulsar.InitialPosition.Earliest
initial_pos = 'earliest'
else:
pos = pulsar.InitialPosition.Latest
initial_pos = 'latest'
# Create consumer via backend
self.consumer = await asyncio.to_thread(
self.client.subscribe,
self.backend.create_consumer,
topic = self.topic,
subscription_name = self.subscriber,
schema = JsonSchema(self.schema),
initial_position = pos,
consumer_type = pulsar.ConsumerType.Shared,
subscription = self.subscriber,
schema = self.schema,
initial_position = initial_pos,
consumer_type = 'shared',
)
except Exception as e:
@ -159,9 +162,10 @@ class Consumer:
self.consumer.receive,
timeout_millis=2000
)
except _pulsar.Timeout:
continue
except Exception as e:
# Handle timeout from any backend
if 'timeout' in str(type(e)).lower() or 'timeout' in str(e).lower():
continue
raise e
await self.handle_one_from_queue(msg)

View file

@ -19,7 +19,7 @@ class ConsumerSpec(Spec):
consumer = Consumer(
taskgroup = processor.taskgroup,
flow = flow,
client = processor.pulsar_client,
backend = processor.pubsub,
topic = definition[self.name],
subscriber = processor.id + "--" + flow.name + "--" + self.name,
schema = self.schema,

View file

@ -1,5 +1,4 @@
from pulsar.schema import JsonSchema
import asyncio
import logging
@ -8,10 +7,10 @@ logger = logging.getLogger(__name__)
class Producer:
def __init__(self, client, topic, schema, metrics=None,
def __init__(self, backend, topic, schema, metrics=None,
chunking_enabled=True):
self.client = client
self.backend = backend # Changed from 'client' to 'backend'
self.topic = topic
self.schema = schema
@ -44,9 +43,9 @@ class Producer:
try:
logger.info(f"Connecting publisher to {self.topic}...")
self.producer = self.client.create_producer(
self.producer = self.backend.create_producer(
topic = self.topic,
schema = JsonSchema(self.schema),
schema = self.schema,
chunking_enabled = self.chunking_enabled,
)
logger.info(f"Connected publisher to {self.topic}")

View file

@ -15,7 +15,7 @@ class ProducerSpec(Spec):
)
producer = Producer(
client = processor.pulsar_client,
backend = processor.pubsub,
topic = definition[self.name],
schema = self.schema,
metrics = producer_metrics,

View file

@ -37,21 +37,20 @@ class PromptClient(RequestResponse):
else:
logger.info("DEBUG prompt_client: Streaming path")
# Streaming path - collect all chunks
full_text = ""
full_object = None
# Streaming path - just forward chunks, don't accumulate
last_text = ""
last_object = None
async def collect_chunks(resp):
nonlocal full_text, full_object
logger.info(f"DEBUG prompt_client: collect_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}")
async def forward_chunks(resp):
nonlocal last_text, last_object
logger.info(f"DEBUG prompt_client: forward_chunks called, resp.text={resp.text[:50] if resp.text else None}, end_of_stream={getattr(resp, 'end_of_stream', False)}")
if resp.error:
logger.error(f"DEBUG prompt_client: Error in response: {resp.error.message}")
raise RuntimeError(resp.error.message)
if resp.text:
full_text += resp.text
logger.info(f"DEBUG prompt_client: Accumulated {len(full_text)} chars")
last_text = resp.text
# Call chunk callback if provided
if chunk_callback:
logger.info(f"DEBUG prompt_client: Calling chunk_callback")
@ -61,7 +60,7 @@ class PromptClient(RequestResponse):
chunk_callback(resp.text)
elif resp.object:
logger.info(f"DEBUG prompt_client: Got object response")
full_object = resp.object
last_object = resp.object
end_stream = getattr(resp, 'end_of_stream', False)
logger.info(f"DEBUG prompt_client: Returning end_of_stream={end_stream}")
@ -79,17 +78,17 @@ class PromptClient(RequestResponse):
logger.info(f"DEBUG prompt_client: About to call self.request with recipient, timeout={timeout}")
await self.request(
req,
recipient=collect_chunks,
recipient=forward_chunks,
timeout=timeout
)
logger.info(f"DEBUG prompt_client: self.request returned, full_text has {len(full_text)} chars")
logger.info(f"DEBUG prompt_client: self.request returned, last_text={last_text[:50] if last_text else None}")
if full_text:
logger.info("DEBUG prompt_client: Returning full_text")
return full_text
if last_text:
logger.info("DEBUG prompt_client: Returning last_text")
return last_text
logger.info("DEBUG prompt_client: Returning parsed full_object")
return json.loads(full_object)
logger.info("DEBUG prompt_client: Returning parsed last_object")
return json.loads(last_object) if last_object else None
async def extract_definitions(self, text, timeout=600):
return await self.prompt(

View file

@ -1,9 +1,6 @@
from pulsar.schema import JsonSchema
import asyncio
import time
import pulsar
import logging
# Module logger
@ -11,9 +8,9 @@ logger = logging.getLogger(__name__)
class Publisher:
def __init__(self, client, topic, schema=None, max_size=10,
def __init__(self, backend, topic, schema=None, max_size=10,
chunking_enabled=True, drain_timeout=5.0):
self.client = client
self.backend = backend # Changed from 'client' to 'backend'
self.topic = topic
self.schema = schema
self.q = asyncio.Queue(maxsize=max_size)
@ -47,9 +44,9 @@ class Publisher:
try:
producer = self.client.create_producer(
producer = self.backend.create_producer(
topic=self.topic,
schema=JsonSchema(self.schema),
schema=self.schema,
chunking_enabled=self.chunking_enabled,
)

View file

@ -4,8 +4,45 @@ import pulsar
import _pulsar
import uuid
from pulsar.schema import JsonSchema
import logging
from .. log_level import LogLevel
from .pulsar_backend import PulsarBackend
logger = logging.getLogger(__name__)
def get_pubsub(**config):
"""
Factory function to create a pub/sub backend based on configuration.
Args:
config: Configuration dictionary from command-line args
Must include 'pubsub_backend' key
Returns:
Backend instance (PulsarBackend, MQTTBackend, etc.)
Example:
backend = get_pubsub(
pubsub_backend='pulsar',
pulsar_host='pulsar://localhost:6650'
)
"""
backend_type = config.get('pubsub_backend', 'pulsar')
if backend_type == 'pulsar':
return PulsarBackend(
host=config.get('pulsar_host', PulsarClient.default_pulsar_host),
api_key=config.get('pulsar_api_key', PulsarClient.default_pulsar_api_key),
listener=config.get('pulsar_listener'),
)
elif backend_type == 'mqtt':
# TODO: Implement MQTT backend
raise NotImplementedError("MQTT backend not yet implemented")
else:
raise ValueError(f"Unknown pub/sub backend: {backend_type}")
class PulsarClient:

View file

@ -0,0 +1,350 @@
"""
Pulsar backend implementation for pub/sub abstraction.
This module provides a Pulsar-specific implementation of the backend interfaces,
handling topic mapping, serialization, and Pulsar client management.
"""
import pulsar
import _pulsar
import json
import logging
import base64
import types
from dataclasses import asdict, is_dataclass
from typing import Any
from .backend import PubSubBackend, BackendProducer, BackendConsumer, Message
logger = logging.getLogger(__name__)
def dataclass_to_dict(obj: Any) -> dict:
"""
Recursively convert a dataclass to a dictionary, handling None values and bytes.
None values are excluded from the dictionary (not serialized).
Bytes values are decoded as UTF-8 strings for JSON serialization (matching Pulsar behavior).
"""
if obj is None:
return None
if is_dataclass(obj):
result = {}
for key, value in asdict(obj).items():
if value is not None:
if isinstance(value, bytes):
# Decode bytes as UTF-8 for JSON serialization (like Pulsar did)
result[key] = value.decode('utf-8')
elif is_dataclass(value):
result[key] = dataclass_to_dict(value)
elif isinstance(value, list):
result[key] = [
item.decode('utf-8') if isinstance(item, bytes)
else dataclass_to_dict(item) if is_dataclass(item)
else item
for item in value
]
elif isinstance(value, dict):
result[key] = {k: dataclass_to_dict(v) if is_dataclass(v) else v for k, v in value.items()}
else:
result[key] = value
return result
return obj
def dict_to_dataclass(data: dict, cls: type) -> Any:
"""
Convert a dictionary back to a dataclass instance.
Handles nested dataclasses and missing fields.
"""
if data is None:
return None
if not is_dataclass(cls):
return data
# Get field types from the dataclass
field_types = {f.name: f.type for f in cls.__dataclass_fields__.values()}
kwargs = {}
for key, value in data.items():
if key in field_types:
field_type = field_types[key]
# Handle modern union types (X | Y)
if isinstance(field_type, types.UnionType):
# Check if it's Optional (X | None)
if type(None) in field_type.__args__:
# Get the non-None type
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Check if this is a generic type (list, dict, etc.)
elif hasattr(field_type, '__origin__'):
# Handle list[T]
if field_type.__origin__ == list:
item_type = field_type.__args__[0] if field_type.__args__ else None
if item_type and is_dataclass(item_type) and isinstance(value, list):
kwargs[key] = [
dict_to_dataclass(item, item_type) if isinstance(item, dict) else item
for item in value
]
else:
kwargs[key] = value
# Handle old-style Optional[T] (which is Union[T, None])
elif hasattr(field_type, '__args__') and type(None) in field_type.__args__:
# Get the non-None type from Union
actual_type = next((t for t in field_type.__args__ if t is not type(None)), None)
if actual_type and is_dataclass(actual_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, actual_type)
else:
kwargs[key] = value
else:
kwargs[key] = value
# Handle direct dataclass fields
elif is_dataclass(field_type) and isinstance(value, dict):
kwargs[key] = dict_to_dataclass(value, field_type)
# Handle bytes fields (UTF-8 encoded strings from JSON)
elif field_type == bytes and isinstance(value, str):
kwargs[key] = value.encode('utf-8')
else:
kwargs[key] = value
return cls(**kwargs)
class PulsarMessage:
"""Wrapper for Pulsar messages to match Message protocol."""
def __init__(self, pulsar_msg, schema_cls):
self._msg = pulsar_msg
self._schema_cls = schema_cls
self._value = None
def value(self) -> Any:
"""Deserialize and return the message value as a dataclass."""
if self._value is None:
# Get JSON string from Pulsar message
json_data = self._msg.data().decode('utf-8')
data_dict = json.loads(json_data)
# Convert to dataclass
self._value = dict_to_dataclass(data_dict, self._schema_cls)
return self._value
def properties(self) -> dict:
"""Return message properties."""
return self._msg.properties()
class PulsarBackendProducer:
"""Pulsar-specific producer implementation."""
def __init__(self, pulsar_producer, schema_cls):
self._producer = pulsar_producer
self._schema_cls = schema_cls
def send(self, message: Any, properties: dict = {}) -> None:
"""Send a dataclass message."""
# Convert dataclass to dict, excluding None values
data_dict = dataclass_to_dict(message)
# Serialize to JSON
json_data = json.dumps(data_dict)
# Send via Pulsar
self._producer.send(json_data.encode('utf-8'), properties=properties)
def flush(self) -> None:
"""Flush buffered messages."""
self._producer.flush()
def close(self) -> None:
"""Close the producer."""
self._producer.close()
class PulsarBackendConsumer:
"""Pulsar-specific consumer implementation."""
def __init__(self, pulsar_consumer, schema_cls):
self._consumer = pulsar_consumer
self._schema_cls = schema_cls
def receive(self, timeout_millis: int = 2000) -> Message:
"""Receive a message."""
pulsar_msg = self._consumer.receive(timeout_millis=timeout_millis)
return PulsarMessage(pulsar_msg, self._schema_cls)
def acknowledge(self, message: Message) -> None:
"""Acknowledge a message."""
if isinstance(message, PulsarMessage):
self._consumer.acknowledge(message._msg)
def negative_acknowledge(self, message: Message) -> None:
"""Negative acknowledge a message."""
if isinstance(message, PulsarMessage):
self._consumer.negative_acknowledge(message._msg)
def unsubscribe(self) -> None:
"""Unsubscribe from the topic."""
self._consumer.unsubscribe()
def close(self) -> None:
"""Close the consumer."""
self._consumer.close()
class PulsarBackend:
"""
Pulsar backend implementation.
Handles topic mapping, client management, and creation of Pulsar-specific
producers and consumers.
"""
def __init__(self, host: str, api_key: str = None, listener: str = None):
"""
Initialize Pulsar backend.
Args:
host: Pulsar broker URL (e.g., pulsar://localhost:6650)
api_key: Optional API key for authentication
listener: Optional listener name for multi-homed setups
"""
self.host = host
self.api_key = api_key
self.listener = listener
# Create Pulsar client
client_args = {'service_url': host}
if listener:
client_args['listener_name'] = listener
if api_key:
client_args['authentication'] = pulsar.AuthenticationToken(api_key)
self.client = pulsar.Client(**client_args)
logger.info(f"Pulsar client connected to {host}")
def map_topic(self, generic_topic: str) -> str:
"""
Map generic topic format to Pulsar URI.
Format: qos/tenant/namespace/queue
Example: q1/tg/flow/my-queue -> persistent://tg/flow/my-queue
Args:
generic_topic: Generic topic string or already-formatted Pulsar URI
Returns:
Pulsar topic URI
"""
# If already a Pulsar URI, return as-is
if '://' in generic_topic:
return generic_topic
parts = generic_topic.split('/', 3)
if len(parts) != 4:
raise ValueError(f"Invalid topic format: {generic_topic}, expected qos/tenant/namespace/queue")
qos, tenant, namespace, queue = parts
# Map QoS to persistence
if qos == 'q0':
persistence = 'non-persistent'
elif qos in ['q1', 'q2']:
persistence = 'persistent'
else:
raise ValueError(f"Invalid QoS level: {qos}, expected q0, q1, or q2")
return f"{persistence}://{tenant}/{namespace}/{queue}"
def create_producer(self, topic: str, schema: type, **options) -> BackendProducer:
"""
Create a Pulsar producer.
Args:
topic: Generic topic format (qos/tenant/namespace/queue)
schema: Dataclass type for messages
**options: Backend-specific options (e.g., chunking_enabled)
Returns:
PulsarBackendProducer instance
"""
pulsar_topic = self.map_topic(topic)
producer_args = {
'topic': pulsar_topic,
'schema': pulsar.schema.BytesSchema(), # We handle serialization ourselves
}
# Add optional parameters
if 'chunking_enabled' in options:
producer_args['chunking_enabled'] = options['chunking_enabled']
pulsar_producer = self.client.create_producer(**producer_args)
logger.debug(f"Created producer for topic: {pulsar_topic}")
return PulsarBackendProducer(pulsar_producer, schema)
def create_consumer(
self,
topic: str,
subscription: str,
schema: type,
initial_position: str = 'latest',
consumer_type: str = 'shared',
**options
) -> BackendConsumer:
"""
Create a Pulsar consumer.
Args:
topic: Generic topic format (qos/tenant/namespace/queue)
subscription: Subscription name
schema: Dataclass type for messages
initial_position: 'earliest' or 'latest'
consumer_type: 'shared', 'exclusive', or 'failover'
**options: Backend-specific options
Returns:
PulsarBackendConsumer instance
"""
pulsar_topic = self.map_topic(topic)
# Map initial position
if initial_position == 'earliest':
pos = pulsar.InitialPosition.Earliest
else:
pos = pulsar.InitialPosition.Latest
# Map consumer type
if consumer_type == 'exclusive':
ctype = pulsar.ConsumerType.Exclusive
elif consumer_type == 'failover':
ctype = pulsar.ConsumerType.Failover
else:
ctype = pulsar.ConsumerType.Shared
consumer_args = {
'topic': pulsar_topic,
'subscription_name': subscription,
'schema': pulsar.schema.BytesSchema(), # We handle deserialization ourselves
'initial_position': pos,
'consumer_type': ctype,
}
pulsar_consumer = self.client.subscribe(**consumer_args)
logger.debug(f"Created consumer for topic: {pulsar_topic}, subscription: {subscription}")
return PulsarBackendConsumer(pulsar_consumer, schema)
def close(self) -> None:
"""Close the Pulsar client."""
self.client.close()
logger.info("Pulsar client closed")

View file

@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class RequestResponse(Subscriber):
def __init__(
self, client, subscription, consumer_name,
self, backend, subscription, consumer_name,
request_topic, request_schema,
request_metrics,
response_topic, response_schema,
@ -22,7 +22,7 @@ class RequestResponse(Subscriber):
):
super(RequestResponse, self).__init__(
client = client,
backend = backend,
subscription = subscription,
consumer_name = consumer_name,
topic = response_topic,
@ -31,7 +31,7 @@ class RequestResponse(Subscriber):
)
self.producer = Producer(
client = client,
backend = backend,
topic = request_topic,
schema = request_schema,
metrics = request_metrics,
@ -126,7 +126,7 @@ class RequestResponseSpec(Spec):
)
rr = self.impl(
client = processor.pulsar_client,
backend = processor.pubsub,
# Make subscription names unique, so that all subscribers get
# to see all response messages

View file

@ -3,9 +3,7 @@
# off of a queue and make it available using an internal broker system,
# so suitable for when multiple recipients are reading from the same queue
from pulsar.schema import JsonSchema
import asyncio
import _pulsar
import time
import logging
import uuid
@ -13,12 +11,16 @@ import uuid
# Module logger
logger = logging.getLogger(__name__)
# Timeout exception - can come from different backends
class TimeoutError(Exception):
pass
class Subscriber:
def __init__(self, client, topic, subscription, consumer_name,
def __init__(self, backend, topic, subscription, consumer_name,
schema=None, max_size=100, metrics=None,
backpressure_strategy="block", drain_timeout=5.0):
self.client = client
self.backend = backend # Changed from 'client' to 'backend'
self.topic = topic
self.subscription = subscription
self.consumer_name = consumer_name
@ -43,18 +45,14 @@ class Subscriber:
async def start(self):
# Build subscribe arguments
subscribe_args = {
'topic': self.topic,
'subscription_name': self.subscription,
'consumer_name': self.consumer_name,
}
# Only add schema if provided (omit if None)
if self.schema is not None:
subscribe_args['schema'] = JsonSchema(self.schema)
self.consumer = self.client.subscribe(**subscribe_args)
# Create consumer via backend
self.consumer = await asyncio.to_thread(
self.backend.create_consumer,
topic=self.topic,
subscription=self.subscription,
schema=self.schema,
consumer_type='shared',
)
self.task = asyncio.create_task(self.run())
@ -94,12 +92,13 @@ class Subscriber:
drain_end_time = time.time() + self.drain_timeout
logger.info(f"Subscriber entering drain mode, timeout={self.drain_timeout}s")
# Stop accepting new messages from Pulsar during drain
if self.consumer:
# Stop accepting new messages during drain
# Note: Not all backends support pausing message listeners
if self.consumer and hasattr(self.consumer, 'pause_message_listener'):
try:
self.consumer.pause_message_listener()
except _pulsar.InvalidConfiguration:
# Not all consumers have message listeners (e.g., blocking receive mode)
except Exception:
# Not all consumers support message listeners
pass
# Check drain timeout
@ -133,9 +132,10 @@ class Subscriber:
self.consumer.receive,
timeout_millis=250
)
except _pulsar.Timeout:
continue
except Exception as e:
# Handle timeout from any backend
if 'timeout' in str(type(e)).lower() or 'timeout' in str(e).lower():
continue
logger.error(f"Exception in subscriber receive: {e}", exc_info=True)
raise e
@ -157,19 +157,20 @@ class Subscriber:
for msg in self.pending_acks.values():
try:
self.consumer.negative_acknowledge(msg)
except _pulsar.AlreadyClosed:
pass # Consumer already closed
except Exception:
pass # Consumer already closed or error
self.pending_acks.clear()
if self.consumer:
try:
self.consumer.unsubscribe()
except _pulsar.AlreadyClosed:
pass # Already closed
if hasattr(self.consumer, 'unsubscribe'):
try:
self.consumer.unsubscribe()
except Exception:
pass # Already closed or error
try:
self.consumer.close()
except _pulsar.AlreadyClosed:
pass # Already closed
except Exception:
pass # Already closed or error
self.consumer = None

View file

@ -16,7 +16,7 @@ class SubscriberSpec(Spec):
)
subscriber = Subscriber(
client = processor.pulsar_client,
backend = processor.pubsub,
topic = definition[self.name],
subscription = flow.id,
consumer_name = flow.id,