diff --git a/containers/Containerfile.ocr b/containers/Containerfile.ocr index 661b8e17..43b66463 100644 --- a/containers/Containerfile.ocr +++ b/containers/Containerfile.ocr @@ -9,6 +9,7 @@ FROM docker.io/fedora:42 AS base ENV PIP_BREAK_SYSTEM_PACKAGES=1 RUN dnf install -y python3.12 && \ + dnf install -y tesseract poppler-utils && \ alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \ python -m ensurepip --upgrade && \ pip3 install --no-cache-dir wheel aiohttp && \ diff --git a/trustgraph-base/trustgraph/base/async_processor.py b/trustgraph-base/trustgraph/base/async_processor.py index ba1d4e1a..545220c4 100644 --- a/trustgraph-base/trustgraph/base/async_processor.py +++ b/trustgraph-base/trustgraph/base/async_processor.py @@ -13,7 +13,6 @@ from prometheus_client import start_http_server, Info from .. schema import ConfigPush, config_push_queue from .. log_level import LogLevel -from .. exceptions import TooManyRequests from . pubsub import PulsarClient from . producer import Producer from . consumer import Consumer diff --git a/trustgraph-base/trustgraph/base/consumer.py b/trustgraph-base/trustgraph/base/consumer.py index 162e10eb..8b7b2b0d 100644 --- a/trustgraph-base/trustgraph/base/consumer.py +++ b/trustgraph-base/trustgraph/base/consumer.py @@ -1,4 +1,14 @@ +# Consumer is similar to subscriber: It takes information from a queue +# and passes on to a processor function. This is the main receiving +# loop for TrustGraph processors. Incorporates retry functionality + +# Note: there is a 'defect' in the system which is tolerated, althought +# the processing handlers are async functions, ideally implementation +# would use all async code. In practice if the processor only implements +# one handler, and a single thread of concurrency, nothing too outrageous +# will happen if synchronous / blocking code is used + from pulsar.schema import JsonSchema import pulsar import _pulsar @@ -16,6 +26,7 @@ class Consumer: start_of_messages=False, rate_limit_retry_time = 10, rate_limit_timeout = 7200, reconnect_time = 5, + concurrency = 1, # Number of concurrent requests to handle ): self.taskgroup = taskgroup @@ -34,7 +45,9 @@ class Consumer: self.start_of_messages = start_of_messages self.running = True - self.task = None + self.consumer_task = None + + self.concurrency = concurrency self.metrics = metrics @@ -52,7 +65,11 @@ class Consumer: async def stop(self): self.running = False - await self.task + + if self.consumer_task: + await self.consumer_task + + self.consumer_task = None async def start(self): @@ -62,9 +79,9 @@ class Consumer: if self.metrics: self.metrics.state("stopped") - self.task = self.taskgroup.create_task(self.run()) + self.consumer_task = self.taskgroup.create_task(self.consumer_run()) - async def run(self): + async def consumer_run(self): while self.running: @@ -102,7 +119,19 @@ class Consumer: try: - await self.consume() + print( + "Starting", self.concurrency, "receiver threads", + flush=True + ) + + async with asyncio.TaskGroup() as tg: + + tasks = [] + + for i in range(0, self.concurrency): + tasks.append( + tg.create_task(self.consume_from_queue()) + ) if self.metrics: self.metrics.state("stopped") @@ -120,7 +149,7 @@ class Consumer: self.consumer.unsubscribe() self.consumer.close() - async def consume(self): + async def consume_from_queue(self): while self.running: @@ -134,71 +163,75 @@ class Consumer: except Exception as e: raise e - expiry = time.time() + self.rate_limit_timeout + await self.handle_one_from_queue(msg) - # This loop is for retry on rate-limit / resource limits - while self.running: + async def handle_one_from_queue(self, msg): - if time.time() > expiry: + expiry = time.time() + self.rate_limit_timeout - print("Gave up waiting for rate-limit retry", flush=True) + # This loop is for retry on rate-limit / resource limits + while self.running: - # Message failed to be processed, this causes it to - # be retried - self.consumer.negative_acknowledge(msg) + if time.time() > expiry: - if self.metrics: - self.metrics.process("error") + print("Gave up waiting for rate-limit retry", flush=True) - # Break out of retry loop, processes next message - break + # Message failed to be processed, this causes it to + # be retried + self.consumer.negative_acknowledge(msg) - try: + if self.metrics: + self.metrics.process("error") - print("Handle...", flush=True) + # Break out of retry loop, processes next message + break - if self.metrics: + try: - with self.metrics.record_time(): - await self.handler(msg, self, self.flow) + print("Handle...", flush=True) - else: + if self.metrics: + + with self.metrics.record_time(): await self.handler(msg, self, self.flow) - print("Handled.", flush=True) + else: + await self.handler(msg, self, self.flow) - # Acknowledge successful processing of the message - self.consumer.acknowledge(msg) + print("Handled.", flush=True) - if self.metrics: - self.metrics.process("success") + # Acknowledge successful processing of the message + self.consumer.acknowledge(msg) - # Break out of retry loop - break + if self.metrics: + self.metrics.process("success") - except TooManyRequests: + # Break out of retry loop + break - print("TooManyRequests: will retry...", flush=True) + except TooManyRequests: - if self.metrics: - self.metrics.rate_limit() + print("TooManyRequests: will retry...", flush=True) - # Sleep - await asyncio.sleep(self.rate_limit_retry_time) + if self.metrics: + self.metrics.rate_limit() - # Contine from retry loop, just causes a reprocessing - continue + # Sleep + await asyncio.sleep(self.rate_limit_retry_time) - except Exception as e: + # Contine from retry loop, just causes a reprocessing + continue - print("consume exception:", e, flush=True) + except Exception as e: - # Message failed to be processed, this causes it to - # be retried - self.consumer.negative_acknowledge(msg) + print("consume exception:", e, flush=True) - if self.metrics: - self.metrics.process("error") + # Message failed to be processed, this causes it to + # be retried + self.consumer.negative_acknowledge(msg) - # Break out of retry loop, processes next message - break + if self.metrics: + self.metrics.process("error") + + # Break out of retry loop, processes next message + break diff --git a/trustgraph-base/trustgraph/base/consumer_spec.py b/trustgraph-base/trustgraph/base/consumer_spec.py index 93665476..89581b02 100644 --- a/trustgraph-base/trustgraph/base/consumer_spec.py +++ b/trustgraph-base/trustgraph/base/consumer_spec.py @@ -4,10 +4,11 @@ from . consumer import Consumer from . spec import Spec class ConsumerSpec(Spec): - def __init__(self, name, schema, handler): + def __init__(self, name, schema, handler, concurrency = 1): self.name = name self.schema = schema self.handler = handler + self.concurrency = concurrency def add(self, flow, processor, definition): @@ -24,6 +25,7 @@ class ConsumerSpec(Spec): schema = self.schema, handler = self.handler, metrics = consumer_metrics, + concurrency = self.concurrency ) # Consumer handle gets access to producers and other diff --git a/trustgraph-base/trustgraph/base/embeddings_service.py b/trustgraph-base/trustgraph/base/embeddings_service.py index c6befdb7..c0dd3978 100644 --- a/trustgraph-base/trustgraph/base/embeddings_service.py +++ b/trustgraph-base/trustgraph/base/embeddings_service.py @@ -11,20 +11,26 @@ from .. exceptions import TooManyRequests from .. base import FlowProcessor, ConsumerSpec, ProducerSpec default_ident = "embeddings" +default_concurrency = 1 class EmbeddingsService(FlowProcessor): def __init__(self, **params): id = params.get("id") + concurrency = params.get("concurrency", 1) - super(EmbeddingsService, self).__init__(**params | { "id": id }) + super(EmbeddingsService, self).__init__(**params | { + "id": id, + "concurrency": concurrency, + }) self.register_specification( ConsumerSpec( name = "request", schema = EmbeddingsRequest, - handler = self.on_request + handler = self.on_request, + concurrency = concurrency, ) ) @@ -84,6 +90,13 @@ class EmbeddingsService(FlowProcessor): @staticmethod def add_args(parser): + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Concurrent processing threads (default: {default_concurrency})' + ) + FlowProcessor.add_args(parser) diff --git a/trustgraph-base/trustgraph/base/llm_service.py b/trustgraph-base/trustgraph/base/llm_service.py index c79b819b..fddbdf3e 100644 --- a/trustgraph-base/trustgraph/base/llm_service.py +++ b/trustgraph-base/trustgraph/base/llm_service.py @@ -11,9 +11,13 @@ from .. exceptions import TooManyRequests from .. base import FlowProcessor, ConsumerSpec, ProducerSpec default_ident = "text-completion" +default_concurrency = 1 class LlmResult: - def __init__(self, text=None, in_token=None, out_token=None, model=None): + def __init__( + self, text = None, in_token = None, out_token = None, + model = None, + ): self.text = text self.in_token = in_token self.out_token = out_token @@ -25,14 +29,19 @@ class LlmService(FlowProcessor): def __init__(self, **params): id = params.get("id") + concurrency = params.get("concurrency", 1) - super(LlmService, self).__init__(**params | { "id": id }) + super(LlmService, self).__init__(**params | { + "id": id, + "concurrency": concurrency, + }) self.register_specification( ConsumerSpec( name = "request", schema = TextCompletionRequest, - handler = self.on_request + handler = self.on_request, + concurrency = concurrency, ) ) @@ -115,5 +124,12 @@ class LlmService(FlowProcessor): @staticmethod def add_args(parser): + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Concurrent processing threads (default: {default_concurrency})' + ) + FlowProcessor.add_args(parser) diff --git a/trustgraph-base/trustgraph/base/subscriber.py b/trustgraph-base/trustgraph/base/subscriber.py index 8467b0bf..6e79adab 100644 --- a/trustgraph-base/trustgraph/base/subscriber.py +++ b/trustgraph-base/trustgraph/base/subscriber.py @@ -1,4 +1,8 @@ +# Subscriber is similar to consumer: It provides a service to take stuff +# off of a queue and make it available using an internal broker system, +# so suitable for when multiple recipients are reading from the same queue + from pulsar.schema import JsonSchema import asyncio import _pulsar diff --git a/trustgraph-base/trustgraph/base/triples_store_service.py b/trustgraph-base/trustgraph/base/triples_store_service.py index 74f95f57..c33c2801 100644 --- a/trustgraph-base/trustgraph/base/triples_store_service.py +++ b/trustgraph-base/trustgraph/base/triples_store_service.py @@ -5,6 +5,7 @@ Triples store base class from .. schema import Triples from .. base import FlowProcessor, ConsumerSpec +from .. exceptions import TooManyRequests default_ident = "triples-write" diff --git a/trustgraph-base/trustgraph/messaging/__init__.py b/trustgraph-base/trustgraph/messaging/__init__.py new file mode 100644 index 00000000..a9caf950 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/__init__.py @@ -0,0 +1,105 @@ +from .registry import TranslatorRegistry +from .translators import * + +# Auto-register all translators +from .translators.agent import AgentRequestTranslator, AgentResponseTranslator +from .translators.embeddings import EmbeddingsRequestTranslator, EmbeddingsResponseTranslator +from .translators.text_completion import TextCompletionRequestTranslator, TextCompletionResponseTranslator +from .translators.retrieval import ( + DocumentRagRequestTranslator, DocumentRagResponseTranslator, + GraphRagRequestTranslator, GraphRagResponseTranslator +) +from .translators.triples import TriplesQueryRequestTranslator, TriplesQueryResponseTranslator +from .translators.knowledge import KnowledgeRequestTranslator, KnowledgeResponseTranslator +from .translators.library import LibraryRequestTranslator, LibraryResponseTranslator +from .translators.document_loading import DocumentTranslator, TextDocumentTranslator +from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator +from .translators.flow import FlowRequestTranslator, FlowResponseTranslator +from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator +from .translators.embeddings_query import ( + DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator, + GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator +) + +# Register all service translators +TranslatorRegistry.register_service( + "agent", + AgentRequestTranslator(), + AgentResponseTranslator() +) + +TranslatorRegistry.register_service( + "embeddings", + EmbeddingsRequestTranslator(), + EmbeddingsResponseTranslator() +) + +TranslatorRegistry.register_service( + "text-completion", + TextCompletionRequestTranslator(), + TextCompletionResponseTranslator() +) + +TranslatorRegistry.register_service( + "document-rag", + DocumentRagRequestTranslator(), + DocumentRagResponseTranslator() +) + +TranslatorRegistry.register_service( + "graph-rag", + GraphRagRequestTranslator(), + GraphRagResponseTranslator() +) + +TranslatorRegistry.register_service( + "triples-query", + TriplesQueryRequestTranslator(), + TriplesQueryResponseTranslator() +) + +TranslatorRegistry.register_service( + "knowledge", + KnowledgeRequestTranslator(), + KnowledgeResponseTranslator() +) + +TranslatorRegistry.register_service( + "librarian", + LibraryRequestTranslator(), + LibraryResponseTranslator() +) + +TranslatorRegistry.register_service( + "config", + ConfigRequestTranslator(), + ConfigResponseTranslator() +) + +TranslatorRegistry.register_service( + "flow", + FlowRequestTranslator(), + FlowResponseTranslator() +) + +TranslatorRegistry.register_service( + "prompt", + PromptRequestTranslator(), + PromptResponseTranslator() +) + +TranslatorRegistry.register_service( + "document-embeddings-query", + DocumentEmbeddingsRequestTranslator(), + DocumentEmbeddingsResponseTranslator() +) + +TranslatorRegistry.register_service( + "graph-embeddings-query", + GraphEmbeddingsRequestTranslator(), + GraphEmbeddingsResponseTranslator() +) + +# Register single-direction translators for document loading +TranslatorRegistry.register_request("document", DocumentTranslator()) +TranslatorRegistry.register_request("text-document", TextDocumentTranslator()) diff --git a/trustgraph-base/trustgraph/messaging/registry.py b/trustgraph-base/trustgraph/messaging/registry.py new file mode 100644 index 00000000..f42c53bb --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/registry.py @@ -0,0 +1,51 @@ +from typing import Dict, List, Union +from .translators.base import MessageTranslator + + +class TranslatorRegistry: + """Registry for service translators""" + + _request_translators: Dict[str, MessageTranslator] = {} + _response_translators: Dict[str, MessageTranslator] = {} + + @classmethod + def register_request(cls, service_name: str, translator: MessageTranslator): + """Register a request translator for a service""" + cls._request_translators[service_name] = translator + + @classmethod + def register_response(cls, service_name: str, translator: MessageTranslator): + """Register a response translator for a service""" + cls._response_translators[service_name] = translator + + @classmethod + def register_service(cls, service_name: str, request_translator: MessageTranslator, + response_translator: MessageTranslator): + """Register both request and response translators for a service""" + cls.register_request(service_name, request_translator) + cls.register_response(service_name, response_translator) + + @classmethod + def get_request_translator(cls, service_name: str) -> MessageTranslator: + """Get request translator for a service""" + if service_name not in cls._request_translators: + raise KeyError(f"No request translator registered for service: {service_name}") + return cls._request_translators[service_name] + + @classmethod + def get_response_translator(cls, service_name: str) -> MessageTranslator: + """Get response translator for a service""" + if service_name not in cls._response_translators: + raise KeyError(f"No response translator registered for service: {service_name}") + return cls._response_translators[service_name] + + @classmethod + def list_services(cls) -> List[str]: + """List all registered services""" + return sorted(set(cls._request_translators.keys()) | set(cls._response_translators.keys())) + + @classmethod + def has_service(cls, service_name: str) -> bool: + """Check if a service is registered""" + return (service_name in cls._request_translators or + service_name in cls._response_translators) \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/__init__.py b/trustgraph-base/trustgraph/messaging/translators/__init__.py new file mode 100644 index 00000000..fb487281 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/__init__.py @@ -0,0 +1,19 @@ +from .base import Translator, MessageTranslator +from .primitives import ValueTranslator, TripleTranslator, SubgraphTranslator +from .metadata import DocumentMetadataTranslator, ProcessingMetadataTranslator +from .agent import AgentRequestTranslator, AgentResponseTranslator +from .embeddings import EmbeddingsRequestTranslator, EmbeddingsResponseTranslator +from .text_completion import TextCompletionRequestTranslator, TextCompletionResponseTranslator +from .retrieval import DocumentRagRequestTranslator, DocumentRagResponseTranslator +from .retrieval import GraphRagRequestTranslator, GraphRagResponseTranslator +from .triples import TriplesQueryRequestTranslator, TriplesQueryResponseTranslator +from .knowledge import KnowledgeRequestTranslator, KnowledgeResponseTranslator +from .library import LibraryRequestTranslator, LibraryResponseTranslator +from .document_loading import DocumentTranslator, TextDocumentTranslator, ChunkTranslator, DocumentEmbeddingsTranslator +from .config import ConfigRequestTranslator, ConfigResponseTranslator +from .flow import FlowRequestTranslator, FlowResponseTranslator +from .prompt import PromptRequestTranslator, PromptResponseTranslator +from .embeddings_query import ( + DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator, + GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator +) diff --git a/trustgraph-base/trustgraph/messaging/translators/agent.py b/trustgraph-base/trustgraph/messaging/translators/agent.py new file mode 100644 index 00000000..5529a1a2 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/agent.py @@ -0,0 +1,44 @@ +from typing import Dict, Any, Tuple +from ...schema import AgentRequest, AgentResponse +from .base import MessageTranslator + + +class AgentRequestTranslator(MessageTranslator): + """Translator for AgentRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest: + return AgentRequest( + question=data["question"], + plan=data.get("plan", ""), + state=data.get("state", ""), + history=data.get("history", []) + ) + + def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]: + return { + "question": obj.question, + "plan": obj.plan, + "state": obj.state, + "history": obj.history + } + + +class AgentResponseTranslator(MessageTranslator): + """Translator for AgentResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> AgentResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: AgentResponse) -> Dict[str, Any]: + result = {} + if obj.answer: + result["answer"] = obj.answer + if obj.thought: + result["thought"] = obj.thought + if obj.observation: + result["observation"] = obj.observation + return result + + def from_response_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), (obj.answer is not None) \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/base.py b/trustgraph-base/trustgraph/messaging/translators/base.py new file mode 100644 index 00000000..64e2b635 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/base.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Dict, Any, Tuple +from pulsar.schema import Record + + +class Translator(ABC): + """Base class for bidirectional Pulsar ↔ dict translation""" + + @abstractmethod + def to_pulsar(self, data: Dict[str, Any]) -> Record: + """Convert dict to Pulsar schema object""" + pass + + @abstractmethod + def from_pulsar(self, obj: Record) -> Dict[str, Any]: + """Convert Pulsar schema object to dict""" + pass + + +class MessageTranslator(Translator): + """For complete request/response message translation""" + + def from_response_with_completion(self, obj: Record) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final) - for streaming responses""" + return self.from_pulsar(obj), True + + +class SendTranslator(Translator): + """For fire-and-forget send operations (like ServiceSender)""" + + def from_pulsar(self, obj: Record) -> Dict[str, Any]: + """Usually not needed for send-only operations""" + raise NotImplementedError("Send translators typically don't need from_pulsar") + + +def handle_optional_fields(obj: Record, fields: list) -> Dict[str, Any]: + """Helper to extract optional fields from Pulsar object""" + result = {} + for field in fields: + value = getattr(obj, field, None) + if value is not None: + result[field] = value + return result \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/config.py b/trustgraph-base/trustgraph/messaging/translators/config.py new file mode 100644 index 00000000..10e023f6 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/config.py @@ -0,0 +1,100 @@ +from typing import Dict, Any, Tuple +from ...schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue +from .base import MessageTranslator + + +class ConfigRequestTranslator(MessageTranslator): + """Translator for ConfigRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> ConfigRequest: + keys = None + if "keys" in data: + keys = [ + ConfigKey( + type=k["type"], + key=k["key"] + ) + for k in data["keys"] + ] + + values = None + if "values" in data: + values = [ + ConfigValue( + type=v["type"], + key=v["key"], + value=v["value"] + ) + for v in data["values"] + ] + + return ConfigRequest( + operation=data.get("operation"), + keys=keys, + type=data.get("type"), + values=values + ) + + def from_pulsar(self, obj: ConfigRequest) -> Dict[str, Any]: + result = {} + + if obj.operation: + result["operation"] = obj.operation + if obj.type: + result["type"] = obj.type + + if obj.keys: + result["keys"] = [ + { + "type": k.type, + "key": k.key + } + for k in obj.keys + ] + + if obj.values: + result["values"] = [ + { + "type": v.type, + "key": v.key, + "value": v.value + } + for v in obj.values + ] + + return result + + +class ConfigResponseTranslator(MessageTranslator): + """Translator for ConfigResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> ConfigResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: ConfigResponse) -> Dict[str, Any]: + result = {} + + if obj.version is not None: + result["version"] = obj.version + + if obj.values: + result["values"] = [ + { + "type": v.type, + "key": v.key, + "value": v.value + } + for v in obj.values + ] + + if obj.directory: + result["directory"] = obj.directory + + if obj.config: + result["config"] = obj.config + + return result + + def from_response_with_completion(self, obj: ConfigResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/document_loading.py b/trustgraph-base/trustgraph/messaging/translators/document_loading.py new file mode 100644 index 00000000..3dfef718 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/document_loading.py @@ -0,0 +1,191 @@ +import base64 +from typing import Dict, Any +from ...schema import Document, TextDocument, Chunk, DocumentEmbeddings, ChunkEmbeddings +from .base import SendTranslator +from .metadata import DocumentMetadataTranslator +from .primitives import SubgraphTranslator + + +class DocumentTranslator(SendTranslator): + """Translator for Document schema objects (PDF docs etc.)""" + + def __init__(self): + self.subgraph_translator = SubgraphTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> Document: + metadata = data.get("metadata", []) + + # Handle base64 content validation + doc = base64.b64decode(data["data"]) + + from ...schema import Metadata + return Document( + metadata=Metadata( + id=data.get("id"), + metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [], + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default"), + ), + data=base64.b64encode(doc).decode("utf-8") + ) + + def from_pulsar(self, obj: Document) -> Dict[str, Any]: + result = { + "data": obj.data + } + + if obj.metadata: + metadata_dict = {} + if obj.metadata.id: + metadata_dict["id"] = obj.metadata.id + if obj.metadata.user: + metadata_dict["user"] = obj.metadata.user + if obj.metadata.collection: + metadata_dict["collection"] = obj.metadata.collection + if obj.metadata.metadata: + metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata) + + result["metadata"] = metadata_dict + + return result + + +class TextDocumentTranslator(SendTranslator): + """Translator for TextDocument schema objects""" + + def __init__(self): + self.subgraph_translator = SubgraphTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> TextDocument: + metadata = data.get("metadata", []) + charset = data.get("charset", "utf-8") + + # Text is base64 encoded in input + text = base64.b64decode(data["text"]).decode(charset) + + from ...schema import Metadata + return TextDocument( + metadata=Metadata( + id=data.get("id"), + metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [], + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default"), + ), + text=text.encode("utf-8") + ) + + def from_pulsar(self, obj: TextDocument) -> Dict[str, Any]: + result = { + "text": obj.text.decode("utf-8") if isinstance(obj.text, bytes) else obj.text + } + + if obj.metadata: + metadata_dict = {} + if obj.metadata.id: + metadata_dict["id"] = obj.metadata.id + if obj.metadata.user: + metadata_dict["user"] = obj.metadata.user + if obj.metadata.collection: + metadata_dict["collection"] = obj.metadata.collection + if obj.metadata.metadata: + metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata) + + result["metadata"] = metadata_dict + + return result + + +class ChunkTranslator(SendTranslator): + """Translator for Chunk schema objects""" + + def __init__(self): + self.subgraph_translator = SubgraphTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> Chunk: + metadata = data.get("metadata", []) + + from ...schema import Metadata + return Chunk( + metadata=Metadata( + id=data.get("id"), + metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [], + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default"), + ), + chunk=data["chunk"].encode("utf-8") if isinstance(data["chunk"], str) else data["chunk"] + ) + + def from_pulsar(self, obj: Chunk) -> Dict[str, Any]: + result = { + "chunk": obj.chunk.decode("utf-8") if isinstance(obj.chunk, bytes) else obj.chunk + } + + if obj.metadata: + metadata_dict = {} + if obj.metadata.id: + metadata_dict["id"] = obj.metadata.id + if obj.metadata.user: + metadata_dict["user"] = obj.metadata.user + if obj.metadata.collection: + metadata_dict["collection"] = obj.metadata.collection + if obj.metadata.metadata: + metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata) + + result["metadata"] = metadata_dict + + return result + + +class DocumentEmbeddingsTranslator(SendTranslator): + """Translator for DocumentEmbeddings schema objects""" + + def __init__(self): + self.subgraph_translator = SubgraphTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddings: + metadata = data.get("metadata", {}) + + chunks = [ + ChunkEmbeddings( + chunk=chunk["chunk"].encode("utf-8") if isinstance(chunk["chunk"], str) else chunk["chunk"], + vectors=chunk["vectors"] + ) + for chunk in data.get("chunks", []) + ] + + from ...schema import Metadata + return DocumentEmbeddings( + metadata=Metadata( + id=metadata.get("id"), + metadata=self.subgraph_translator.to_pulsar(metadata.get("metadata", [])), + user=metadata.get("user", "trustgraph"), + collection=metadata.get("collection", "default"), + ), + chunks=chunks + ) + + def from_pulsar(self, obj: DocumentEmbeddings) -> Dict[str, Any]: + result = { + "chunks": [ + { + "chunk": chunk.chunk.decode("utf-8") if isinstance(chunk.chunk, bytes) else chunk.chunk, + "vectors": chunk.vectors + } + for chunk in obj.chunks + ] + } + + if obj.metadata: + metadata_dict = {} + if obj.metadata.id: + metadata_dict["id"] = obj.metadata.id + if obj.metadata.user: + metadata_dict["user"] = obj.metadata.user + if obj.metadata.collection: + metadata_dict["collection"] = obj.metadata.collection + if obj.metadata.metadata: + metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata) + + result["metadata"] = metadata_dict + + return result \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings.py b/trustgraph-base/trustgraph/messaging/translators/embeddings.py new file mode 100644 index 00000000..7e6eff83 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings.py @@ -0,0 +1,33 @@ +from typing import Dict, Any, Tuple +from ...schema import EmbeddingsRequest, EmbeddingsResponse +from .base import MessageTranslator + + +class EmbeddingsRequestTranslator(MessageTranslator): + """Translator for EmbeddingsRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsRequest: + return EmbeddingsRequest( + text=data["text"] + ) + + def from_pulsar(self, obj: EmbeddingsRequest) -> Dict[str, Any]: + return { + "text": obj.text + } + + +class EmbeddingsResponseTranslator(MessageTranslator): + """Translator for EmbeddingsResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: EmbeddingsResponse) -> Dict[str, Any]: + return { + "vectors": obj.vectors + } + + def from_response_with_completion(self, obj: EmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py new file mode 100644 index 00000000..d69e7bef --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/embeddings_query.py @@ -0,0 +1,94 @@ +from typing import Dict, Any, Tuple +from ...schema import ( + DocumentEmbeddingsRequest, DocumentEmbeddingsResponse, + GraphEmbeddingsRequest, GraphEmbeddingsResponse +) +from .base import MessageTranslator +from .primitives import ValueTranslator + + +class DocumentEmbeddingsRequestTranslator(MessageTranslator): + """Translator for DocumentEmbeddingsRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest: + return DocumentEmbeddingsRequest( + vectors=data["vectors"], + limit=int(data.get("limit", 10)), + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default") + ) + + def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]: + return { + "vectors": obj.vectors, + "limit": obj.limit, + "user": obj.user, + "collection": obj.collection + } + + +class DocumentEmbeddingsResponseTranslator(MessageTranslator): + """Translator for DocumentEmbeddingsResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]: + result = {} + + if obj.documents: + result["documents"] = [ + doc.decode("utf-8") if isinstance(doc, bytes) else doc + for doc in obj.documents + ] + + return result + + def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True + + +class GraphEmbeddingsRequestTranslator(MessageTranslator): + """Translator for GraphEmbeddingsRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest: + return GraphEmbeddingsRequest( + vectors=data["vectors"], + limit=int(data.get("limit", 10)), + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default") + ) + + def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]: + return { + "vectors": obj.vectors, + "limit": obj.limit, + "user": obj.user, + "collection": obj.collection + } + + +class GraphEmbeddingsResponseTranslator(MessageTranslator): + """Translator for GraphEmbeddingsResponse schema objects""" + + def __init__(self): + self.value_translator = ValueTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]: + result = {} + + if obj.entities: + result["entities"] = [ + self.value_translator.from_pulsar(entity) + for entity in obj.entities + ] + + return result + + def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/flow.py b/trustgraph-base/trustgraph/messaging/translators/flow.py new file mode 100644 index 00000000..212a9992 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/flow.py @@ -0,0 +1,59 @@ +from typing import Dict, Any, Tuple +from ...schema import FlowRequest, FlowResponse +from .base import MessageTranslator + + +class FlowRequestTranslator(MessageTranslator): + """Translator for FlowRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> FlowRequest: + return FlowRequest( + operation=data.get("operation"), + class_name=data.get("class-name"), + class_definition=data.get("class-definition"), + description=data.get("description"), + flow_id=data.get("flow-id") + ) + + def from_pulsar(self, obj: FlowRequest) -> Dict[str, Any]: + result = {} + + if obj.operation: + result["operation"] = obj.operation + if obj.class_name: + result["class-name"] = obj.class_name + if obj.class_definition: + result["class-definition"] = obj.class_definition + if obj.description: + result["description"] = obj.description + if obj.flow_id: + result["flow-id"] = obj.flow_id + + return result + + +class FlowResponseTranslator(MessageTranslator): + """Translator for FlowResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> FlowResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: FlowResponse) -> Dict[str, Any]: + result = {} + + if obj.class_names: + result["class-names"] = obj.class_names + if obj.flow_ids: + result["flow-ids"] = obj.flow_ids + if obj.class_definition: + result["class-definition"] = obj.class_definition + if obj.flow: + result["flow"] = obj.flow + if obj.description: + result["description"] = obj.description + + return result + + def from_response_with_completion(self, obj: FlowResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/knowledge.py b/trustgraph-base/trustgraph/messaging/translators/knowledge.py new file mode 100644 index 00000000..5377cbd4 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/knowledge.py @@ -0,0 +1,183 @@ +from typing import Dict, Any, Tuple, Optional +from ...schema import ( + KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings, + Metadata, EntityEmbeddings +) +from .base import MessageTranslator +from .primitives import ValueTranslator, SubgraphTranslator +from .metadata import DocumentMetadataTranslator + + +class KnowledgeRequestTranslator(MessageTranslator): + """Translator for KnowledgeRequest schema objects""" + + def __init__(self): + self.value_translator = ValueTranslator() + self.subgraph_translator = SubgraphTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> KnowledgeRequest: + triples = None + if "triples" in data: + triples = Triples( + metadata=Metadata( + id=data["triples"]["metadata"]["id"], + metadata=self.subgraph_translator.to_pulsar( + data["triples"]["metadata"]["metadata"] + ), + user=data["triples"]["metadata"]["user"], + collection=data["triples"]["metadata"]["collection"] + ), + triples=self.subgraph_translator.to_pulsar(data["triples"]["triples"]), + ) + + graph_embeddings = None + if "graph-embeddings" in data: + graph_embeddings = GraphEmbeddings( + metadata=Metadata( + id=data["graph-embeddings"]["metadata"]["id"], + metadata=self.subgraph_translator.to_pulsar( + data["graph-embeddings"]["metadata"]["metadata"] + ), + user=data["graph-embeddings"]["metadata"]["user"], + collection=data["graph-embeddings"]["metadata"]["collection"] + ), + entities=[ + EntityEmbeddings( + entity=self.value_translator.to_pulsar(ent["entity"]), + vectors=ent["vectors"], + ) + for ent in data["graph-embeddings"]["entities"] + ] + ) + + return KnowledgeRequest( + operation=data.get("operation"), + user=data.get("user"), + id=data.get("id"), + flow=data.get("flow"), + collection=data.get("collection"), + triples=triples, + graph_embeddings=graph_embeddings, + ) + + def from_pulsar(self, obj: KnowledgeRequest) -> Dict[str, Any]: + result = {} + + if obj.operation: + result["operation"] = obj.operation + if obj.user: + result["user"] = obj.user + if obj.id: + result["id"] = obj.id + if obj.flow: + result["flow"] = obj.flow + if obj.collection: + result["collection"] = obj.collection + + if obj.triples: + result["triples"] = { + "metadata": { + "id": obj.triples.metadata.id, + "metadata": self.subgraph_translator.from_pulsar( + obj.triples.metadata.metadata + ), + "user": obj.triples.metadata.user, + "collection": obj.triples.metadata.collection, + }, + "triples": self.subgraph_translator.from_pulsar(obj.triples.triples), + } + + if obj.graph_embeddings: + result["graph-embeddings"] = { + "metadata": { + "id": obj.graph_embeddings.metadata.id, + "metadata": self.subgraph_translator.from_pulsar( + obj.graph_embeddings.metadata.metadata + ), + "user": obj.graph_embeddings.metadata.user, + "collection": obj.graph_embeddings.metadata.collection, + }, + "entities": [ + { + "vectors": entity.vectors, + "entity": self.value_translator.from_pulsar(entity.entity), + } + for entity in obj.graph_embeddings.entities + ], + } + + return result + + +class KnowledgeResponseTranslator(MessageTranslator): + """Translator for KnowledgeResponse schema objects""" + + def __init__(self): + self.value_translator = ValueTranslator() + self.subgraph_translator = SubgraphTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> KnowledgeResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: KnowledgeResponse) -> Dict[str, Any]: + # Response to list operation + if obj.ids is not None: + return {"ids": obj.ids} + + # Streaming triples response + if obj.triples: + return { + "triples": { + "metadata": { + "id": obj.triples.metadata.id, + "metadata": self.subgraph_translator.from_pulsar( + obj.triples.metadata.metadata + ), + "user": obj.triples.metadata.user, + "collection": obj.triples.metadata.collection, + }, + "triples": self.subgraph_translator.from_pulsar(obj.triples.triples), + } + } + + # Streaming graph embeddings response + if obj.graph_embeddings: + return { + "graph-embeddings": { + "metadata": { + "id": obj.graph_embeddings.metadata.id, + "metadata": self.subgraph_translator.from_pulsar( + obj.graph_embeddings.metadata.metadata + ), + "user": obj.graph_embeddings.metadata.user, + "collection": obj.graph_embeddings.metadata.collection, + }, + "entities": [ + { + "vectors": entity.vectors, + "entity": self.value_translator.from_pulsar(entity.entity), + } + for entity in obj.graph_embeddings.entities + ], + } + } + + # End of stream marker + if obj.eos is True: + return {"eos": True} + + # Empty response (successful delete) + return {} + + def from_response_with_completion(self, obj: KnowledgeResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + response = self.from_pulsar(obj) + + # Check if this is a final response + is_final = ( + obj.ids is not None or # List response + obj.eos is True or # End of stream + (not obj.triples and not obj.graph_embeddings) # Empty response + ) + + return response, is_final \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/library.py b/trustgraph-base/trustgraph/messaging/translators/library.py new file mode 100644 index 00000000..fc355dda --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/library.py @@ -0,0 +1,124 @@ +from typing import Dict, Any, Tuple, Optional +from ...schema import LibrarianRequest, LibrarianResponse, DocumentMetadata, ProcessingMetadata, Criteria +from .base import MessageTranslator +from .metadata import DocumentMetadataTranslator, ProcessingMetadataTranslator + + +class LibraryRequestTranslator(MessageTranslator): + """Translator for LibrarianRequest schema objects""" + + def __init__(self): + self.doc_metadata_translator = DocumentMetadataTranslator() + self.proc_metadata_translator = ProcessingMetadataTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> LibrarianRequest: + # Document metadata + doc_metadata = None + if "document-metadata" in data: + doc_metadata = self.doc_metadata_translator.to_pulsar(data["document-metadata"]) + + # Processing metadata + proc_metadata = None + if "processing-metadata" in data: + proc_metadata = self.proc_metadata_translator.to_pulsar(data["processing-metadata"]) + + # Criteria + criteria = [] + if "criteria" in data: + criteria = [ + Criteria( + key=c["key"], + value=c["value"], + operator=c["operator"] + ) + for c in data["criteria"] + ] + + # Content as bytes + content = None + if "content" in data: + if isinstance(data["content"], str): + content = data["content"].encode("utf-8") + else: + content = data["content"] + + return LibrarianRequest( + operation=data.get("operation"), + document_id=data.get("document-id"), + processing_id=data.get("processing-id"), + document_metadata=doc_metadata, + processing_metadata=proc_metadata, + content=content, + user=data.get("user"), + collection=data.get("collection"), + criteria=criteria + ) + + def from_pulsar(self, obj: LibrarianRequest) -> Dict[str, Any]: + result = {} + + if obj.operation: + result["operation"] = obj.operation + if obj.document_id: + result["document-id"] = obj.document_id + if obj.processing_id: + result["processing-id"] = obj.processing_id + if obj.document_metadata: + result["document-metadata"] = self.doc_metadata_translator.from_pulsar(obj.document_metadata) + if obj.processing_metadata: + result["processing-metadata"] = self.proc_metadata_translator.from_pulsar(obj.processing_metadata) + if obj.content: + result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content + if obj.user: + result["user"] = obj.user + if obj.collection: + result["collection"] = obj.collection + if obj.criteria is not None: + result["criteria"] = [ + { + "key": c.key, + "value": c.value, + "operator": c.operator + } + for c in obj.criteria + ] + + return result + + +class LibraryResponseTranslator(MessageTranslator): + """Translator for LibrarianResponse schema objects""" + + def __init__(self): + self.doc_metadata_translator = DocumentMetadataTranslator() + self.proc_metadata_translator = ProcessingMetadataTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> LibrarianResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: LibrarianResponse) -> Dict[str, Any]: + result = {} + + if obj.document_metadata: + result["document-metadata"] = self.doc_metadata_translator.from_pulsar(obj.document_metadata) + + if obj.content: + result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content + + if obj.document_metadatas is not None: + result["document-metadatas"] = [ + self.doc_metadata_translator.from_pulsar(dm) + for dm in obj.document_metadatas + ] + + if obj.processing_metadatas is not None: + result["processing-metadatas"] = [ + self.proc_metadata_translator.from_pulsar(pm) + for pm in obj.processing_metadatas + ] + + return result + + def from_response_with_completion(self, obj: LibrarianResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True diff --git a/trustgraph-base/trustgraph/messaging/translators/metadata.py b/trustgraph-base/trustgraph/messaging/translators/metadata.py new file mode 100644 index 00000000..006b222c --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/metadata.py @@ -0,0 +1,81 @@ +from typing import Dict, Any, Optional +from ...schema import DocumentMetadata, ProcessingMetadata +from .base import Translator +from .primitives import SubgraphTranslator + + +class DocumentMetadataTranslator(Translator): + """Translator for DocumentMetadata schema objects""" + + def __init__(self): + self.subgraph_translator = SubgraphTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> DocumentMetadata: + metadata = data.get("metadata", []) + return DocumentMetadata( + id=data.get("id"), + time=data.get("time"), + kind=data.get("kind"), + title=data.get("title"), + comments=data.get("comments"), + metadata=self.subgraph_translator.to_pulsar(metadata) if metadata is not None else [], + user=data.get("user"), + tags=data.get("tags") + ) + + def from_pulsar(self, obj: DocumentMetadata) -> Dict[str, Any]: + result = {} + + if obj.id: + result["id"] = obj.id + if obj.time: + result["time"] = obj.time + if obj.kind: + result["kind"] = obj.kind + if obj.title: + result["title"] = obj.title + if obj.comments: + result["comments"] = obj.comments + if obj.metadata is not None: + result["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata) + if obj.user: + result["user"] = obj.user + if obj.tags is not None: + result["tags"] = obj.tags + + return result + + +class ProcessingMetadataTranslator(Translator): + """Translator for ProcessingMetadata schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> ProcessingMetadata: + return ProcessingMetadata( + id=data.get("id"), + document_id=data.get("document-id"), + time=data.get("time"), + flow=data.get("flow"), + user=data.get("user"), + collection=data.get("collection"), + tags=data.get("tags") + ) + + def from_pulsar(self, obj: ProcessingMetadata) -> Dict[str, Any]: + result = {} + + if obj.id: + result["id"] = obj.id + if obj.document_id: + result["document-id"] = obj.document_id + if obj.time: + result["time"] = obj.time + if obj.flow: + result["flow"] = obj.flow + if obj.user: + result["user"] = obj.user + if obj.collection: + result["collection"] = obj.collection + if obj.tags is not None: + result["tags"] = obj.tags + + return result diff --git a/trustgraph-base/trustgraph/messaging/translators/primitives.py b/trustgraph-base/trustgraph/messaging/translators/primitives.py new file mode 100644 index 00000000..6b57aec4 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/primitives.py @@ -0,0 +1,47 @@ +from typing import Dict, Any, List +from ...schema import Value, Triple +from .base import Translator + + +class ValueTranslator(Translator): + """Translator for Value schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> Value: + return Value(value=data["v"], is_uri=data["e"]) + + def from_pulsar(self, obj: Value) -> Dict[str, Any]: + return {"v": obj.value, "e": obj.is_uri} + + +class TripleTranslator(Translator): + """Translator for Triple schema objects""" + + def __init__(self): + self.value_translator = ValueTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> Triple: + return Triple( + s=self.value_translator.to_pulsar(data["s"]), + p=self.value_translator.to_pulsar(data["p"]), + o=self.value_translator.to_pulsar(data["o"]) + ) + + def from_pulsar(self, obj: Triple) -> Dict[str, Any]: + return { + "s": self.value_translator.from_pulsar(obj.s), + "p": self.value_translator.from_pulsar(obj.p), + "o": self.value_translator.from_pulsar(obj.o) + } + + +class SubgraphTranslator(Translator): + """Translator for lists of Triple objects (subgraphs)""" + + def __init__(self): + self.triple_translator = TripleTranslator() + + def to_pulsar(self, data: List[Dict[str, Any]]) -> List[Triple]: + return [self.triple_translator.to_pulsar(t) for t in data] + + def from_pulsar(self, obj: List[Triple]) -> List[Dict[str, Any]]: + return [self.triple_translator.from_pulsar(t) for t in obj] \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/prompt.py b/trustgraph-base/trustgraph/messaging/translators/prompt.py new file mode 100644 index 00000000..b0e7351f --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/prompt.py @@ -0,0 +1,54 @@ +import json +from typing import Dict, Any, Tuple +from ...schema import PromptRequest, PromptResponse +from .base import MessageTranslator + + +class PromptRequestTranslator(MessageTranslator): + """Translator for PromptRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> PromptRequest: + # Handle both "terms" and "variables" input keys + terms = data.get("terms", {}) + if "variables" in data: + # Convert variables to JSON strings as expected by the service + terms = { + k: json.dumps(v) + for k, v in data["variables"].items() + } + + return PromptRequest( + id=data.get("id"), + terms=terms + ) + + def from_pulsar(self, obj: PromptRequest) -> Dict[str, Any]: + result = {} + + if obj.id: + result["id"] = obj.id + if obj.terms: + result["terms"] = obj.terms + + return result + + +class PromptResponseTranslator(MessageTranslator): + """Translator for PromptResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> PromptResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]: + result = {} + + if obj.text: + result["text"] = obj.text + if obj.object: + result["object"] = obj.object + + return result + + def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/retrieval.py b/trustgraph-base/trustgraph/messaging/translators/retrieval.py new file mode 100644 index 00000000..96c25ed8 --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/retrieval.py @@ -0,0 +1,81 @@ +from typing import Dict, Any, Tuple +from ...schema import DocumentRagQuery, DocumentRagResponse, GraphRagQuery, GraphRagResponse +from .base import MessageTranslator + + +class DocumentRagRequestTranslator(MessageTranslator): + """Translator for DocumentRagQuery schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagQuery: + return DocumentRagQuery( + query=data["query"], + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default"), + doc_limit=int(data.get("doc-limit", 20)) + ) + + def from_pulsar(self, obj: DocumentRagQuery) -> Dict[str, Any]: + return { + "query": obj.query, + "user": obj.user, + "collection": obj.collection, + "doc-limit": obj.doc_limit + } + + +class DocumentRagResponseTranslator(MessageTranslator): + """Translator for DocumentRagResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]: + return { + "response": obj.response + } + + def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True + + +class GraphRagRequestTranslator(MessageTranslator): + """Translator for GraphRagQuery schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> GraphRagQuery: + return GraphRagQuery( + query=data["query"], + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default"), + entity_limit=int(data.get("entity-limit", 50)), + triple_limit=int(data.get("triple-limit", 30)), + max_subgraph_size=int(data.get("max-subgraph-size", 1000)), + max_path_length=int(data.get("max-path-length", 2)) + ) + + def from_pulsar(self, obj: GraphRagQuery) -> Dict[str, Any]: + return { + "query": obj.query, + "user": obj.user, + "collection": obj.collection, + "entity-limit": obj.entity_limit, + "triple-limit": obj.triple_limit, + "max-subgraph-size": obj.max_subgraph_size, + "max-path-length": obj.max_path_length + } + + +class GraphRagResponseTranslator(MessageTranslator): + """Translator for GraphRagResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> GraphRagResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]: + return { + "response": obj.response + } + + def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/text_completion.py b/trustgraph-base/trustgraph/messaging/translators/text_completion.py new file mode 100644 index 00000000..eda3be5d --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/text_completion.py @@ -0,0 +1,42 @@ +from typing import Dict, Any, Tuple +from ...schema import TextCompletionRequest, TextCompletionResponse +from .base import MessageTranslator + + +class TextCompletionRequestTranslator(MessageTranslator): + """Translator for TextCompletionRequest schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> TextCompletionRequest: + return TextCompletionRequest( + system=data["system"], + prompt=data["prompt"] + ) + + def from_pulsar(self, obj: TextCompletionRequest) -> Dict[str, Any]: + return { + "system": obj.system, + "prompt": obj.prompt + } + + +class TextCompletionResponseTranslator(MessageTranslator): + """Translator for TextCompletionResponse schema objects""" + + def to_pulsar(self, data: Dict[str, Any]) -> TextCompletionResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: TextCompletionResponse) -> Dict[str, Any]: + result = {"response": obj.response} + + if obj.in_token: + result["in_token"] = obj.in_token + if obj.out_token: + result["out_token"] = obj.out_token + if obj.model: + result["model"] = obj.model + + return result + + def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True \ No newline at end of file diff --git a/trustgraph-base/trustgraph/messaging/translators/triples.py b/trustgraph-base/trustgraph/messaging/translators/triples.py new file mode 100644 index 00000000..1c08625b --- /dev/null +++ b/trustgraph-base/trustgraph/messaging/translators/triples.py @@ -0,0 +1,60 @@ +from typing import Dict, Any, Tuple, Optional +from ...schema import TriplesQueryRequest, TriplesQueryResponse +from .base import MessageTranslator +from .primitives import ValueTranslator, SubgraphTranslator + + +class TriplesQueryRequestTranslator(MessageTranslator): + """Translator for TriplesQueryRequest schema objects""" + + def __init__(self): + self.value_translator = ValueTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> TriplesQueryRequest: + s = self.value_translator.to_pulsar(data["s"]) if "s" in data else None + p = self.value_translator.to_pulsar(data["p"]) if "p" in data else None + o = self.value_translator.to_pulsar(data["o"]) if "o" in data else None + + return TriplesQueryRequest( + s=s, + p=p, + o=o, + limit=int(data.get("limit", 10000)), + user=data.get("user", "trustgraph"), + collection=data.get("collection", "default") + ) + + def from_pulsar(self, obj: TriplesQueryRequest) -> Dict[str, Any]: + result = { + "limit": obj.limit, + "user": obj.user, + "collection": obj.collection + } + + if obj.s: + result["s"] = self.value_translator.from_pulsar(obj.s) + if obj.p: + result["p"] = self.value_translator.from_pulsar(obj.p) + if obj.o: + result["o"] = self.value_translator.from_pulsar(obj.o) + + return result + + +class TriplesQueryResponseTranslator(MessageTranslator): + """Translator for TriplesQueryResponse schema objects""" + + def __init__(self): + self.subgraph_translator = SubgraphTranslator() + + def to_pulsar(self, data: Dict[str, Any]) -> TriplesQueryResponse: + raise NotImplementedError("Response translation to Pulsar not typically needed") + + def from_pulsar(self, obj: TriplesQueryResponse) -> Dict[str, Any]: + return { + "response": self.subgraph_translator.from_pulsar(obj.triples) + } + + def from_response_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]: + """Returns (response_dict, is_final)""" + return self.from_pulsar(obj), True \ No newline at end of file diff --git a/trustgraph-cli/scripts/tg-load-turtle b/trustgraph-cli/scripts/tg-load-turtle index ba92067e..f10fd760 100755 --- a/trustgraph-cli/scripts/tg-load-turtle +++ b/trustgraph-cli/scripts/tg-load-turtle @@ -1,80 +1,68 @@ #!/usr/bin/env python3 """ -Loads Graph embeddings into TrustGraph processing. - -FIXME: This hasn't been updated following API gateway change. +Loads triples into the knowledge graph. """ -import pulsar -from pulsar.schema import JsonSchema -from trustgraph.schema import Triples, Triple, Value, Metadata +import asyncio import argparse import os import time -import pyarrow as pa import rdflib +import json +from websockets.asyncio.client import connect from trustgraph.log_level import LogLevel +default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/') default_user = 'trustgraph' default_collection = 'default' -default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') -default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None) - -default_output_queue = triples_store_queue class Loader: def __init__( self, - pulsar_host, - output_queue, - log_level, files, + flow, user, collection, - pulsar_api_key=None, + document_id, + url = default_url, ): - if pulsar_api_key: - auth = pulsar.AuthenticationToken(pulsar_api_key) - self.client = pulsar.Client( - pulsar_host, - authentication=auth, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) - else: - self.client = pulsar.Client( - pulsar_host, - logger=pulsar.ConsoleLogger(log_level.to_pulsar()) - ) + if not url.endswith("/"): + url += "/" - self.producer = self.client.create_producer( - topic=output_queue, - schema=JsonSchema(Triples), - chunking_enabled=True, - ) + url = url + f"api/v1/flow/{flow}/import/triples" + + self.url = url self.files = files self.user = user self.collection = collection + self.document_id = document_id - def run(self): + async def run(self): try: - for file in self.files: - self.load_file(file) + async with connect(self.url) as ws: + for file in self.files: + await self.load_file(file, ws) except Exception as e: print(e, flush=True) - def load_file(self, file): + async def load_file(self, file, ws): g = rdflib.Graph() g.parse(file, format="turtle") + def Value(value, is_uri): + return { "v": value, "e": is_uri } + + triples = [] + for e in g: s = Value(value=str(e[0]), is_uri=True) p = Value(value=str(e[1]), is_uri=True) @@ -83,20 +71,23 @@ class Loader: else: o = Value(value=str(e[2]), is_uri=False) - r = Triples( - metadata=Metadata( - id=None, - metadata=[], - user=self.user, - collection=self.collection, - ), - triples=[ Triple(s=s, p=p, o=o) ] - ) + req = { + "metadata": { + "id": self.document_id, + "metadata": [], + "user": self.user, + "collection": self.collection + }, + "triples": [ + { + "s": s, + "p": p, + "o": o, + } + ] + } - self.producer.send(r) - - def __del__(self): - self.client.close() + await ws.send(json.dumps(req)) def main(): @@ -106,9 +97,15 @@ def main(): ) parser.add_argument( - '-p', '--pulsar-host', - default=default_pulsar_host, - help=f'Pulsar host (default: {default_pulsar_host})', + '-u', '--api-url', + default=default_url, + help=f'API URL (default: {default_url})', + ) + + parser.add_argument( + '-i', '--document-id', + required=True, + help=f'Document ID)', ) parser.add_argument( @@ -116,39 +113,19 @@ def main(): default="default", help=f'Flow ID (default: default)' ) - - parser.add_argument( - '--pulsar-api-key', - default=default_pulsar_api_key, - help=f'Pulsar API key', - ) parser.add_argument( - '-o', '--output-queue', - default=default_output_queue, - help=f'Output queue (default: {default_output_queue})' - ) - - parser.add_argument( - '-u', '--user', + '-U', '--user', default=default_user, help=f'User ID (default: {default_user})' ) parser.add_argument( - '-c', '--collection', + '-C', '--collection', default=default_collection, help=f'Collection ID (default: {default_collection})' ) - parser.add_argument( - '-l', '--log-level', - type=LogLevel, - default=LogLevel.ERROR, - choices=list(LogLevel), - help=f'Output queue (default: info)' - ) - parser.add_argument( 'files', nargs='+', help=f'Turtle files to load' @@ -160,16 +137,15 @@ def main(): try: p = Loader( - pulsar_host=args.pulsar_host, - pulsar_api_key=args.pulsar_api_key, - output_queue=args.output_queue, - log_level=args.log_level, - files=args.files, - user=args.user, - collection=args.collection, + document_id = args.document_id, + url = args.api_url, + flow = args.flow_id, + files = args.files, + user = args.user, + collection = args.collection, ) - p.run() + asyncio.run(p.run()) print("File loaded.") break @@ -181,6 +157,5 @@ def main(): time.sleep(10) -print("Not implemented.") -#main() +main() diff --git a/trustgraph-cli/scripts/tg-show-token-rate b/trustgraph-cli/scripts/tg-show-token-rate new file mode 100755 index 00000000..800569e5 --- /dev/null +++ b/trustgraph-cli/scripts/tg-show-token-rate @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 + +""" +Dump out a stream of token rates, input, output and total. This is averaged +across the time since tg-show-token-rate is started. +""" + +import requests +import argparse +import json +import time + +default_metrics_url = "http://localhost:8088/api/metrics" + +class Collate: + + def look(self, data): + return sum( + [ + float(x["value"][1]) + for x in data["data"]["result"] + ] + ) + + def __init__(self, data): + self.last = self.look(data) + self.total = 0 + self.time = 0 + + def record(self, data, time): + + value = self.look(data) + delta = value - self.last + self.last = value + + self.total += delta + self.time += time + + return delta/time, self.total/self.time + +def dump_status(metrics_url, number_samples, period): + + input_url = f"{metrics_url}/query?query=input_tokens_total" + output_url = f"{metrics_url}/query?query=output_tokens_total" + + resp = requests.get(input_url) + obj = resp.json() + input = Collate(obj) + + resp = requests.get(output_url) + obj = resp.json() + output = Collate(obj) + + print(f"{'Input':>10s} {'Output':>10s} {'Total':>10s}") + print(f"{'-----':>10s} {'------':>10s} {'-----':>10s}") + + for i in range(number_samples): + + time.sleep(period) + + resp = requests.get(input_url) + obj = resp.json() + inr, inl = input.record(obj, period) + + resp = requests.get(output_url) + obj = resp.json() + outr, outl = output.record(obj, period) + + print(f"{inl:10.1f} {outl:10.1f} {inl+outl:10.1f}") + +def main(): + + parser = argparse.ArgumentParser( + prog='tg-show-processor-state', + description=__doc__, + ) + + parser.add_argument( + '-m', '--metrics-url', + default=default_metrics_url, + help=f'Metrics URL (default: {default_metrics_url})', + ) + + parser.add_argument( + '-p', '--period', + type=int, + default=1, + help=f'Metrics period (default: 1)', + ) + + parser.add_argument( + '-n', '--number-samples', + type=int, + default=100, + help=f'Metrics period (default: 100)', + ) + + args = parser.parse_args() + + try: + + dump_status(**vars(args)) + + except Exception as e: + + print("Exception:", e, flush=True) + +main() + diff --git a/trustgraph-cli/setup.py b/trustgraph-cli/setup.py index ed26511e..147b1807 100644 --- a/trustgraph-cli/setup.py +++ b/trustgraph-cli/setup.py @@ -80,6 +80,7 @@ setuptools.setup( "scripts/tg-show-processor-state", "scripts/tg-show-prompts", "scripts/tg-show-token-costs", + "scripts/tg-show-token-rate", "scripts/tg-show-tools", "scripts/tg-start-flow", "scripts/tg-unload-kg-core", diff --git a/trustgraph-flow/scripts/rev-gateway b/trustgraph-flow/scripts/rev-gateway new file mode 100755 index 00000000..708c6c96 --- /dev/null +++ b/trustgraph-flow/scripts/rev-gateway @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.rev_gateway import run + +run() + diff --git a/trustgraph-flow/scripts/text-completion-vllm b/trustgraph-flow/scripts/text-completion-vllm new file mode 100755 index 00000000..e24c076a --- /dev/null +++ b/trustgraph-flow/scripts/text-completion-vllm @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from trustgraph.model.text_completion.vllm import run + +run() + diff --git a/trustgraph-flow/setup.py b/trustgraph-flow/setup.py index a4d4f7a0..562c5389 100644 --- a/trustgraph-flow/setup.py +++ b/trustgraph-flow/setup.py @@ -71,6 +71,7 @@ setuptools.setup( scripts=[ "scripts/agent-manager-react", "scripts/api-gateway", + "scripts/rev-gateway", "scripts/chunker-recursive", "scripts/chunker-token", "scripts/config-svc", @@ -118,6 +119,7 @@ setuptools.setup( "scripts/text-completion-ollama", "scripts/text-completion-openai", "scripts/text-completion-tgi", + "scripts/text-completion-vllm", "scripts/triples-query-cassandra", "scripts/triples-query-falkordb", "scripts/triples-query-memgraph", diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index f95dadf9..66571478 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -21,16 +21,19 @@ RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True) SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) default_ident = "kg-extract-definitions" +default_concurrency = 1 class Processor(FlowProcessor): def __init__(self, **params): id = params.get("id") + concurrency = params.get("concurrency", 1) super(Processor, self).__init__( **params | { "id": id, + "concurrency": concurrency, } ) @@ -38,7 +41,8 @@ class Processor(FlowProcessor): ConsumerSpec( name = "input", schema = Chunk, - handler = self.on_message + handler = self.on_message, + concurrency = concurrency, ) ) @@ -190,6 +194,13 @@ class Processor(FlowProcessor): @staticmethod def add_args(parser): + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Concurrent processing threads (default: {default_concurrency})' + ) + FlowProcessor.add_args(parser) def run(): diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index 63670a7d..dafee77d 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -20,16 +20,19 @@ RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True) SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True) default_ident = "kg-extract-relationships" +default_concurrency = 1 class Processor(FlowProcessor): def __init__(self, **params): id = params.get("id") + concurrency = params.get("concurrency", 1) super(Processor, self).__init__( **params | { "id": id, + "concurrency": concurrency, } ) @@ -37,7 +40,8 @@ class Processor(FlowProcessor): ConsumerSpec( name = "input", schema = Chunk, - handler = self.on_message + handler = self.on_message, + concurrency = concurrency, ) ) @@ -192,6 +196,13 @@ class Processor(FlowProcessor): @staticmethod def add_args(parser): + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Concurrent processing threads (default: {default_concurrency})' + ) + FlowProcessor.add_args(parser) def run(): diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/agent.py b/trustgraph-flow/trustgraph/gateway/dispatch/agent.py index d0fd8537..1a5e8299 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/agent.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/agent.py @@ -1,5 +1,6 @@ from ... schema import AgentRequest, AgentResponse +from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor @@ -20,24 +21,12 @@ class AgentRequestor(ServiceRequestor): timeout=timeout, ) + self.request_translator = TranslatorRegistry.get_request_translator("agent") + self.response_translator = TranslatorRegistry.get_response_translator("agent") + def to_request(self, body): - return AgentRequest( - question=body["question"] - ) + return self.request_translator.to_pulsar(body) def from_response(self, message): - resp = { - } - - if message.answer: - resp["answer"] = message.answer - - if message.thought: - resp["thought"] = message.thought - - if message.observation: - resp["observation"] = message.observation - - # The 2nd boolean expression indicates whether we're done responding - return resp, (message.answer is not None) + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/config.py b/trustgraph-flow/trustgraph/gateway/dispatch/config.py index 3aeedb6f..c4fac5fa 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/config.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/config.py @@ -2,6 +2,7 @@ from ... schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue from ... schema import config_request_queue from ... schema import config_response_queue +from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor @@ -19,60 +20,12 @@ class ConfigRequestor(ServiceRequestor): timeout=timeout, ) + self.request_translator = TranslatorRegistry.get_request_translator("config") + self.response_translator = TranslatorRegistry.get_response_translator("config") + def to_request(self, body): - - if "keys" in body: - keys = [ - ConfigKey( - type = k["type"], - key = k["key"], - ) - for k in body["keys"] - ] - else: - keys = None - - if "values" in body: - values = [ - ConfigValue( - type = v["type"], - key = v["key"], - value = v["value"], - ) - for v in body["values"] - ] - else: - values = None - - return ConfigRequest( - operation = body.get("operation", None), - keys = keys, - type = body.get("type", None), - values = values - ) + return self.request_translator.to_pulsar(body) def from_response(self, message): - - response = { } - - if message.version is not None: - response["version"] = message.version - - if message.values is not None: - response["values"] = [ - { - "type": v.type, - "key": v.key, - "value": v.value, - } - for v in message.values - ] - - if message.directory is not None: - response["directory"] = message.directory - - if message.config is not None: - response["config"] = message.config - - return response, True + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py new file mode 100644 index 00000000..941ce5d8 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_export.py @@ -0,0 +1,96 @@ + +import asyncio +import uuid +import msgpack +from . knowledge import KnowledgeRequestor + +class CoreExport: + + def __init__(self, pulsar_client): + self.pulsar_client = pulsar_client + + async def process(self, data, error, ok, request): + + id = request.query["id"] + user = request.query["user"] + + response = await ok() + + kr = KnowledgeRequestor( + pulsar_client = self.pulsar_client, + consumer = "api-gateway-core-export-" + str(uuid.uuid4()), + subscriber = "api-gateway-core-export-" + str(uuid.uuid4()), + ) + + try: + + await kr.start() + + async def responder(resp, fin): + + if "graph-embeddings" in resp: + + data = resp["graph-embeddings"] + + msg = ( + "ge", + { + "m": { + "i": data["metadata"]["id"], + "m": data["metadata"]["metadata"], + "u": data["metadata"]["user"], + "c": data["metadata"]["collection"], + }, + "e": [ + { + "e": ent["entity"], + "v": ent["vectors"], + } + for ent in data["entities"] + ] + } + ) + + enc = msgpack.packb(msg) + await response.write(enc) + + if "triples" in resp: + + data = resp["triples"] + msg = ( + "t", + { + "m": { + "i": data["metadata"]["id"], + "m": data["metadata"]["metadata"], + "u": data["metadata"]["user"], + "c": data["metadata"]["collection"], + }, + "t": data["triples"], + } + ) + + enc = msgpack.packb(msg) + await response.write(enc) + + await kr.process( + { + "operation": "get-kg-core", + "user": user, + "id": id, + }, + responder + ) + + except Exception as e: + + print("Exception:", e) + + finally: + + await kr.stop() + + await response.write_eof() + + return response + diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py new file mode 100644 index 00000000..b819d286 --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/dispatch/core_import.py @@ -0,0 +1,94 @@ + +import asyncio +import json +import uuid +import msgpack +from . knowledge import KnowledgeRequestor + +class CoreImport: + + def __init__(self, pulsar_client): + self.pulsar_client = pulsar_client + + async def process(self, data, error, ok, request): + + id = request.query["id"] + user = request.query["user"] + + kr = KnowledgeRequestor( + pulsar_client = self.pulsar_client, + consumer = "api-gateway-core-import-" + str(uuid.uuid4()), + subscriber = "api-gateway-core-import-" + str(uuid.uuid4()), + ) + + await kr.start() + + try: + + unpacker = msgpack.Unpacker() + + while True: + buf = await data.read(128*1024) + if not buf: break + + unpacker.feed(buf) + + for unpacked in unpacker: + + if unpacked[0] == "t": + msg = unpacked[1] + msg = { + "operation": "put-kg-core", + "user": user, + "id": id, + "triples": { + "metadata": { + "id": id, + "metadata": msg["m"]["m"], + "user": user, + "collection": "default", # Not used? + }, + "triples": msg["t"], + } + } + + await kr.process(msg) + + elif unpacked[0] == "ge": + msg = unpacked[1] + msg = { + "operation": "put-kg-core", + "user": user, + "id": id, + "graph-embeddings": { + "metadata": { + "id": id, + "metadata": msg["m"]["m"], + "user": user, + "collection": "default", # Not used? + }, + "entities": [ + { + "entity": ent["e"], + "vectors": ent["v"], + } + for ent in msg["e"] + ] + } + } + + await kr.process(msg) + + except Exception as e: + print("Exception:", e) + await error(str(e)) + + finally: + + await kr.stop() + + print("All done.") + response = await ok() + await response.write_eof() + + return response diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py index e486f613..dd4fc4e1 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_embeddings_import.py @@ -6,8 +6,7 @@ from aiohttp import WSMsgType from ... schema import Metadata from ... schema import DocumentEmbeddings, ChunkEmbeddings from ... base import Publisher - -from . serialize import to_subgraph +from ... messaging.translators.document_loading import DocumentEmbeddingsTranslator class DocumentEmbeddingsImport: @@ -17,6 +16,7 @@ class DocumentEmbeddingsImport: self.ws = ws self.running = running + self.translator = DocumentEmbeddingsTranslator() self.publisher = Publisher( pulsar_client, topic = queue, schema = DocumentEmbeddings @@ -36,23 +36,7 @@ class DocumentEmbeddingsImport: async def receive(self, msg): data = msg.json() - - elt = DocumentEmbeddings( - metadata=Metadata( - id=data["metadata"]["id"], - metadata=to_subgraph(data["metadata"]["metadata"]), - user=data["metadata"]["user"], - collection=data["metadata"]["collection"], - ), - chunks=[ - ChunkEmbeddings( - chunk=de["chunk"].encode("utf-8"), - vectors=de["vectors"], - ) - for de in data["chunks"] - ], - ) - + elt = self.translator.to_pulsar(data) await self.publisher.send(None, elt) async def run(self): diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_load.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_load.py index f92fc34f..101e9b41 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_load.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_load.py @@ -2,9 +2,9 @@ import base64 from ... schema import Document, Metadata +from ... messaging import TranslatorRegistry from . sender import ServiceSender -from . serialize import to_subgraph class DocumentLoad(ServiceSender): def __init__(self, pulsar_client, queue): @@ -15,26 +15,9 @@ class DocumentLoad(ServiceSender): schema = Document, ) + self.translator = TranslatorRegistry.get_request_translator("document") + def to_request(self, body): - - if "metadata" in body: - metadata = to_subgraph(body["metadata"]) - else: - metadata = [] - - # Doing a base64 decoe/encode here to make sure the - # content is valid base64 - doc = base64.b64decode(body["data"]) - print("Document received") - - return Document( - metadata=Metadata( - id=body.get("id"), - metadata=metadata, - user=body.get("user", "trustgraph"), - collection=body.get("collection", "default"), - ), - data=base64.b64encode(doc).decode("utf-8") - ) + return self.translator.to_pulsar(body) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py b/trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py index 29194f97..a7f3634e 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/document_rag.py @@ -1,5 +1,6 @@ from ... schema import DocumentRagQuery, DocumentRagResponse +from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor @@ -20,14 +21,12 @@ class DocumentRagRequestor(ServiceRequestor): timeout=timeout, ) + self.request_translator = TranslatorRegistry.get_request_translator("document-rag") + self.response_translator = TranslatorRegistry.get_response_translator("document-rag") + def to_request(self, body): - return DocumentRagQuery( - query=body["query"], - user=body.get("user", "trustgraph"), - collection=body.get("collection", "default"), - doc_limit=int(body.get("doc-limit", 20)), - ) + return self.request_translator.to_pulsar(body) def from_response(self, message): - return { "response": message.response }, True + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py b/trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py index 4549942e..47146e57 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/embeddings.py @@ -1,5 +1,6 @@ from ... schema import EmbeddingsRequest, EmbeddingsResponse +from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor @@ -20,11 +21,12 @@ class EmbeddingsRequestor(ServiceRequestor): timeout=timeout, ) + self.request_translator = TranslatorRegistry.get_request_translator("embeddings") + self.response_translator = TranslatorRegistry.get_response_translator("embeddings") + def to_request(self, body): - return EmbeddingsRequest( - text=body["text"] - ) + return self.request_translator.to_pulsar(body) def from_response(self, message): - return { "vectors": message.vectors }, True + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/flow.py b/trustgraph-flow/trustgraph/gateway/dispatch/flow.py index 0b38e9be..30f8d45e 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/flow.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/flow.py @@ -2,6 +2,7 @@ from ... schema import FlowRequest, FlowResponse from ... schema import flow_request_queue from ... schema import flow_response_queue +from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor @@ -19,34 +20,12 @@ class FlowRequestor(ServiceRequestor): timeout=timeout, ) - def to_request(self, body): + self.request_translator = TranslatorRegistry.get_request_translator("flow") + self.response_translator = TranslatorRegistry.get_response_translator("flow") - return FlowRequest( - operation = body.get("operation", None), - class_name = body.get("class-name", None), - class_definition = body.get("class-definition", None), - description = body.get("description", None), - flow_id = body.get("flow-id", None), - ) + def to_request(self, body): + return self.request_translator.to_pulsar(body) def from_response(self, message): - - response = { } - - if message.class_names is not None: - response["class-names"] = message.class_names - - if message.flow_ids is not None: - response["flow-ids"] = message.flow_ids - - if message.class_definition is not None: - response["class-definition"] = message.class_definition - - if message.flow is not None: - response["flow"] = message.flow - - if message.description is not None: - response["description"] = message.description - - return response, True + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_query.py index 27ceb702..f5be06fb 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_embeddings_query.py @@ -1,8 +1,8 @@ from ... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse +from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor -from . serialize import serialize_value class GraphEmbeddingsQueryRequestor(ServiceRequestor): def __init__( @@ -21,22 +21,12 @@ class GraphEmbeddingsQueryRequestor(ServiceRequestor): timeout=timeout, ) + self.request_translator = TranslatorRegistry.get_request_translator("graph-embeddings-query") + self.response_translator = TranslatorRegistry.get_response_translator("graph-embeddings-query") + def to_request(self, body): - - limit = int(body.get("limit", 20)) - - return GraphEmbeddingsRequest( - vectors = body["vectors"], - limit = limit, - user = body.get("user", "trustgraph"), - collection = body.get("collection", "default"), - ) + return self.request_translator.to_pulsar(body) def from_response(self, message): - - return { - "entities": [ - serialize_value(ent) for ent in message.entities - ] - }, True + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py b/trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py index a31795b9..a15a1aee 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/graph_rag.py @@ -1,5 +1,6 @@ from ... schema import GraphRagQuery, GraphRagResponse +from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor @@ -20,17 +21,12 @@ class GraphRagRequestor(ServiceRequestor): timeout=timeout, ) + self.request_translator = TranslatorRegistry.get_request_translator("graph-rag") + self.response_translator = TranslatorRegistry.get_response_translator("graph-rag") + def to_request(self, body): - return GraphRagQuery( - query=body["query"], - user=body.get("user", "trustgraph"), - collection=body.get("collection", "default"), - entity_limit=int(body.get("entity-limit", 50)), - triple_limit=int(body.get("triple-limit", 30)), - max_subgraph_size=int(body.get("max-subgraph-size", 1000)), - max_path_length=int(body.get("max-path-length", 2)), - ) + return self.request_translator.to_pulsar(body) def from_response(self, message): - return { "response": message.response }, True + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/knowledge.py b/trustgraph-flow/trustgraph/gateway/dispatch/knowledge.py index a35ee4f0..950b3430 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/knowledge.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/knowledge.py @@ -5,11 +5,9 @@ from ... schema import KnowledgeRequest, KnowledgeResponse, Triples from ... schema import GraphEmbeddings, Metadata, EntityEmbeddings from ... schema import knowledge_request_queue from ... schema import knowledge_response_queue +from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor -from . serialize import serialize_graph_embeddings -from . serialize import serialize_triples, to_subgraph, to_value -from . serialize import to_document_metadata, to_processing_metadata class KnowledgeRequestor(ServiceRequestor): def __init__(self, pulsar_client, consumer, subscriber, timeout=120): @@ -25,73 +23,12 @@ class KnowledgeRequestor(ServiceRequestor): timeout=timeout, ) + self.request_translator = TranslatorRegistry.get_request_translator("knowledge") + self.response_translator = TranslatorRegistry.get_response_translator("knowledge") + def to_request(self, body): - - if "triples" in body: - triples = Triples( - metadata=Metadata( - id = body["triples"]["metadata"]["id"], - metadata = to_subgraph(body["triples"]["metadata"]["metadata"]), - user = body["triples"]["metadata"]["user"], - ), - triples = to_subgraph(body["triples"]["triples"]), - ) - else: - triples = None - - if "graph-embeddings" in body: - ge = GraphEmbeddings( - metadata = Metadata( - id = body["graph-embeddings"]["metadata"]["id"], - metadata = to_subgraph(body["graph-embeddings"]["metadata"]["metadata"]), - user = body["graph-embeddings"]["metadata"]["user"], - ), - entities=[ - EntityEmbeddings( - entity = to_value(ent["entity"]), - vectors = ent["vectors"], - ) - for ent in body["graph-embeddings"]["entities"] - ] - ) - else: - ge = None - - return KnowledgeRequest( - operation = body.get("operation", None), - user = body.get("user", None), - id = body.get("id", None), - flow = body.get("flow", None), - collection = body.get("collection", None), - triples = triples, - graph_embeddings = ge, - ) + return self.request_translator.to_pulsar(body) def from_response(self, message): - - # Response to list, - if message.ids is not None: - return { - "ids": message.ids - }, True - - if message.triples: - return { - "triples": serialize_triples(message.triples) - }, False - - if message.graph_embeddings: - return { - "graph-embeddings": serialize_graph_embeddings( - message.graph_embeddings - ) - }, False - - if message.eos is True: - return { - "eos": True - }, True - - # Empty case, return from successful delete. - return {}, True + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/librarian.py b/trustgraph-flow/trustgraph/gateway/dispatch/librarian.py index 364ba1c2..2155aa5d 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/librarian.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/librarian.py @@ -4,12 +4,9 @@ import base64 from ... schema import LibrarianRequest, LibrarianResponse from ... schema import librarian_request_queue from ... schema import librarian_response_queue +from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor -from . serialize import serialize_document_metadata -from . serialize import serialize_processing_metadata -from . serialize import to_document_metadata, to_processing_metadata -from . serialize import to_criteria class LibrarianRequestor(ServiceRequestor): def __init__(self, pulsar_client, consumer, subscriber, timeout=120): @@ -25,67 +22,20 @@ class LibrarianRequestor(ServiceRequestor): timeout=timeout, ) + self.request_translator = TranslatorRegistry.get_request_translator("librarian") + self.response_translator = TranslatorRegistry.get_response_translator("librarian") + def to_request(self, body): - - # Content gets base64 decoded & encoded again. It at least makes - # sure payload is valid base64. - - if "document-metadata" in body: - dm = to_document_metadata(body["document-metadata"]) - else: - dm = None - - if "processing-metadata" in body: - pm = to_processing_metadata(body["processing-metadata"]) - else: - pm = None - - if "criteria" in body: - criteria = to_criteria(body["criteria"]) - else: - criteria = None - + # Handle base64 content processing if "content" in body: + # Content gets base64 decoded & encoded again to ensure valid base64 content = base64.b64decode(body["content"].encode("utf-8")) content = base64.b64encode(content).decode("utf-8") - else: - content = None - - return LibrarianRequest( - operation = body.get("operation", None), - document_id = body.get("document-id", None), - processing_id = body.get("processing-id", None), - document_metadata = dm, - processing_metadata = pm, - content = content, - user = body.get("user", None), - collection = body.get("collection", None), - criteria = criteria, - ) + body = body.copy() + body["content"] = content + + return self.request_translator.to_pulsar(body) def from_response(self, message): - - response = {} - - if message.document_metadata: - response["document-metadata"] = serialize_document_metadata( - message.document_metadata - ) - - if message.content: - response["content"] = message.content.decode("utf-8") - - if message.document_metadatas != None: - response["document-metadatas"] = [ - serialize_document_metadata(v) - for v in message.document_metadatas - ] - - if message.processing_metadatas != None: - response["processing-metadatas"] = [ - serialize_processing_metadata(v) - for v in message.processing_metadatas - ] - - return response, True + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py index 8223461a..0b5b26f1 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/manager.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/manager.py @@ -1,5 +1,6 @@ import asyncio +from aiohttp import web import uuid from . config import ConfigRequestor @@ -30,6 +31,9 @@ from . graph_embeddings_import import GraphEmbeddingsImport from . document_embeddings_import import DocumentEmbeddingsImport from . entity_contexts_import import EntityContextsImport +from . core_export import CoreExport +from . core_import import CoreImport + from . mux import Mux request_response_dispatchers = { @@ -77,10 +81,11 @@ class DispatcherWrapper: class DispatcherManager: - def __init__(self, pulsar_client, config_receiver): + def __init__(self, pulsar_client, config_receiver, prefix="api-gateway"): self.pulsar_client = pulsar_client self.config_receiver = config_receiver self.config_receiver.add_handler(self) + self.prefix = prefix self.flows = {} self.dispatchers = {} @@ -98,6 +103,22 @@ class DispatcherManager: def dispatch_global_service(self): return DispatcherWrapper(self.process_global_service) + def dispatch_core_export(self): + return DispatcherWrapper(self.process_core_export) + + def dispatch_core_import(self): + return DispatcherWrapper(self.process_core_import) + + async def process_core_import(self, data, error, ok, request): + + ci = CoreImport(self.pulsar_client) + return await ci.process(data, error, ok, request) + + async def process_core_export(self, data, error, ok, request): + + ce = CoreExport(self.pulsar_client) + return await ce.process(data, error, ok, request) + async def process_global_service(self, data, responder, params): kind = params.get("kind") @@ -113,8 +134,8 @@ class DispatcherManager: dispatcher = global_dispatchers[kind]( pulsar_client = self.pulsar_client, timeout = 120, - consumer = f"api-gateway-{kind}-request", - subscriber = f"api-gateway-{kind}-request", + consumer = f"{self.prefix}-{kind}-request", + subscriber = f"{self.prefix}-{kind}-request", ) await dispatcher.start() @@ -206,8 +227,8 @@ class DispatcherManager: ws = ws, running = running, queue = qconfig, - consumer = f"api-gateway-{id}", - subscriber = f"api-gateway-{id}", + consumer = f"{self.prefix}-{id}", + subscriber = f"{self.prefix}-{id}", ) return dispatcher @@ -248,8 +269,8 @@ class DispatcherManager: request_queue = qconfig["request"], response_queue = qconfig["response"], timeout = 120, - consumer = f"api-gateway-{flow}-{kind}-request", - subscriber = f"api-gateway-{flow}-{kind}-request", + consumer = f"{self.prefix}-{flow}-{kind}-request", + subscriber = f"{self.prefix}-{flow}-{kind}-request", ) elif kind in sender_dispatchers: dispatcher = sender_dispatchers[kind]( diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/prompt.py b/trustgraph-flow/trustgraph/gateway/dispatch/prompt.py index 496d01e5..5c316cf6 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/prompt.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/prompt.py @@ -2,6 +2,7 @@ import json from ... schema import PromptRequest, PromptResponse +from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor @@ -22,22 +23,12 @@ class PromptRequestor(ServiceRequestor): timeout=timeout, ) + self.request_translator = TranslatorRegistry.get_request_translator("prompt") + self.response_translator = TranslatorRegistry.get_response_translator("prompt") + def to_request(self, body): - return PromptRequest( - id=body["id"], - terms={ - k: json.dumps(v) - for k, v in body["variables"].items() - } - ) + return self.request_translator.to_pulsar(body) def from_response(self, message): - if message.object: - return { - "object": message.object - }, True - else: - return { - "text": message.text - }, True + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py index bde3553a..653ecfd9 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/serialize.py @@ -3,6 +3,13 @@ import base64 from ... schema import Value, Triple, DocumentMetadata, ProcessingMetadata +# DEPRECATED: These functions have been moved to trustgraph.... messaging.translators +# Use the new messaging translation system instead for consistency and reusability. +# Examples: +# from trustgraph.... messaging.translators.primitives import ValueTranslator +# value_translator = ValueTranslator() +# pulsar_value = value_translator.to_pulsar({"v": "example", "e": True}) + def to_value(x): return Value(value=x["v"], is_uri=x["e"]) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/text_completion.py b/trustgraph-flow/trustgraph/gateway/dispatch/text_completion.py index 40ae7616..d29d1918 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/text_completion.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/text_completion.py @@ -1,5 +1,6 @@ from ... schema import TextCompletionRequest, TextCompletionResponse +from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor @@ -20,12 +21,12 @@ class TextCompletionRequestor(ServiceRequestor): timeout=timeout, ) + self.request_translator = TranslatorRegistry.get_request_translator("text-completion") + self.response_translator = TranslatorRegistry.get_response_translator("text-completion") + def to_request(self, body): - return TextCompletionRequest( - system=body["system"], - prompt=body["prompt"] - ) + return self.request_translator.to_pulsar(body) def from_response(self, message): - return { "response": message.response }, True + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/text_load.py b/trustgraph-flow/trustgraph/gateway/dispatch/text_load.py index 53ea7452..8f30c8de 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/text_load.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/text_load.py @@ -2,9 +2,9 @@ import base64 from ... schema import TextDocument, Metadata +from ... messaging import TranslatorRegistry from . sender import ServiceSender -from . serialize import to_subgraph class TextLoad(ServiceSender): def __init__(self, pulsar_client, queue): @@ -15,30 +15,9 @@ class TextLoad(ServiceSender): schema = TextDocument, ) + self.translator = TranslatorRegistry.get_request_translator("text-document") + def to_request(self, body): - - if "metadata" in body: - metadata = to_subgraph(body["metadata"]) - else: - metadata = [] - - if "charset" in body: - charset = body["charset"] - else: - charset = "utf-8" - - # Text is base64 encoded - text = base64.b64decode(body["text"]).decode(charset) - print("Text document received") - - return TextDocument( - metadata=Metadata( - id=body.get("id"), - metadata=metadata, - user=body.get("user", "trustgraph"), - collection=body.get("collection", "default"), - ), - text=text, - ) + return self.translator.to_pulsar(body) diff --git a/trustgraph-flow/trustgraph/gateway/dispatch/triples_query.py b/trustgraph-flow/trustgraph/gateway/dispatch/triples_query.py index 7c3a5fc9..d2def9c1 100644 --- a/trustgraph-flow/trustgraph/gateway/dispatch/triples_query.py +++ b/trustgraph-flow/trustgraph/gateway/dispatch/triples_query.py @@ -1,8 +1,8 @@ from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples +from ... messaging import TranslatorRegistry from . requestor import ServiceRequestor -from . serialize import to_value, serialize_subgraph class TriplesQueryRequestor(ServiceRequestor): def __init__( @@ -21,34 +21,12 @@ class TriplesQueryRequestor(ServiceRequestor): timeout=timeout, ) + self.request_translator = TranslatorRegistry.get_request_translator("triples-query") + self.response_translator = TranslatorRegistry.get_response_translator("triples-query") + def to_request(self, body): - - if "s" in body: - s = to_value(body["s"]) - else: - s = None - - if "p" in body: - p = to_value(body["p"]) - else: - p = None - - if "o" in body: - o = to_value(body["o"]) - else: - o = None - - limit = int(body.get("limit", 10000)) - - return TriplesQueryRequest( - s = s, p = p, o = o, - limit = limit, - user = body.get("user", "trustgraph"), - collection = body.get("collection", "default"), - ) + return self.request_translator.to_pulsar(body) def from_response(self, message): - return { - "response": serialize_subgraph(message.triples) - }, True + return self.response_translator.from_response_with_completion(message) diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py index 75a39766..f9616d9a 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/manager.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/manager.py @@ -3,7 +3,7 @@ import asyncio from aiohttp import web -from . constant_endpoint import ConstantEndpoint +from . stream_endpoint import StreamEndpoint from . variable_endpoint import VariableEndpoint from . socket import SocketEndpoint from . metrics import MetricsEndpoint @@ -52,6 +52,18 @@ class EndpointManager: auth = auth, dispatcher = dispatcher_manager.dispatch_flow_export() ), + StreamEndpoint( + endpoint_path = "/api/v1/import-core", + auth = auth, + method = "POST", + dispatcher = dispatcher_manager.dispatch_core_import(), + ), + StreamEndpoint( + endpoint_path = "/api/v1/export-core", + auth = auth, + method = "GET", + dispatcher = dispatcher_manager.dispatch_core_export(), + ), ] def add_routes(self, app): diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py new file mode 100644 index 00000000..649c043e --- /dev/null +++ b/trustgraph-flow/trustgraph/gateway/endpoint/stream_endpoint.py @@ -0,0 +1,82 @@ + +import asyncio +from aiohttp import web +import logging + +logger = logging.getLogger("endpoint") +logger.setLevel(logging.INFO) + +class StreamEndpoint: + + def __init__(self, endpoint_path, auth, dispatcher, method="POST"): + + self.path = endpoint_path + + self.auth = auth + self.operation = "service" + self.method = method + + self.dispatcher = dispatcher + + async def start(self): + pass + + def add_routes(self, app): + + if self.method == "POST": + app.add_routes([ + web.post(self.path, self.handle), + ]) + elif self.method == "GET": + app.add_routes([ + web.get(self.path, self.handle), + ]) + else: + raise RuntimeError("Bad method" + self.method) + + async def handle(self, request): + + print(request.path, "...") + + try: + ht = request.headers["Authorization"] + tokens = ht.split(" ", 2) + if tokens[0] != "Bearer": + return web.HTTPUnauthorized() + token = tokens[1] + except: + token = "" + + if not self.auth.permitted(token, self.operation): + return web.HTTPUnauthorized() + + try: + + data = request.content + + async def error(err): + return web.HTTPInternalServerError(text = err) + + async def ok( + status=200, reason="OK", type="application/octet-stream" + ): + response = web.StreamResponse( + status = status, reason = reason, + headers = {"Content-Type": type} + ) + await response.prepare(request) + return response + + resp = await self.dispatcher.process( + data, error, ok, request + ) + + return resp + + except Exception as e: + logging.error(f"Exception: {e}") + + return web.json_response( + { "error": str(e) } + ) + diff --git a/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py b/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py index 4a131d9d..ae0ae8fb 100644 --- a/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py +++ b/trustgraph-flow/trustgraph/gateway/endpoint/variable_endpoint.py @@ -1,7 +1,6 @@ import asyncio from aiohttp import web -import uuid import logging logger = logging.getLogger("endpoint") diff --git a/trustgraph-flow/trustgraph/gateway/service.py b/trustgraph-flow/trustgraph/gateway/service.py index 97406422..ee66b9d3 100755 --- a/trustgraph-flow/trustgraph/gateway/service.py +++ b/trustgraph-flow/trustgraph/gateway/service.py @@ -73,6 +73,7 @@ class Api: self.dispatcher_manager = DispatcherManager( pulsar_client = self.pulsar_client, config_receiver = self.config_receiver, + prefix = "gateway", ) self.endpoint_manager = EndpointManager( diff --git a/trustgraph-flow/trustgraph/model/prompt/template/service.py b/trustgraph-flow/trustgraph/model/prompt/template/service.py index 67590c1c..7bebf5f4 100755 --- a/trustgraph-flow/trustgraph/model/prompt/template/service.py +++ b/trustgraph-flow/trustgraph/model/prompt/template/service.py @@ -18,12 +18,14 @@ from .... base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec from . prompt_manager import PromptConfiguration, Prompt, PromptManager default_ident = "prompt" +default_concurrency = 1 class Processor(FlowProcessor): def __init__(self, **params): id = params.get("id") + concurrency = params.get("concurrency", 1) # Config key for prompts self.config_key = params.get("config_type", "prompt") @@ -31,6 +33,7 @@ class Processor(FlowProcessor): super(Processor, self).__init__( **params | { "id": id, + "concurrency": concurrency, } ) @@ -38,7 +41,8 @@ class Processor(FlowProcessor): ConsumerSpec( name = "request", schema = PromptRequest, - handler = self.on_request + handler = self.on_request, + concurrency = concurrency, ) ) @@ -219,6 +223,13 @@ class Processor(FlowProcessor): @staticmethod def add_args(parser): + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Concurrent processing threads (default: {default_concurrency})' + ) + FlowProcessor.add_args(parser) parser.add_argument( diff --git a/trustgraph-flow/trustgraph/model/text_completion/vllm/__init__.py b/trustgraph-flow/trustgraph/model/text_completion/vllm/__init__.py new file mode 100644 index 00000000..f2017af8 --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/vllm/__init__.py @@ -0,0 +1,3 @@ + +from . llm import * + diff --git a/trustgraph-flow/trustgraph/model/text_completion/vllm/__main__.py b/trustgraph-flow/trustgraph/model/text_completion/vllm/__main__.py new file mode 100755 index 00000000..91342d2d --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/vllm/__main__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from . llm import run + +if __name__ == '__main__': + run() + diff --git a/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py b/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py new file mode 100755 index 00000000..96b232e8 --- /dev/null +++ b/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py @@ -0,0 +1,138 @@ + +""" +Simple LLM service, performs text prompt completion using vLLM +Input is prompt, output is response. +""" + +import os +import aiohttp + +from .... exceptions import TooManyRequests +from .... base import LlmService, LlmResult + +default_ident = "text-completion" + +default_temperature = 0.0 +default_max_output = 2048 +default_base_url = os.getenv("VLLM_BASE_URL") +default_model = "TheBloke/Mistral-7B-v0.1-AWQ" + +if default_base_url == "" or default_base_url is None: + default_base_url = "http://vllm-service:8899/v1" + +class Processor(LlmService): + + def __init__(self, **params): + + base_url = params.get("url", default_base_url) + temperature = params.get("temperature", default_temperature) + max_output = params.get("max_output", default_max_output) + model = params.get("model", default_model) + + super(Processor, self).__init__( + **params | { + "temperature": temperature, + "max_output": max_output, + "url": base_url, + "model": model, + } + ) + + self.base_url = base_url + self.temperature = temperature + self.max_output = max_output + self.model = model + + self.session = aiohttp.ClientSession() + + print("Using vLLM service at", base_url) + + print("Initialised", flush=True) + + async def generate_content(self, system, prompt): + + headers = { + "Content-Type": "application/json", + } + + request = { + "model": self.model, + "prompt": system + "\n\n" + prompt, + "max_tokens": self.max_output, + "temperature": self.temperature, + } + + try: + + url = f"{self.base_url}/completions" + + async with self.session.post( + url, + headers=headers, + json=request, + ) as response: + + if response.status != 200: + raise RuntimeError("Bad status: " + str(response.status)) + + resp = await response.json() + + inputtokens = resp["usage"]["prompt_tokens"] + outputtokens = resp["usage"]["completion_tokens"] + ans = resp["choices"][0]["text"] + print(f"Input Tokens: {inputtokens}", flush=True) + print(f"Output Tokens: {outputtokens}", flush=True) + print(ans, flush=True) + + resp = LlmResult( + text = ans, + in_token = inputtokens, + out_token = outputtokens, + model = self.model, + ) + + return resp + + # FIXME: Assuming vLLM won't produce rate limits? + + except Exception as e: + + # Apart from rate limits, treat all exceptions as unrecoverable + + print(f"Exception: {type(e)} {e}") + raise e + + @staticmethod + def add_args(parser): + + LlmService.add_args(parser) + + parser.add_argument( + '-u', '--url', + default=default_base_url, + help=f'vLLM service base URL (default: {default_base_url})' + ) + + parser.add_argument( + '-m', '--model', + default=default_model, + help=f'LLM model (default: {default_model})' + ) + + parser.add_argument( + '-t', '--temperature', + type=float, + default=default_temperature, + help=f'LLM temperature parameter (default: {default_temperature})' + ) + + parser.add_argument( + '-x', '--max-output', + type=int, + default=default_max_output, + help=f'LLM max output tokens (default: {default_max_output})' + ) + +def run(): + + Processor.launch(default_ident, __doc__) diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 5d3cc2f4..328ae3f9 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -11,12 +11,14 @@ from ... base import PromptClientSpec, EmbeddingsClientSpec from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec default_ident = "graph-rag" +default_concurrency = 1 class Processor(FlowProcessor): def __init__(self, **params): id = params.get("id", default_ident) + concurrency = params.get("concurrency", 1) entity_limit = params.get("entity_limit", 50) triple_limit = params.get("triple_limit", 30) @@ -26,6 +28,7 @@ class Processor(FlowProcessor): super(Processor, self).__init__( **params | { "id": id, + "concurrency": concurrency, "entity_limit": entity_limit, "triple_limit": triple_limit, "max_subgraph_size": max_subgraph_size, @@ -43,6 +46,7 @@ class Processor(FlowProcessor): name = "request", schema = GraphRagQuery, handler = self.on_request, + concurrency = concurrency, ) ) @@ -157,6 +161,13 @@ class Processor(FlowProcessor): @staticmethod def add_args(parser): + parser.add_argument( + '-c', '--concurrency', + type=int, + default=default_concurrency, + help=f'Concurrent processing threads (default: {default_concurrency})' + ) + FlowProcessor.add_args(parser) parser.add_argument( diff --git a/trustgraph-flow/trustgraph/rev_gateway/__init__.py b/trustgraph-flow/trustgraph/rev_gateway/__init__.py new file mode 100644 index 00000000..1be89162 --- /dev/null +++ b/trustgraph-flow/trustgraph/rev_gateway/__init__.py @@ -0,0 +1 @@ +from . service import run diff --git a/trustgraph-flow/trustgraph/rev_gateway/__main__.py b/trustgraph-flow/trustgraph/rev_gateway/__main__.py new file mode 100644 index 00000000..70262bc8 --- /dev/null +++ b/trustgraph-flow/trustgraph/rev_gateway/__main__.py @@ -0,0 +1,11 @@ +import logging +from .service import run + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +if __name__ == "__main__": + run() + diff --git a/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py new file mode 100644 index 00000000..03e79c0d --- /dev/null +++ b/trustgraph-flow/trustgraph/rev_gateway/dispatcher.py @@ -0,0 +1,130 @@ +import asyncio +import logging +import uuid +from typing import Dict, Any, Optional +from trustgraph.messaging import TranslatorRegistry +from ..gateway.dispatch.manager import DispatcherManager + +logger = logging.getLogger("dispatcher") +logger.setLevel(logging.INFO) + +class WebSocketResponder: + """Simple responder that captures response for websocket return""" + def __init__(self): + self.response = None + self.completed = False + + async def send(self, data): + """Capture the response data""" + self.response = data + self.completed = True + + async def __call__(self, data, final=False): + """Make the responder callable for compatibility with requestor""" + await self.send(data) + if final: + self.completed = True + +class MessageDispatcher: + + def __init__(self, max_workers: int = 10, config_receiver=None, pulsar_client=None): + self.max_workers = max_workers + self.semaphore = asyncio.Semaphore(max_workers) + self.active_tasks = set() + self.pulsar_client = pulsar_client + + # Use DispatcherManager for flow and service management + if pulsar_client and config_receiver: + self.dispatcher_manager = DispatcherManager(pulsar_client, config_receiver, prefix="rev-gateway") + else: + self.dispatcher_manager = None + logger.warning("No pulsar_client or config_receiver provided - using fallback mode") + + # Service name mapping from websocket protocol to translator registry + self.service_mapping = { + "text-completion": "text-completion", + "graph-rag": "graph-rag", + "agent": "agent", + "embeddings": "embeddings", + "graph-embeddings": "graph-embeddings", + "triples": "triples", + "document-load": "document", + "text-load": "text-document", + "flow": "flow", + "knowledge": "knowledge", + "config": "config", + "librarian": "librarian", + "document-rag": "document-rag" + } + + async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]: + async with self.semaphore: + task = asyncio.create_task(self._process_message(message)) + self.active_tasks.add(task) + + try: + result = await task + return result + finally: + self.active_tasks.discard(task) + + async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]: + request_id = message.get('id', str(uuid.uuid4())) + service = message.get('service') + request_data = message.get('request', {}) + flow_id = message.get('flow', 'default') # Default flow + + logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}") + + try: + if not self.dispatcher_manager: + raise RuntimeError("DispatcherManager not available - pulsar_client and config_receiver required") + + # Use DispatcherManager for flow-based processing + responder = WebSocketResponder() + + # Map websocket service name to dispatcher service name + dispatcher_service = self.service_mapping.get(service, service) + + # Check if this is a global service or flow service + from ..gateway.dispatch.manager import global_dispatchers + if dispatcher_service in global_dispatchers: + # Use global service dispatcher + await self.dispatcher_manager.invoke_global_service( + request_data, responder, dispatcher_service + ) + else: + # Use DispatcherManager to process the request through Pulsar queues + await self.dispatcher_manager.invoke_flow_service( + request_data, responder, flow_id, dispatcher_service + ) + + # Get the response from the responder + if responder.completed: + response_data = responder.response + else: + response_data = {'error': 'No response received'} + + response = { + 'id': request_id, + 'response': response_data + } + + except Exception as e: + logger.error(f"Error processing message {request_id}: {e}") + response = { + 'id': request_id, + 'response': {'error': str(e)} + } + + logger.info(f"Completed processing message {request_id}") + return response + + + async def shutdown(self): + if self.active_tasks: + logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete") + await asyncio.gather(*self.active_tasks, return_exceptions=True) + + # DispatcherManager handles its own cleanup + logger.info("Dispatcher shutdown complete") diff --git a/trustgraph-flow/trustgraph/rev_gateway/service.py b/trustgraph-flow/trustgraph/rev_gateway/service.py new file mode 100644 index 00000000..8d82f407 --- /dev/null +++ b/trustgraph-flow/trustgraph/rev_gateway/service.py @@ -0,0 +1,242 @@ +import asyncio +import argparse +import logging +import json +import sys +import os +from aiohttp import ClientSession, WSMsgType, ClientWebSocketResponse +from typing import Optional +from urllib.parse import urlparse, urlunparse +import pulsar + +from .dispatcher import MessageDispatcher +from ..gateway.config.receiver import ConfigReceiver + +logger = logging.getLogger("rev_gateway") +logger.setLevel(logging.INFO) + +default_websocket = "ws://localhost:7650/out" + +class ReverseGateway: + + def __init__(self, websocket_uri: str = None, max_workers: int = 10, + pulsar_host: str = None, pulsar_api_key: str = None, + pulsar_listener: str = None): + # Set default WebSocket URI with environment variable support + if websocket_uri is None: + websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket) + + # Parse and validate the WebSocket URI + parsed_uri = urlparse(websocket_uri) + if parsed_uri.scheme not in ('ws', 'wss'): + raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}") + if not parsed_uri.netloc: + raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}") + + # Store parsed components for debugging/logging + self.websocket_uri = websocket_uri + self.host = parsed_uri.hostname + self.port = parsed_uri.port + self.scheme = parsed_uri.scheme + self.path = parsed_uri.path or "/ws" + + # Construct the full URL (in case path was missing) + if not parsed_uri.path: + self.url = f"{self.scheme}://{parsed_uri.netloc}/ws" + else: + self.url = websocket_uri + + self.max_workers = max_workers + self.ws: Optional[ClientWebSocketResponse] = None + self.session: Optional[ClientSession] = None + self.running = False + self.reconnect_delay = 3.0 + + # Pulsar configuration + self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") + self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None) + self.pulsar_listener = pulsar_listener + + # Initialize Pulsar client + if self.pulsar_api_key: + self.pulsar_client = pulsar.Client( + self.pulsar_host, + listener_name=self.pulsar_listener, + authentication=pulsar.AuthenticationToken(self.pulsar_api_key) + ) + else: + self.pulsar_client = pulsar.Client( + self.pulsar_host, + listener_name=self.pulsar_listener + ) + + # Initialize config receiver + self.config_receiver = ConfigReceiver(self.pulsar_client) + + # Initialize dispatcher with config_receiver and pulsar_client - must be created after config_receiver + self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.pulsar_client) + + async def connect(self) -> bool: + try: + if self.session is None: + self.session = ClientSession() + + logger.info(f"Connecting to {self.url}") + self.ws = await self.session.ws_connect(self.url) + logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}") + return True + + except Exception as e: + logger.error(f"Failed to connect to {self.url}: {e}") + return False + + async def disconnect(self): + if self.ws and not self.ws.closed: + await self.ws.close() + if self.session and not self.session.closed: + await self.session.close() + self.ws = None + self.session = None + + async def send_message(self, message: dict): + if self.ws and not self.ws.closed: + try: + await self.ws.send_str(json.dumps(message)) + except Exception as e: + logger.error(f"Failed to send message: {e}") + + async def handle_message(self, message: str): + try: + print(f"Received: {message}", flush=True) + + msg_data = json.loads(message) + response = await self.dispatcher.handle_message(msg_data) + + if response: + await self.send_message(response) + + except Exception as e: + logger.error(f"Error handling message: {e}") + + async def listen(self): + while self.running and self.ws and not self.ws.closed: + try: + msg = await self.ws.receive() + + if msg.type == WSMsgType.TEXT: + await self.handle_message(msg.data) + elif msg.type == WSMsgType.BINARY: + await self.handle_message(msg.data.decode('utf-8')) + elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR): + logger.warning("WebSocket closed or error occurred") + break + + except Exception as e: + logger.error(f"Error in listen loop: {e}") + break + + async def run(self): + self.running = True + logger.info("Starting reverse gateway") + + # Start config receiver + logger.info("Starting config receiver") + await self.config_receiver.start() + + while self.running: + try: + if await self.connect(): + await self.listen() + else: + logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds") + + await self.disconnect() + + if self.running: + await asyncio.sleep(self.reconnect_delay) + + except KeyboardInterrupt: + logger.info("Shutdown requested") + break + except Exception as e: + logger.error(f"Unexpected error: {e}") + if self.running: + await asyncio.sleep(self.reconnect_delay) + + await self.shutdown() + + async def shutdown(self): + logger.info("Shutting down reverse gateway") + self.running = False + await self.dispatcher.shutdown() + await self.disconnect() + + # Close Pulsar client + if hasattr(self, 'pulsar_client'): + self.pulsar_client.close() + + def stop(self): + self.running = False + +def parse_args(): + parser = argparse.ArgumentParser( + prog="reverse-gateway", + description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge" + ) + + parser.add_argument( + '--websocket-uri', + default=None, + help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)' + ) + + parser.add_argument( + '--max-workers', + type=int, + default=10, + help='Maximum concurrent message handlers (default: 10)' + ) + + parser.add_argument( + '-p', '--pulsar-host', + default=None, + help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)' + ) + + parser.add_argument( + '--pulsar-api-key', + default=None, + help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)' + ) + + parser.add_argument( + '--pulsar-listener', + default=None, + help='Pulsar listener name' + ) + + return parser.parse_args() + +def run(): + args = parse_args() + + gateway = ReverseGateway( + websocket_uri=args.websocket_uri, + max_workers=args.max_workers, + pulsar_host=args.pulsar_host, + pulsar_api_key=args.pulsar_api_key, + pulsar_listener=args.pulsar_listener + ) + + print(f"Starting reverse gateway:") + print(f" WebSocket URI: {gateway.url}") + print(f" Max workers: {args.max_workers}") + print(f" Pulsar host: {gateway.pulsar_host}") + + try: + asyncio.run(gateway.run()) + except KeyboardInterrupt: + print("\nShutdown requested by user") + except Exception as e: + print(f"Fatal error: {e}") + sys.exit(1)