Merge branch 'release/v1.0'

This commit is contained in:
Cyber MacGeddon 2025-06-28 12:02:20 +01:00
commit dbe78ebe46
66 changed files with 2706 additions and 528 deletions

View file

@ -9,6 +9,7 @@ FROM docker.io/fedora:42 AS base
ENV PIP_BREAK_SYSTEM_PACKAGES=1
RUN dnf install -y python3.12 && \
dnf install -y tesseract poppler-utils && \
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
python -m ensurepip --upgrade && \
pip3 install --no-cache-dir wheel aiohttp && \

View file

@ -13,7 +13,6 @@ from prometheus_client import start_http_server, Info
from .. schema import ConfigPush, config_push_queue
from .. log_level import LogLevel
from .. exceptions import TooManyRequests
from . pubsub import PulsarClient
from . producer import Producer
from . consumer import Consumer

View file

@ -1,4 +1,14 @@
# Consumer is similar to subscriber: It takes information from a queue
# and passes on to a processor function. This is the main receiving
# loop for TrustGraph processors. Incorporates retry functionality
# Note: there is a 'defect' in the system which is tolerated, althought
# the processing handlers are async functions, ideally implementation
# would use all async code. In practice if the processor only implements
# one handler, and a single thread of concurrency, nothing too outrageous
# will happen if synchronous / blocking code is used
from pulsar.schema import JsonSchema
import pulsar
import _pulsar
@ -16,6 +26,7 @@ class Consumer:
start_of_messages=False,
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
reconnect_time = 5,
concurrency = 1, # Number of concurrent requests to handle
):
self.taskgroup = taskgroup
@ -34,7 +45,9 @@ class Consumer:
self.start_of_messages = start_of_messages
self.running = True
self.task = None
self.consumer_task = None
self.concurrency = concurrency
self.metrics = metrics
@ -52,7 +65,11 @@ class Consumer:
async def stop(self):
self.running = False
await self.task
if self.consumer_task:
await self.consumer_task
self.consumer_task = None
async def start(self):
@ -62,9 +79,9 @@ class Consumer:
if self.metrics:
self.metrics.state("stopped")
self.task = self.taskgroup.create_task(self.run())
self.consumer_task = self.taskgroup.create_task(self.consumer_run())
async def run(self):
async def consumer_run(self):
while self.running:
@ -102,7 +119,19 @@ class Consumer:
try:
await self.consume()
print(
"Starting", self.concurrency, "receiver threads",
flush=True
)
async with asyncio.TaskGroup() as tg:
tasks = []
for i in range(0, self.concurrency):
tasks.append(
tg.create_task(self.consume_from_queue())
)
if self.metrics:
self.metrics.state("stopped")
@ -120,7 +149,7 @@ class Consumer:
self.consumer.unsubscribe()
self.consumer.close()
async def consume(self):
async def consume_from_queue(self):
while self.running:
@ -134,71 +163,75 @@ class Consumer:
except Exception as e:
raise e
expiry = time.time() + self.rate_limit_timeout
await self.handle_one_from_queue(msg)
# This loop is for retry on rate-limit / resource limits
while self.running:
async def handle_one_from_queue(self, msg):
if time.time() > expiry:
expiry = time.time() + self.rate_limit_timeout
print("Gave up waiting for rate-limit retry", flush=True)
# This loop is for retry on rate-limit / resource limits
while self.running:
# Message failed to be processed, this causes it to
# be retried
self.consumer.negative_acknowledge(msg)
if time.time() > expiry:
if self.metrics:
self.metrics.process("error")
print("Gave up waiting for rate-limit retry", flush=True)
# Break out of retry loop, processes next message
break
# Message failed to be processed, this causes it to
# be retried
self.consumer.negative_acknowledge(msg)
try:
if self.metrics:
self.metrics.process("error")
print("Handle...", flush=True)
# Break out of retry loop, processes next message
break
if self.metrics:
try:
with self.metrics.record_time():
await self.handler(msg, self, self.flow)
print("Handle...", flush=True)
else:
if self.metrics:
with self.metrics.record_time():
await self.handler(msg, self, self.flow)
print("Handled.", flush=True)
else:
await self.handler(msg, self, self.flow)
# Acknowledge successful processing of the message
self.consumer.acknowledge(msg)
print("Handled.", flush=True)
if self.metrics:
self.metrics.process("success")
# Acknowledge successful processing of the message
self.consumer.acknowledge(msg)
# Break out of retry loop
break
if self.metrics:
self.metrics.process("success")
except TooManyRequests:
# Break out of retry loop
break
print("TooManyRequests: will retry...", flush=True)
except TooManyRequests:
if self.metrics:
self.metrics.rate_limit()
print("TooManyRequests: will retry...", flush=True)
# Sleep
await asyncio.sleep(self.rate_limit_retry_time)
if self.metrics:
self.metrics.rate_limit()
# Contine from retry loop, just causes a reprocessing
continue
# Sleep
await asyncio.sleep(self.rate_limit_retry_time)
except Exception as e:
# Contine from retry loop, just causes a reprocessing
continue
print("consume exception:", e, flush=True)
except Exception as e:
# Message failed to be processed, this causes it to
# be retried
self.consumer.negative_acknowledge(msg)
print("consume exception:", e, flush=True)
if self.metrics:
self.metrics.process("error")
# Message failed to be processed, this causes it to
# be retried
self.consumer.negative_acknowledge(msg)
# Break out of retry loop, processes next message
break
if self.metrics:
self.metrics.process("error")
# Break out of retry loop, processes next message
break

View file

@ -4,10 +4,11 @@ from . consumer import Consumer
from . spec import Spec
class ConsumerSpec(Spec):
def __init__(self, name, schema, handler):
def __init__(self, name, schema, handler, concurrency = 1):
self.name = name
self.schema = schema
self.handler = handler
self.concurrency = concurrency
def add(self, flow, processor, definition):
@ -24,6 +25,7 @@ class ConsumerSpec(Spec):
schema = self.schema,
handler = self.handler,
metrics = consumer_metrics,
concurrency = self.concurrency
)
# Consumer handle gets access to producers and other

View file

@ -11,20 +11,26 @@ from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
default_ident = "embeddings"
default_concurrency = 1
class EmbeddingsService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", 1)
super(EmbeddingsService, self).__init__(**params | { "id": id })
super(EmbeddingsService, self).__init__(**params | {
"id": id,
"concurrency": concurrency,
})
self.register_specification(
ConsumerSpec(
name = "request",
schema = EmbeddingsRequest,
handler = self.on_request
handler = self.on_request,
concurrency = concurrency,
)
)
@ -84,6 +90,13 @@ class EmbeddingsService(FlowProcessor):
@staticmethod
def add_args(parser):
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)

View file

@ -11,9 +11,13 @@ from .. exceptions import TooManyRequests
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
default_ident = "text-completion"
default_concurrency = 1
class LlmResult:
def __init__(self, text=None, in_token=None, out_token=None, model=None):
def __init__(
self, text = None, in_token = None, out_token = None,
model = None,
):
self.text = text
self.in_token = in_token
self.out_token = out_token
@ -25,14 +29,19 @@ class LlmService(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", 1)
super(LlmService, self).__init__(**params | { "id": id })
super(LlmService, self).__init__(**params | {
"id": id,
"concurrency": concurrency,
})
self.register_specification(
ConsumerSpec(
name = "request",
schema = TextCompletionRequest,
handler = self.on_request
handler = self.on_request,
concurrency = concurrency,
)
)
@ -115,5 +124,12 @@ class LlmService(FlowProcessor):
@staticmethod
def add_args(parser):
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)

View file

@ -1,4 +1,8 @@
# Subscriber is similar to consumer: It provides a service to take stuff
# off of a queue and make it available using an internal broker system,
# so suitable for when multiple recipients are reading from the same queue
from pulsar.schema import JsonSchema
import asyncio
import _pulsar

View file

@ -5,6 +5,7 @@ Triples store base class
from .. schema import Triples
from .. base import FlowProcessor, ConsumerSpec
from .. exceptions import TooManyRequests
default_ident = "triples-write"

View file

@ -0,0 +1,105 @@
from .registry import TranslatorRegistry
from .translators import *
# Auto-register all translators
from .translators.agent import AgentRequestTranslator, AgentResponseTranslator
from .translators.embeddings import EmbeddingsRequestTranslator, EmbeddingsResponseTranslator
from .translators.text_completion import TextCompletionRequestTranslator, TextCompletionResponseTranslator
from .translators.retrieval import (
DocumentRagRequestTranslator, DocumentRagResponseTranslator,
GraphRagRequestTranslator, GraphRagResponseTranslator
)
from .translators.triples import TriplesQueryRequestTranslator, TriplesQueryResponseTranslator
from .translators.knowledge import KnowledgeRequestTranslator, KnowledgeResponseTranslator
from .translators.library import LibraryRequestTranslator, LibraryResponseTranslator
from .translators.document_loading import DocumentTranslator, TextDocumentTranslator
from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator
from .translators.flow import FlowRequestTranslator, FlowResponseTranslator
from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator
from .translators.embeddings_query import (
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
)
# Register all service translators
TranslatorRegistry.register_service(
"agent",
AgentRequestTranslator(),
AgentResponseTranslator()
)
TranslatorRegistry.register_service(
"embeddings",
EmbeddingsRequestTranslator(),
EmbeddingsResponseTranslator()
)
TranslatorRegistry.register_service(
"text-completion",
TextCompletionRequestTranslator(),
TextCompletionResponseTranslator()
)
TranslatorRegistry.register_service(
"document-rag",
DocumentRagRequestTranslator(),
DocumentRagResponseTranslator()
)
TranslatorRegistry.register_service(
"graph-rag",
GraphRagRequestTranslator(),
GraphRagResponseTranslator()
)
TranslatorRegistry.register_service(
"triples-query",
TriplesQueryRequestTranslator(),
TriplesQueryResponseTranslator()
)
TranslatorRegistry.register_service(
"knowledge",
KnowledgeRequestTranslator(),
KnowledgeResponseTranslator()
)
TranslatorRegistry.register_service(
"librarian",
LibraryRequestTranslator(),
LibraryResponseTranslator()
)
TranslatorRegistry.register_service(
"config",
ConfigRequestTranslator(),
ConfigResponseTranslator()
)
TranslatorRegistry.register_service(
"flow",
FlowRequestTranslator(),
FlowResponseTranslator()
)
TranslatorRegistry.register_service(
"prompt",
PromptRequestTranslator(),
PromptResponseTranslator()
)
TranslatorRegistry.register_service(
"document-embeddings-query",
DocumentEmbeddingsRequestTranslator(),
DocumentEmbeddingsResponseTranslator()
)
TranslatorRegistry.register_service(
"graph-embeddings-query",
GraphEmbeddingsRequestTranslator(),
GraphEmbeddingsResponseTranslator()
)
# Register single-direction translators for document loading
TranslatorRegistry.register_request("document", DocumentTranslator())
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())

View file

@ -0,0 +1,51 @@
from typing import Dict, List, Union
from .translators.base import MessageTranslator
class TranslatorRegistry:
"""Registry for service translators"""
_request_translators: Dict[str, MessageTranslator] = {}
_response_translators: Dict[str, MessageTranslator] = {}
@classmethod
def register_request(cls, service_name: str, translator: MessageTranslator):
"""Register a request translator for a service"""
cls._request_translators[service_name] = translator
@classmethod
def register_response(cls, service_name: str, translator: MessageTranslator):
"""Register a response translator for a service"""
cls._response_translators[service_name] = translator
@classmethod
def register_service(cls, service_name: str, request_translator: MessageTranslator,
response_translator: MessageTranslator):
"""Register both request and response translators for a service"""
cls.register_request(service_name, request_translator)
cls.register_response(service_name, response_translator)
@classmethod
def get_request_translator(cls, service_name: str) -> MessageTranslator:
"""Get request translator for a service"""
if service_name not in cls._request_translators:
raise KeyError(f"No request translator registered for service: {service_name}")
return cls._request_translators[service_name]
@classmethod
def get_response_translator(cls, service_name: str) -> MessageTranslator:
"""Get response translator for a service"""
if service_name not in cls._response_translators:
raise KeyError(f"No response translator registered for service: {service_name}")
return cls._response_translators[service_name]
@classmethod
def list_services(cls) -> List[str]:
"""List all registered services"""
return sorted(set(cls._request_translators.keys()) | set(cls._response_translators.keys()))
@classmethod
def has_service(cls, service_name: str) -> bool:
"""Check if a service is registered"""
return (service_name in cls._request_translators or
service_name in cls._response_translators)

View file

@ -0,0 +1,19 @@
from .base import Translator, MessageTranslator
from .primitives import ValueTranslator, TripleTranslator, SubgraphTranslator
from .metadata import DocumentMetadataTranslator, ProcessingMetadataTranslator
from .agent import AgentRequestTranslator, AgentResponseTranslator
from .embeddings import EmbeddingsRequestTranslator, EmbeddingsResponseTranslator
from .text_completion import TextCompletionRequestTranslator, TextCompletionResponseTranslator
from .retrieval import DocumentRagRequestTranslator, DocumentRagResponseTranslator
from .retrieval import GraphRagRequestTranslator, GraphRagResponseTranslator
from .triples import TriplesQueryRequestTranslator, TriplesQueryResponseTranslator
from .knowledge import KnowledgeRequestTranslator, KnowledgeResponseTranslator
from .library import LibraryRequestTranslator, LibraryResponseTranslator
from .document_loading import DocumentTranslator, TextDocumentTranslator, ChunkTranslator, DocumentEmbeddingsTranslator
from .config import ConfigRequestTranslator, ConfigResponseTranslator
from .flow import FlowRequestTranslator, FlowResponseTranslator
from .prompt import PromptRequestTranslator, PromptResponseTranslator
from .embeddings_query import (
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
)

View file

@ -0,0 +1,44 @@
from typing import Dict, Any, Tuple
from ...schema import AgentRequest, AgentResponse
from .base import MessageTranslator
class AgentRequestTranslator(MessageTranslator):
"""Translator for AgentRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest:
return AgentRequest(
question=data["question"],
plan=data.get("plan", ""),
state=data.get("state", ""),
history=data.get("history", [])
)
def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]:
return {
"question": obj.question,
"plan": obj.plan,
"state": obj.state,
"history": obj.history
}
class AgentResponseTranslator(MessageTranslator):
"""Translator for AgentResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> AgentResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: AgentResponse) -> Dict[str, Any]:
result = {}
if obj.answer:
result["answer"] = obj.answer
if obj.thought:
result["thought"] = obj.thought
if obj.observation:
result["observation"] = obj.observation
return result
def from_response_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), (obj.answer is not None)

View file

@ -0,0 +1,43 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple
from pulsar.schema import Record
class Translator(ABC):
"""Base class for bidirectional Pulsar ↔ dict translation"""
@abstractmethod
def to_pulsar(self, data: Dict[str, Any]) -> Record:
"""Convert dict to Pulsar schema object"""
pass
@abstractmethod
def from_pulsar(self, obj: Record) -> Dict[str, Any]:
"""Convert Pulsar schema object to dict"""
pass
class MessageTranslator(Translator):
"""For complete request/response message translation"""
def from_response_with_completion(self, obj: Record) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final) - for streaming responses"""
return self.from_pulsar(obj), True
class SendTranslator(Translator):
"""For fire-and-forget send operations (like ServiceSender)"""
def from_pulsar(self, obj: Record) -> Dict[str, Any]:
"""Usually not needed for send-only operations"""
raise NotImplementedError("Send translators typically don't need from_pulsar")
def handle_optional_fields(obj: Record, fields: list) -> Dict[str, Any]:
"""Helper to extract optional fields from Pulsar object"""
result = {}
for field in fields:
value = getattr(obj, field, None)
if value is not None:
result[field] = value
return result

View file

@ -0,0 +1,100 @@
from typing import Dict, Any, Tuple
from ...schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
from .base import MessageTranslator
class ConfigRequestTranslator(MessageTranslator):
"""Translator for ConfigRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> ConfigRequest:
keys = None
if "keys" in data:
keys = [
ConfigKey(
type=k["type"],
key=k["key"]
)
for k in data["keys"]
]
values = None
if "values" in data:
values = [
ConfigValue(
type=v["type"],
key=v["key"],
value=v["value"]
)
for v in data["values"]
]
return ConfigRequest(
operation=data.get("operation"),
keys=keys,
type=data.get("type"),
values=values
)
def from_pulsar(self, obj: ConfigRequest) -> Dict[str, Any]:
result = {}
if obj.operation:
result["operation"] = obj.operation
if obj.type:
result["type"] = obj.type
if obj.keys:
result["keys"] = [
{
"type": k.type,
"key": k.key
}
for k in obj.keys
]
if obj.values:
result["values"] = [
{
"type": v.type,
"key": v.key,
"value": v.value
}
for v in obj.values
]
return result
class ConfigResponseTranslator(MessageTranslator):
"""Translator for ConfigResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> ConfigResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: ConfigResponse) -> Dict[str, Any]:
result = {}
if obj.version is not None:
result["version"] = obj.version
if obj.values:
result["values"] = [
{
"type": v.type,
"key": v.key,
"value": v.value
}
for v in obj.values
]
if obj.directory:
result["directory"] = obj.directory
if obj.config:
result["config"] = obj.config
return result
def from_response_with_completion(self, obj: ConfigResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True

View file

@ -0,0 +1,191 @@
import base64
from typing import Dict, Any
from ...schema import Document, TextDocument, Chunk, DocumentEmbeddings, ChunkEmbeddings
from .base import SendTranslator
from .metadata import DocumentMetadataTranslator
from .primitives import SubgraphTranslator
class DocumentTranslator(SendTranslator):
"""Translator for Document schema objects (PDF docs etc.)"""
def __init__(self):
self.subgraph_translator = SubgraphTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> Document:
metadata = data.get("metadata", [])
# Handle base64 content validation
doc = base64.b64decode(data["data"])
from ...schema import Metadata
return Document(
metadata=Metadata(
id=data.get("id"),
metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [],
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
data=base64.b64encode(doc).decode("utf-8")
)
def from_pulsar(self, obj: Document) -> Dict[str, Any]:
result = {
"data": obj.data
}
if obj.metadata:
metadata_dict = {}
if obj.metadata.id:
metadata_dict["id"] = obj.metadata.id
if obj.metadata.user:
metadata_dict["user"] = obj.metadata.user
if obj.metadata.collection:
metadata_dict["collection"] = obj.metadata.collection
if obj.metadata.metadata:
metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata)
result["metadata"] = metadata_dict
return result
class TextDocumentTranslator(SendTranslator):
"""Translator for TextDocument schema objects"""
def __init__(self):
self.subgraph_translator = SubgraphTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> TextDocument:
metadata = data.get("metadata", [])
charset = data.get("charset", "utf-8")
# Text is base64 encoded in input
text = base64.b64decode(data["text"]).decode(charset)
from ...schema import Metadata
return TextDocument(
metadata=Metadata(
id=data.get("id"),
metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [],
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
text=text.encode("utf-8")
)
def from_pulsar(self, obj: TextDocument) -> Dict[str, Any]:
result = {
"text": obj.text.decode("utf-8") if isinstance(obj.text, bytes) else obj.text
}
if obj.metadata:
metadata_dict = {}
if obj.metadata.id:
metadata_dict["id"] = obj.metadata.id
if obj.metadata.user:
metadata_dict["user"] = obj.metadata.user
if obj.metadata.collection:
metadata_dict["collection"] = obj.metadata.collection
if obj.metadata.metadata:
metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata)
result["metadata"] = metadata_dict
return result
class ChunkTranslator(SendTranslator):
"""Translator for Chunk schema objects"""
def __init__(self):
self.subgraph_translator = SubgraphTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> Chunk:
metadata = data.get("metadata", [])
from ...schema import Metadata
return Chunk(
metadata=Metadata(
id=data.get("id"),
metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [],
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
chunk=data["chunk"].encode("utf-8") if isinstance(data["chunk"], str) else data["chunk"]
)
def from_pulsar(self, obj: Chunk) -> Dict[str, Any]:
result = {
"chunk": obj.chunk.decode("utf-8") if isinstance(obj.chunk, bytes) else obj.chunk
}
if obj.metadata:
metadata_dict = {}
if obj.metadata.id:
metadata_dict["id"] = obj.metadata.id
if obj.metadata.user:
metadata_dict["user"] = obj.metadata.user
if obj.metadata.collection:
metadata_dict["collection"] = obj.metadata.collection
if obj.metadata.metadata:
metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata)
result["metadata"] = metadata_dict
return result
class DocumentEmbeddingsTranslator(SendTranslator):
"""Translator for DocumentEmbeddings schema objects"""
def __init__(self):
self.subgraph_translator = SubgraphTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddings:
metadata = data.get("metadata", {})
chunks = [
ChunkEmbeddings(
chunk=chunk["chunk"].encode("utf-8") if isinstance(chunk["chunk"], str) else chunk["chunk"],
vectors=chunk["vectors"]
)
for chunk in data.get("chunks", [])
]
from ...schema import Metadata
return DocumentEmbeddings(
metadata=Metadata(
id=metadata.get("id"),
metadata=self.subgraph_translator.to_pulsar(metadata.get("metadata", [])),
user=metadata.get("user", "trustgraph"),
collection=metadata.get("collection", "default"),
),
chunks=chunks
)
def from_pulsar(self, obj: DocumentEmbeddings) -> Dict[str, Any]:
result = {
"chunks": [
{
"chunk": chunk.chunk.decode("utf-8") if isinstance(chunk.chunk, bytes) else chunk.chunk,
"vectors": chunk.vectors
}
for chunk in obj.chunks
]
}
if obj.metadata:
metadata_dict = {}
if obj.metadata.id:
metadata_dict["id"] = obj.metadata.id
if obj.metadata.user:
metadata_dict["user"] = obj.metadata.user
if obj.metadata.collection:
metadata_dict["collection"] = obj.metadata.collection
if obj.metadata.metadata:
metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata)
result["metadata"] = metadata_dict
return result

View file

@ -0,0 +1,33 @@
from typing import Dict, Any, Tuple
from ...schema import EmbeddingsRequest, EmbeddingsResponse
from .base import MessageTranslator
class EmbeddingsRequestTranslator(MessageTranslator):
"""Translator for EmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsRequest:
return EmbeddingsRequest(
text=data["text"]
)
def from_pulsar(self, obj: EmbeddingsRequest) -> Dict[str, Any]:
return {
"text": obj.text
}
class EmbeddingsResponseTranslator(MessageTranslator):
"""Translator for EmbeddingsResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: EmbeddingsResponse) -> Dict[str, Any]:
return {
"vectors": obj.vectors
}
def from_response_with_completion(self, obj: EmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True

View file

@ -0,0 +1,94 @@
from typing import Dict, Any, Tuple
from ...schema import (
DocumentEmbeddingsRequest, DocumentEmbeddingsResponse,
GraphEmbeddingsRequest, GraphEmbeddingsResponse
)
from .base import MessageTranslator
from .primitives import ValueTranslator
class DocumentEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for DocumentEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
return DocumentEmbeddingsRequest(
vectors=data["vectors"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default")
)
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
return {
"vectors": obj.vectors,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection
}
class DocumentEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for DocumentEmbeddingsResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
result = {}
if obj.documents:
result["documents"] = [
doc.decode("utf-8") if isinstance(doc, bytes) else doc
for doc in obj.documents
]
return result
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
class GraphEmbeddingsRequestTranslator(MessageTranslator):
"""Translator for GraphEmbeddingsRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
return GraphEmbeddingsRequest(
vectors=data["vectors"],
limit=int(data.get("limit", 10)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default")
)
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
return {
"vectors": obj.vectors,
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection
}
class GraphEmbeddingsResponseTranslator(MessageTranslator):
"""Translator for GraphEmbeddingsResponse schema objects"""
def __init__(self):
self.value_translator = ValueTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
result = {}
if obj.entities:
result["entities"] = [
self.value_translator.from_pulsar(entity)
for entity in obj.entities
]
return result
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True

View file

@ -0,0 +1,59 @@
from typing import Dict, Any, Tuple
from ...schema import FlowRequest, FlowResponse
from .base import MessageTranslator
class FlowRequestTranslator(MessageTranslator):
"""Translator for FlowRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> FlowRequest:
return FlowRequest(
operation=data.get("operation"),
class_name=data.get("class-name"),
class_definition=data.get("class-definition"),
description=data.get("description"),
flow_id=data.get("flow-id")
)
def from_pulsar(self, obj: FlowRequest) -> Dict[str, Any]:
result = {}
if obj.operation:
result["operation"] = obj.operation
if obj.class_name:
result["class-name"] = obj.class_name
if obj.class_definition:
result["class-definition"] = obj.class_definition
if obj.description:
result["description"] = obj.description
if obj.flow_id:
result["flow-id"] = obj.flow_id
return result
class FlowResponseTranslator(MessageTranslator):
"""Translator for FlowResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> FlowResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: FlowResponse) -> Dict[str, Any]:
result = {}
if obj.class_names:
result["class-names"] = obj.class_names
if obj.flow_ids:
result["flow-ids"] = obj.flow_ids
if obj.class_definition:
result["class-definition"] = obj.class_definition
if obj.flow:
result["flow"] = obj.flow
if obj.description:
result["description"] = obj.description
return result
def from_response_with_completion(self, obj: FlowResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True

View file

@ -0,0 +1,183 @@
from typing import Dict, Any, Tuple, Optional
from ...schema import (
KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings,
Metadata, EntityEmbeddings
)
from .base import MessageTranslator
from .primitives import ValueTranslator, SubgraphTranslator
from .metadata import DocumentMetadataTranslator
class KnowledgeRequestTranslator(MessageTranslator):
"""Translator for KnowledgeRequest schema objects"""
def __init__(self):
self.value_translator = ValueTranslator()
self.subgraph_translator = SubgraphTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> KnowledgeRequest:
triples = None
if "triples" in data:
triples = Triples(
metadata=Metadata(
id=data["triples"]["metadata"]["id"],
metadata=self.subgraph_translator.to_pulsar(
data["triples"]["metadata"]["metadata"]
),
user=data["triples"]["metadata"]["user"],
collection=data["triples"]["metadata"]["collection"]
),
triples=self.subgraph_translator.to_pulsar(data["triples"]["triples"]),
)
graph_embeddings = None
if "graph-embeddings" in data:
graph_embeddings = GraphEmbeddings(
metadata=Metadata(
id=data["graph-embeddings"]["metadata"]["id"],
metadata=self.subgraph_translator.to_pulsar(
data["graph-embeddings"]["metadata"]["metadata"]
),
user=data["graph-embeddings"]["metadata"]["user"],
collection=data["graph-embeddings"]["metadata"]["collection"]
),
entities=[
EntityEmbeddings(
entity=self.value_translator.to_pulsar(ent["entity"]),
vectors=ent["vectors"],
)
for ent in data["graph-embeddings"]["entities"]
]
)
return KnowledgeRequest(
operation=data.get("operation"),
user=data.get("user"),
id=data.get("id"),
flow=data.get("flow"),
collection=data.get("collection"),
triples=triples,
graph_embeddings=graph_embeddings,
)
def from_pulsar(self, obj: KnowledgeRequest) -> Dict[str, Any]:
result = {}
if obj.operation:
result["operation"] = obj.operation
if obj.user:
result["user"] = obj.user
if obj.id:
result["id"] = obj.id
if obj.flow:
result["flow"] = obj.flow
if obj.collection:
result["collection"] = obj.collection
if obj.triples:
result["triples"] = {
"metadata": {
"id": obj.triples.metadata.id,
"metadata": self.subgraph_translator.from_pulsar(
obj.triples.metadata.metadata
),
"user": obj.triples.metadata.user,
"collection": obj.triples.metadata.collection,
},
"triples": self.subgraph_translator.from_pulsar(obj.triples.triples),
}
if obj.graph_embeddings:
result["graph-embeddings"] = {
"metadata": {
"id": obj.graph_embeddings.metadata.id,
"metadata": self.subgraph_translator.from_pulsar(
obj.graph_embeddings.metadata.metadata
),
"user": obj.graph_embeddings.metadata.user,
"collection": obj.graph_embeddings.metadata.collection,
},
"entities": [
{
"vectors": entity.vectors,
"entity": self.value_translator.from_pulsar(entity.entity),
}
for entity in obj.graph_embeddings.entities
],
}
return result
class KnowledgeResponseTranslator(MessageTranslator):
"""Translator for KnowledgeResponse schema objects"""
def __init__(self):
self.value_translator = ValueTranslator()
self.subgraph_translator = SubgraphTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> KnowledgeResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: KnowledgeResponse) -> Dict[str, Any]:
# Response to list operation
if obj.ids is not None:
return {"ids": obj.ids}
# Streaming triples response
if obj.triples:
return {
"triples": {
"metadata": {
"id": obj.triples.metadata.id,
"metadata": self.subgraph_translator.from_pulsar(
obj.triples.metadata.metadata
),
"user": obj.triples.metadata.user,
"collection": obj.triples.metadata.collection,
},
"triples": self.subgraph_translator.from_pulsar(obj.triples.triples),
}
}
# Streaming graph embeddings response
if obj.graph_embeddings:
return {
"graph-embeddings": {
"metadata": {
"id": obj.graph_embeddings.metadata.id,
"metadata": self.subgraph_translator.from_pulsar(
obj.graph_embeddings.metadata.metadata
),
"user": obj.graph_embeddings.metadata.user,
"collection": obj.graph_embeddings.metadata.collection,
},
"entities": [
{
"vectors": entity.vectors,
"entity": self.value_translator.from_pulsar(entity.entity),
}
for entity in obj.graph_embeddings.entities
],
}
}
# End of stream marker
if obj.eos is True:
return {"eos": True}
# Empty response (successful delete)
return {}
def from_response_with_completion(self, obj: KnowledgeResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
response = self.from_pulsar(obj)
# Check if this is a final response
is_final = (
obj.ids is not None or # List response
obj.eos is True or # End of stream
(not obj.triples and not obj.graph_embeddings) # Empty response
)
return response, is_final

View file

@ -0,0 +1,124 @@
from typing import Dict, Any, Tuple, Optional
from ...schema import LibrarianRequest, LibrarianResponse, DocumentMetadata, ProcessingMetadata, Criteria
from .base import MessageTranslator
from .metadata import DocumentMetadataTranslator, ProcessingMetadataTranslator
class LibraryRequestTranslator(MessageTranslator):
"""Translator for LibrarianRequest schema objects"""
def __init__(self):
self.doc_metadata_translator = DocumentMetadataTranslator()
self.proc_metadata_translator = ProcessingMetadataTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> LibrarianRequest:
# Document metadata
doc_metadata = None
if "document-metadata" in data:
doc_metadata = self.doc_metadata_translator.to_pulsar(data["document-metadata"])
# Processing metadata
proc_metadata = None
if "processing-metadata" in data:
proc_metadata = self.proc_metadata_translator.to_pulsar(data["processing-metadata"])
# Criteria
criteria = []
if "criteria" in data:
criteria = [
Criteria(
key=c["key"],
value=c["value"],
operator=c["operator"]
)
for c in data["criteria"]
]
# Content as bytes
content = None
if "content" in data:
if isinstance(data["content"], str):
content = data["content"].encode("utf-8")
else:
content = data["content"]
return LibrarianRequest(
operation=data.get("operation"),
document_id=data.get("document-id"),
processing_id=data.get("processing-id"),
document_metadata=doc_metadata,
processing_metadata=proc_metadata,
content=content,
user=data.get("user"),
collection=data.get("collection"),
criteria=criteria
)
def from_pulsar(self, obj: LibrarianRequest) -> Dict[str, Any]:
result = {}
if obj.operation:
result["operation"] = obj.operation
if obj.document_id:
result["document-id"] = obj.document_id
if obj.processing_id:
result["processing-id"] = obj.processing_id
if obj.document_metadata:
result["document-metadata"] = self.doc_metadata_translator.from_pulsar(obj.document_metadata)
if obj.processing_metadata:
result["processing-metadata"] = self.proc_metadata_translator.from_pulsar(obj.processing_metadata)
if obj.content:
result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content
if obj.user:
result["user"] = obj.user
if obj.collection:
result["collection"] = obj.collection
if obj.criteria is not None:
result["criteria"] = [
{
"key": c.key,
"value": c.value,
"operator": c.operator
}
for c in obj.criteria
]
return result
class LibraryResponseTranslator(MessageTranslator):
"""Translator for LibrarianResponse schema objects"""
def __init__(self):
self.doc_metadata_translator = DocumentMetadataTranslator()
self.proc_metadata_translator = ProcessingMetadataTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> LibrarianResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: LibrarianResponse) -> Dict[str, Any]:
result = {}
if obj.document_metadata:
result["document-metadata"] = self.doc_metadata_translator.from_pulsar(obj.document_metadata)
if obj.content:
result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content
if obj.document_metadatas is not None:
result["document-metadatas"] = [
self.doc_metadata_translator.from_pulsar(dm)
for dm in obj.document_metadatas
]
if obj.processing_metadatas is not None:
result["processing-metadatas"] = [
self.proc_metadata_translator.from_pulsar(pm)
for pm in obj.processing_metadatas
]
return result
def from_response_with_completion(self, obj: LibrarianResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True

View file

@ -0,0 +1,81 @@
from typing import Dict, Any, Optional
from ...schema import DocumentMetadata, ProcessingMetadata
from .base import Translator
from .primitives import SubgraphTranslator
class DocumentMetadataTranslator(Translator):
"""Translator for DocumentMetadata schema objects"""
def __init__(self):
self.subgraph_translator = SubgraphTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> DocumentMetadata:
metadata = data.get("metadata", [])
return DocumentMetadata(
id=data.get("id"),
time=data.get("time"),
kind=data.get("kind"),
title=data.get("title"),
comments=data.get("comments"),
metadata=self.subgraph_translator.to_pulsar(metadata) if metadata is not None else [],
user=data.get("user"),
tags=data.get("tags")
)
def from_pulsar(self, obj: DocumentMetadata) -> Dict[str, Any]:
result = {}
if obj.id:
result["id"] = obj.id
if obj.time:
result["time"] = obj.time
if obj.kind:
result["kind"] = obj.kind
if obj.title:
result["title"] = obj.title
if obj.comments:
result["comments"] = obj.comments
if obj.metadata is not None:
result["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata)
if obj.user:
result["user"] = obj.user
if obj.tags is not None:
result["tags"] = obj.tags
return result
class ProcessingMetadataTranslator(Translator):
"""Translator for ProcessingMetadata schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> ProcessingMetadata:
return ProcessingMetadata(
id=data.get("id"),
document_id=data.get("document-id"),
time=data.get("time"),
flow=data.get("flow"),
user=data.get("user"),
collection=data.get("collection"),
tags=data.get("tags")
)
def from_pulsar(self, obj: ProcessingMetadata) -> Dict[str, Any]:
result = {}
if obj.id:
result["id"] = obj.id
if obj.document_id:
result["document-id"] = obj.document_id
if obj.time:
result["time"] = obj.time
if obj.flow:
result["flow"] = obj.flow
if obj.user:
result["user"] = obj.user
if obj.collection:
result["collection"] = obj.collection
if obj.tags is not None:
result["tags"] = obj.tags
return result

View file

@ -0,0 +1,47 @@
from typing import Dict, Any, List
from ...schema import Value, Triple
from .base import Translator
class ValueTranslator(Translator):
"""Translator for Value schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> Value:
return Value(value=data["v"], is_uri=data["e"])
def from_pulsar(self, obj: Value) -> Dict[str, Any]:
return {"v": obj.value, "e": obj.is_uri}
class TripleTranslator(Translator):
"""Translator for Triple schema objects"""
def __init__(self):
self.value_translator = ValueTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> Triple:
return Triple(
s=self.value_translator.to_pulsar(data["s"]),
p=self.value_translator.to_pulsar(data["p"]),
o=self.value_translator.to_pulsar(data["o"])
)
def from_pulsar(self, obj: Triple) -> Dict[str, Any]:
return {
"s": self.value_translator.from_pulsar(obj.s),
"p": self.value_translator.from_pulsar(obj.p),
"o": self.value_translator.from_pulsar(obj.o)
}
class SubgraphTranslator(Translator):
"""Translator for lists of Triple objects (subgraphs)"""
def __init__(self):
self.triple_translator = TripleTranslator()
def to_pulsar(self, data: List[Dict[str, Any]]) -> List[Triple]:
return [self.triple_translator.to_pulsar(t) for t in data]
def from_pulsar(self, obj: List[Triple]) -> List[Dict[str, Any]]:
return [self.triple_translator.from_pulsar(t) for t in obj]

View file

@ -0,0 +1,54 @@
import json
from typing import Dict, Any, Tuple
from ...schema import PromptRequest, PromptResponse
from .base import MessageTranslator
class PromptRequestTranslator(MessageTranslator):
"""Translator for PromptRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> PromptRequest:
# Handle both "terms" and "variables" input keys
terms = data.get("terms", {})
if "variables" in data:
# Convert variables to JSON strings as expected by the service
terms = {
k: json.dumps(v)
for k, v in data["variables"].items()
}
return PromptRequest(
id=data.get("id"),
terms=terms
)
def from_pulsar(self, obj: PromptRequest) -> Dict[str, Any]:
result = {}
if obj.id:
result["id"] = obj.id
if obj.terms:
result["terms"] = obj.terms
return result
class PromptResponseTranslator(MessageTranslator):
"""Translator for PromptResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> PromptResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]:
result = {}
if obj.text:
result["text"] = obj.text
if obj.object:
result["object"] = obj.object
return result
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True

View file

@ -0,0 +1,81 @@
from typing import Dict, Any, Tuple
from ...schema import DocumentRagQuery, DocumentRagResponse, GraphRagQuery, GraphRagResponse
from .base import MessageTranslator
class DocumentRagRequestTranslator(MessageTranslator):
"""Translator for DocumentRagQuery schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagQuery:
return DocumentRagQuery(
query=data["query"],
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
doc_limit=int(data.get("doc-limit", 20))
)
def from_pulsar(self, obj: DocumentRagQuery) -> Dict[str, Any]:
return {
"query": obj.query,
"user": obj.user,
"collection": obj.collection,
"doc-limit": obj.doc_limit
}
class DocumentRagResponseTranslator(MessageTranslator):
"""Translator for DocumentRagResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
return {
"response": obj.response
}
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True
class GraphRagRequestTranslator(MessageTranslator):
"""Translator for GraphRagQuery schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> GraphRagQuery:
return GraphRagQuery(
query=data["query"],
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
entity_limit=int(data.get("entity-limit", 50)),
triple_limit=int(data.get("triple-limit", 30)),
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
max_path_length=int(data.get("max-path-length", 2))
)
def from_pulsar(self, obj: GraphRagQuery) -> Dict[str, Any]:
return {
"query": obj.query,
"user": obj.user,
"collection": obj.collection,
"entity-limit": obj.entity_limit,
"triple-limit": obj.triple_limit,
"max-subgraph-size": obj.max_subgraph_size,
"max-path-length": obj.max_path_length
}
class GraphRagResponseTranslator(MessageTranslator):
"""Translator for GraphRagResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> GraphRagResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
return {
"response": obj.response
}
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True

View file

@ -0,0 +1,42 @@
from typing import Dict, Any, Tuple
from ...schema import TextCompletionRequest, TextCompletionResponse
from .base import MessageTranslator
class TextCompletionRequestTranslator(MessageTranslator):
"""Translator for TextCompletionRequest schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> TextCompletionRequest:
return TextCompletionRequest(
system=data["system"],
prompt=data["prompt"]
)
def from_pulsar(self, obj: TextCompletionRequest) -> Dict[str, Any]:
return {
"system": obj.system,
"prompt": obj.prompt
}
class TextCompletionResponseTranslator(MessageTranslator):
"""Translator for TextCompletionResponse schema objects"""
def to_pulsar(self, data: Dict[str, Any]) -> TextCompletionResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: TextCompletionResponse) -> Dict[str, Any]:
result = {"response": obj.response}
if obj.in_token:
result["in_token"] = obj.in_token
if obj.out_token:
result["out_token"] = obj.out_token
if obj.model:
result["model"] = obj.model
return result
def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True

View file

@ -0,0 +1,60 @@
from typing import Dict, Any, Tuple, Optional
from ...schema import TriplesQueryRequest, TriplesQueryResponse
from .base import MessageTranslator
from .primitives import ValueTranslator, SubgraphTranslator
class TriplesQueryRequestTranslator(MessageTranslator):
"""Translator for TriplesQueryRequest schema objects"""
def __init__(self):
self.value_translator = ValueTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> TriplesQueryRequest:
s = self.value_translator.to_pulsar(data["s"]) if "s" in data else None
p = self.value_translator.to_pulsar(data["p"]) if "p" in data else None
o = self.value_translator.to_pulsar(data["o"]) if "o" in data else None
return TriplesQueryRequest(
s=s,
p=p,
o=o,
limit=int(data.get("limit", 10000)),
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default")
)
def from_pulsar(self, obj: TriplesQueryRequest) -> Dict[str, Any]:
result = {
"limit": obj.limit,
"user": obj.user,
"collection": obj.collection
}
if obj.s:
result["s"] = self.value_translator.from_pulsar(obj.s)
if obj.p:
result["p"] = self.value_translator.from_pulsar(obj.p)
if obj.o:
result["o"] = self.value_translator.from_pulsar(obj.o)
return result
class TriplesQueryResponseTranslator(MessageTranslator):
"""Translator for TriplesQueryResponse schema objects"""
def __init__(self):
self.subgraph_translator = SubgraphTranslator()
def to_pulsar(self, data: Dict[str, Any]) -> TriplesQueryResponse:
raise NotImplementedError("Response translation to Pulsar not typically needed")
def from_pulsar(self, obj: TriplesQueryResponse) -> Dict[str, Any]:
return {
"response": self.subgraph_translator.from_pulsar(obj.triples)
}
def from_response_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]:
"""Returns (response_dict, is_final)"""
return self.from_pulsar(obj), True

View file

@ -1,80 +1,68 @@
#!/usr/bin/env python3
"""
Loads Graph embeddings into TrustGraph processing.
FIXME: This hasn't been updated following API gateway change.
Loads triples into the knowledge graph.
"""
import pulsar
from pulsar.schema import JsonSchema
from trustgraph.schema import Triples, Triple, Value, Metadata
import asyncio
import argparse
import os
import time
import pyarrow as pa
import rdflib
import json
from websockets.asyncio.client import connect
from trustgraph.log_level import LogLevel
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
default_user = 'trustgraph'
default_collection = 'default'
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
default_output_queue = triples_store_queue
class Loader:
def __init__(
self,
pulsar_host,
output_queue,
log_level,
files,
flow,
user,
collection,
pulsar_api_key=None,
document_id,
url = default_url,
):
if pulsar_api_key:
auth = pulsar.AuthenticationToken(pulsar_api_key)
self.client = pulsar.Client(
pulsar_host,
authentication=auth,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
else:
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
if not url.endswith("/"):
url += "/"
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(Triples),
chunking_enabled=True,
)
url = url + f"api/v1/flow/{flow}/import/triples"
self.url = url
self.files = files
self.user = user
self.collection = collection
self.document_id = document_id
def run(self):
async def run(self):
try:
for file in self.files:
self.load_file(file)
async with connect(self.url) as ws:
for file in self.files:
await self.load_file(file, ws)
except Exception as e:
print(e, flush=True)
def load_file(self, file):
async def load_file(self, file, ws):
g = rdflib.Graph()
g.parse(file, format="turtle")
def Value(value, is_uri):
return { "v": value, "e": is_uri }
triples = []
for e in g:
s = Value(value=str(e[0]), is_uri=True)
p = Value(value=str(e[1]), is_uri=True)
@ -83,20 +71,23 @@ class Loader:
else:
o = Value(value=str(e[2]), is_uri=False)
r = Triples(
metadata=Metadata(
id=None,
metadata=[],
user=self.user,
collection=self.collection,
),
triples=[ Triple(s=s, p=p, o=o) ]
)
req = {
"metadata": {
"id": self.document_id,
"metadata": [],
"user": self.user,
"collection": self.collection
},
"triples": [
{
"s": s,
"p": p,
"o": o,
}
]
}
self.producer.send(r)
def __del__(self):
self.client.close()
await ws.send(json.dumps(req))
def main():
@ -106,9 +97,15 @@ def main():
)
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
'-u', '--api-url',
default=default_url,
help=f'API URL (default: {default_url})',
)
parser.add_argument(
'-i', '--document-id',
required=True,
help=f'Document ID)',
)
parser.add_argument(
@ -116,39 +113,19 @@ def main():
default="default",
help=f'Flow ID (default: default)'
)
parser.add_argument(
'--pulsar-api-key',
default=default_pulsar_api_key,
help=f'Pulsar API key',
)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)
parser.add_argument(
'-u', '--user',
'-U', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-c', '--collection',
'-C', '--collection',
default=default_collection,
help=f'Collection ID (default: {default_collection})'
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.ERROR,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)
parser.add_argument(
'files', nargs='+',
help=f'Turtle files to load'
@ -160,16 +137,15 @@ def main():
try:
p = Loader(
pulsar_host=args.pulsar_host,
pulsar_api_key=args.pulsar_api_key,
output_queue=args.output_queue,
log_level=args.log_level,
files=args.files,
user=args.user,
collection=args.collection,
document_id = args.document_id,
url = args.api_url,
flow = args.flow_id,
files = args.files,
user = args.user,
collection = args.collection,
)
p.run()
asyncio.run(p.run())
print("File loaded.")
break
@ -181,6 +157,5 @@ def main():
time.sleep(10)
print("Not implemented.")
#main()
main()

View file

@ -0,0 +1,109 @@
#!/usr/bin/env python3
"""
Dump out a stream of token rates, input, output and total. This is averaged
across the time since tg-show-token-rate is started.
"""
import requests
import argparse
import json
import time
default_metrics_url = "http://localhost:8088/api/metrics"
class Collate:
def look(self, data):
return sum(
[
float(x["value"][1])
for x in data["data"]["result"]
]
)
def __init__(self, data):
self.last = self.look(data)
self.total = 0
self.time = 0
def record(self, data, time):
value = self.look(data)
delta = value - self.last
self.last = value
self.total += delta
self.time += time
return delta/time, self.total/self.time
def dump_status(metrics_url, number_samples, period):
input_url = f"{metrics_url}/query?query=input_tokens_total"
output_url = f"{metrics_url}/query?query=output_tokens_total"
resp = requests.get(input_url)
obj = resp.json()
input = Collate(obj)
resp = requests.get(output_url)
obj = resp.json()
output = Collate(obj)
print(f"{'Input':>10s} {'Output':>10s} {'Total':>10s}")
print(f"{'-----':>10s} {'------':>10s} {'-----':>10s}")
for i in range(number_samples):
time.sleep(period)
resp = requests.get(input_url)
obj = resp.json()
inr, inl = input.record(obj, period)
resp = requests.get(output_url)
obj = resp.json()
outr, outl = output.record(obj, period)
print(f"{inl:10.1f} {outl:10.1f} {inl+outl:10.1f}")
def main():
parser = argparse.ArgumentParser(
prog='tg-show-processor-state',
description=__doc__,
)
parser.add_argument(
'-m', '--metrics-url',
default=default_metrics_url,
help=f'Metrics URL (default: {default_metrics_url})',
)
parser.add_argument(
'-p', '--period',
type=int,
default=1,
help=f'Metrics period (default: 1)',
)
parser.add_argument(
'-n', '--number-samples',
type=int,
default=100,
help=f'Metrics period (default: 100)',
)
args = parser.parse_args()
try:
dump_status(**vars(args))
except Exception as e:
print("Exception:", e, flush=True)
main()

View file

@ -80,6 +80,7 @@ setuptools.setup(
"scripts/tg-show-processor-state",
"scripts/tg-show-prompts",
"scripts/tg-show-token-costs",
"scripts/tg-show-token-rate",
"scripts/tg-show-tools",
"scripts/tg-start-flow",
"scripts/tg-unload-kg-core",

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.rev_gateway import run
run()

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.model.text_completion.vllm import run
run()

View file

@ -71,6 +71,7 @@ setuptools.setup(
scripts=[
"scripts/agent-manager-react",
"scripts/api-gateway",
"scripts/rev-gateway",
"scripts/chunker-recursive",
"scripts/chunker-token",
"scripts/config-svc",
@ -118,6 +119,7 @@ setuptools.setup(
"scripts/text-completion-ollama",
"scripts/text-completion-openai",
"scripts/text-completion-tgi",
"scripts/text-completion-vllm",
"scripts/triples-query-cassandra",
"scripts/triples-query-falkordb",
"scripts/triples-query-memgraph",

View file

@ -21,16 +21,19 @@ RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
default_ident = "kg-extract-definitions"
default_concurrency = 1
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", 1)
super(Processor, self).__init__(
**params | {
"id": id,
"concurrency": concurrency,
}
)
@ -38,7 +41,8 @@ class Processor(FlowProcessor):
ConsumerSpec(
name = "input",
schema = Chunk,
handler = self.on_message
handler = self.on_message,
concurrency = concurrency,
)
)
@ -190,6 +194,13 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)
def run():

View file

@ -20,16 +20,19 @@ RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
default_ident = "kg-extract-relationships"
default_concurrency = 1
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", 1)
super(Processor, self).__init__(
**params | {
"id": id,
"concurrency": concurrency,
}
)
@ -37,7 +40,8 @@ class Processor(FlowProcessor):
ConsumerSpec(
name = "input",
schema = Chunk,
handler = self.on_message
handler = self.on_message,
concurrency = concurrency,
)
)
@ -192,6 +196,13 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)
def run():

View file

@ -1,5 +1,6 @@
from ... schema import AgentRequest, AgentResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
@ -20,24 +21,12 @@ class AgentRequestor(ServiceRequestor):
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("agent")
self.response_translator = TranslatorRegistry.get_response_translator("agent")
def to_request(self, body):
return AgentRequest(
question=body["question"]
)
return self.request_translator.to_pulsar(body)
def from_response(self, message):
resp = {
}
if message.answer:
resp["answer"] = message.answer
if message.thought:
resp["thought"] = message.thought
if message.observation:
resp["observation"] = message.observation
# The 2nd boolean expression indicates whether we're done responding
return resp, (message.answer is not None)
return self.response_translator.from_response_with_completion(message)

View file

@ -2,6 +2,7 @@
from ... schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
from ... schema import config_request_queue
from ... schema import config_response_queue
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
@ -19,60 +20,12 @@ class ConfigRequestor(ServiceRequestor):
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("config")
self.response_translator = TranslatorRegistry.get_response_translator("config")
def to_request(self, body):
if "keys" in body:
keys = [
ConfigKey(
type = k["type"],
key = k["key"],
)
for k in body["keys"]
]
else:
keys = None
if "values" in body:
values = [
ConfigValue(
type = v["type"],
key = v["key"],
value = v["value"],
)
for v in body["values"]
]
else:
values = None
return ConfigRequest(
operation = body.get("operation", None),
keys = keys,
type = body.get("type", None),
values = values
)
return self.request_translator.to_pulsar(body)
def from_response(self, message):
response = { }
if message.version is not None:
response["version"] = message.version
if message.values is not None:
response["values"] = [
{
"type": v.type,
"key": v.key,
"value": v.value,
}
for v in message.values
]
if message.directory is not None:
response["directory"] = message.directory
if message.config is not None:
response["config"] = message.config
return response, True
return self.response_translator.from_response_with_completion(message)

View file

@ -0,0 +1,96 @@
import asyncio
import uuid
import msgpack
from . knowledge import KnowledgeRequestor
class CoreExport:
def __init__(self, pulsar_client):
self.pulsar_client = pulsar_client
async def process(self, data, error, ok, request):
id = request.query["id"]
user = request.query["user"]
response = await ok()
kr = KnowledgeRequestor(
pulsar_client = self.pulsar_client,
consumer = "api-gateway-core-export-" + str(uuid.uuid4()),
subscriber = "api-gateway-core-export-" + str(uuid.uuid4()),
)
try:
await kr.start()
async def responder(resp, fin):
if "graph-embeddings" in resp:
data = resp["graph-embeddings"]
msg = (
"ge",
{
"m": {
"i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"e": [
{
"e": ent["entity"],
"v": ent["vectors"],
}
for ent in data["entities"]
]
}
)
enc = msgpack.packb(msg)
await response.write(enc)
if "triples" in resp:
data = resp["triples"]
msg = (
"t",
{
"m": {
"i": data["metadata"]["id"],
"m": data["metadata"]["metadata"],
"u": data["metadata"]["user"],
"c": data["metadata"]["collection"],
},
"t": data["triples"],
}
)
enc = msgpack.packb(msg)
await response.write(enc)
await kr.process(
{
"operation": "get-kg-core",
"user": user,
"id": id,
},
responder
)
except Exception as e:
print("Exception:", e)
finally:
await kr.stop()
await response.write_eof()
return response

View file

@ -0,0 +1,94 @@
import asyncio
import json
import uuid
import msgpack
from . knowledge import KnowledgeRequestor
class CoreImport:
def __init__(self, pulsar_client):
self.pulsar_client = pulsar_client
async def process(self, data, error, ok, request):
id = request.query["id"]
user = request.query["user"]
kr = KnowledgeRequestor(
pulsar_client = self.pulsar_client,
consumer = "api-gateway-core-import-" + str(uuid.uuid4()),
subscriber = "api-gateway-core-import-" + str(uuid.uuid4()),
)
await kr.start()
try:
unpacker = msgpack.Unpacker()
while True:
buf = await data.read(128*1024)
if not buf: break
unpacker.feed(buf)
for unpacked in unpacker:
if unpacked[0] == "t":
msg = unpacked[1]
msg = {
"operation": "put-kg-core",
"user": user,
"id": id,
"triples": {
"metadata": {
"id": id,
"metadata": msg["m"]["m"],
"user": user,
"collection": "default", # Not used?
},
"triples": msg["t"],
}
}
await kr.process(msg)
elif unpacked[0] == "ge":
msg = unpacked[1]
msg = {
"operation": "put-kg-core",
"user": user,
"id": id,
"graph-embeddings": {
"metadata": {
"id": id,
"metadata": msg["m"]["m"],
"user": user,
"collection": "default", # Not used?
},
"entities": [
{
"entity": ent["e"],
"vectors": ent["v"],
}
for ent in msg["e"]
]
}
}
await kr.process(msg)
except Exception as e:
print("Exception:", e)
await error(str(e))
finally:
await kr.stop()
print("All done.")
response = await ok()
await response.write_eof()
return response

View file

@ -6,8 +6,7 @@ from aiohttp import WSMsgType
from ... schema import Metadata
from ... schema import DocumentEmbeddings, ChunkEmbeddings
from ... base import Publisher
from . serialize import to_subgraph
from ... messaging.translators.document_loading import DocumentEmbeddingsTranslator
class DocumentEmbeddingsImport:
@ -17,6 +16,7 @@ class DocumentEmbeddingsImport:
self.ws = ws
self.running = running
self.translator = DocumentEmbeddingsTranslator()
self.publisher = Publisher(
pulsar_client, topic = queue, schema = DocumentEmbeddings
@ -36,23 +36,7 @@ class DocumentEmbeddingsImport:
async def receive(self, msg):
data = msg.json()
elt = DocumentEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
chunks=[
ChunkEmbeddings(
chunk=de["chunk"].encode("utf-8"),
vectors=de["vectors"],
)
for de in data["chunks"]
],
)
elt = self.translator.to_pulsar(data)
await self.publisher.send(None, elt)
async def run(self):

View file

@ -2,9 +2,9 @@
import base64
from ... schema import Document, Metadata
from ... messaging import TranslatorRegistry
from . sender import ServiceSender
from . serialize import to_subgraph
class DocumentLoad(ServiceSender):
def __init__(self, pulsar_client, queue):
@ -15,26 +15,9 @@ class DocumentLoad(ServiceSender):
schema = Document,
)
self.translator = TranslatorRegistry.get_request_translator("document")
def to_request(self, body):
if "metadata" in body:
metadata = to_subgraph(body["metadata"])
else:
metadata = []
# Doing a base64 decoe/encode here to make sure the
# content is valid base64
doc = base64.b64decode(body["data"])
print("Document received")
return Document(
metadata=Metadata(
id=body.get("id"),
metadata=metadata,
user=body.get("user", "trustgraph"),
collection=body.get("collection", "default"),
),
data=base64.b64encode(doc).decode("utf-8")
)
return self.translator.to_pulsar(body)

View file

@ -1,5 +1,6 @@
from ... schema import DocumentRagQuery, DocumentRagResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
@ -20,14 +21,12 @@ class DocumentRagRequestor(ServiceRequestor):
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("document-rag")
self.response_translator = TranslatorRegistry.get_response_translator("document-rag")
def to_request(self, body):
return DocumentRagQuery(
query=body["query"],
user=body.get("user", "trustgraph"),
collection=body.get("collection", "default"),
doc_limit=int(body.get("doc-limit", 20)),
)
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return { "response": message.response }, True
return self.response_translator.from_response_with_completion(message)

View file

@ -1,5 +1,6 @@
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
@ -20,11 +21,12 @@ class EmbeddingsRequestor(ServiceRequestor):
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("embeddings")
self.response_translator = TranslatorRegistry.get_response_translator("embeddings")
def to_request(self, body):
return EmbeddingsRequest(
text=body["text"]
)
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return { "vectors": message.vectors }, True
return self.response_translator.from_response_with_completion(message)

View file

@ -2,6 +2,7 @@
from ... schema import FlowRequest, FlowResponse
from ... schema import flow_request_queue
from ... schema import flow_response_queue
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
@ -19,34 +20,12 @@ class FlowRequestor(ServiceRequestor):
timeout=timeout,
)
def to_request(self, body):
self.request_translator = TranslatorRegistry.get_request_translator("flow")
self.response_translator = TranslatorRegistry.get_response_translator("flow")
return FlowRequest(
operation = body.get("operation", None),
class_name = body.get("class-name", None),
class_definition = body.get("class-definition", None),
description = body.get("description", None),
flow_id = body.get("flow-id", None),
)
def to_request(self, body):
return self.request_translator.to_pulsar(body)
def from_response(self, message):
response = { }
if message.class_names is not None:
response["class-names"] = message.class_names
if message.flow_ids is not None:
response["flow-ids"] = message.flow_ids
if message.class_definition is not None:
response["class-definition"] = message.class_definition
if message.flow is not None:
response["flow"] = message.flow
if message.description is not None:
response["description"] = message.description
return response, True
return self.response_translator.from_response_with_completion(message)

View file

@ -1,8 +1,8 @@
from ... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
from . serialize import serialize_value
class GraphEmbeddingsQueryRequestor(ServiceRequestor):
def __init__(
@ -21,22 +21,12 @@ class GraphEmbeddingsQueryRequestor(ServiceRequestor):
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("graph-embeddings-query")
self.response_translator = TranslatorRegistry.get_response_translator("graph-embeddings-query")
def to_request(self, body):
limit = int(body.get("limit", 20))
return GraphEmbeddingsRequest(
vectors = body["vectors"],
limit = limit,
user = body.get("user", "trustgraph"),
collection = body.get("collection", "default"),
)
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return {
"entities": [
serialize_value(ent) for ent in message.entities
]
}, True
return self.response_translator.from_response_with_completion(message)

View file

@ -1,5 +1,6 @@
from ... schema import GraphRagQuery, GraphRagResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
@ -20,17 +21,12 @@ class GraphRagRequestor(ServiceRequestor):
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("graph-rag")
self.response_translator = TranslatorRegistry.get_response_translator("graph-rag")
def to_request(self, body):
return GraphRagQuery(
query=body["query"],
user=body.get("user", "trustgraph"),
collection=body.get("collection", "default"),
entity_limit=int(body.get("entity-limit", 50)),
triple_limit=int(body.get("triple-limit", 30)),
max_subgraph_size=int(body.get("max-subgraph-size", 1000)),
max_path_length=int(body.get("max-path-length", 2)),
)
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return { "response": message.response }, True
return self.response_translator.from_response_with_completion(message)

View file

@ -5,11 +5,9 @@ from ... schema import KnowledgeRequest, KnowledgeResponse, Triples
from ... schema import GraphEmbeddings, Metadata, EntityEmbeddings
from ... schema import knowledge_request_queue
from ... schema import knowledge_response_queue
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
from . serialize import serialize_graph_embeddings
from . serialize import serialize_triples, to_subgraph, to_value
from . serialize import to_document_metadata, to_processing_metadata
class KnowledgeRequestor(ServiceRequestor):
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
@ -25,73 +23,12 @@ class KnowledgeRequestor(ServiceRequestor):
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("knowledge")
self.response_translator = TranslatorRegistry.get_response_translator("knowledge")
def to_request(self, body):
if "triples" in body:
triples = Triples(
metadata=Metadata(
id = body["triples"]["metadata"]["id"],
metadata = to_subgraph(body["triples"]["metadata"]["metadata"]),
user = body["triples"]["metadata"]["user"],
),
triples = to_subgraph(body["triples"]["triples"]),
)
else:
triples = None
if "graph-embeddings" in body:
ge = GraphEmbeddings(
metadata = Metadata(
id = body["graph-embeddings"]["metadata"]["id"],
metadata = to_subgraph(body["graph-embeddings"]["metadata"]["metadata"]),
user = body["graph-embeddings"]["metadata"]["user"],
),
entities=[
EntityEmbeddings(
entity = to_value(ent["entity"]),
vectors = ent["vectors"],
)
for ent in body["graph-embeddings"]["entities"]
]
)
else:
ge = None
return KnowledgeRequest(
operation = body.get("operation", None),
user = body.get("user", None),
id = body.get("id", None),
flow = body.get("flow", None),
collection = body.get("collection", None),
triples = triples,
graph_embeddings = ge,
)
return self.request_translator.to_pulsar(body)
def from_response(self, message):
# Response to list,
if message.ids is not None:
return {
"ids": message.ids
}, True
if message.triples:
return {
"triples": serialize_triples(message.triples)
}, False
if message.graph_embeddings:
return {
"graph-embeddings": serialize_graph_embeddings(
message.graph_embeddings
)
}, False
if message.eos is True:
return {
"eos": True
}, True
# Empty case, return from successful delete.
return {}, True
return self.response_translator.from_response_with_completion(message)

View file

@ -4,12 +4,9 @@ import base64
from ... schema import LibrarianRequest, LibrarianResponse
from ... schema import librarian_request_queue
from ... schema import librarian_response_queue
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
from . serialize import serialize_document_metadata
from . serialize import serialize_processing_metadata
from . serialize import to_document_metadata, to_processing_metadata
from . serialize import to_criteria
class LibrarianRequestor(ServiceRequestor):
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
@ -25,67 +22,20 @@ class LibrarianRequestor(ServiceRequestor):
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("librarian")
self.response_translator = TranslatorRegistry.get_response_translator("librarian")
def to_request(self, body):
# Content gets base64 decoded & encoded again. It at least makes
# sure payload is valid base64.
if "document-metadata" in body:
dm = to_document_metadata(body["document-metadata"])
else:
dm = None
if "processing-metadata" in body:
pm = to_processing_metadata(body["processing-metadata"])
else:
pm = None
if "criteria" in body:
criteria = to_criteria(body["criteria"])
else:
criteria = None
# Handle base64 content processing
if "content" in body:
# Content gets base64 decoded & encoded again to ensure valid base64
content = base64.b64decode(body["content"].encode("utf-8"))
content = base64.b64encode(content).decode("utf-8")
else:
content = None
return LibrarianRequest(
operation = body.get("operation", None),
document_id = body.get("document-id", None),
processing_id = body.get("processing-id", None),
document_metadata = dm,
processing_metadata = pm,
content = content,
user = body.get("user", None),
collection = body.get("collection", None),
criteria = criteria,
)
body = body.copy()
body["content"] = content
return self.request_translator.to_pulsar(body)
def from_response(self, message):
response = {}
if message.document_metadata:
response["document-metadata"] = serialize_document_metadata(
message.document_metadata
)
if message.content:
response["content"] = message.content.decode("utf-8")
if message.document_metadatas != None:
response["document-metadatas"] = [
serialize_document_metadata(v)
for v in message.document_metadatas
]
if message.processing_metadatas != None:
response["processing-metadatas"] = [
serialize_processing_metadata(v)
for v in message.processing_metadatas
]
return response, True
return self.response_translator.from_response_with_completion(message)

View file

@ -1,5 +1,6 @@
import asyncio
from aiohttp import web
import uuid
from . config import ConfigRequestor
@ -30,6 +31,9 @@ from . graph_embeddings_import import GraphEmbeddingsImport
from . document_embeddings_import import DocumentEmbeddingsImport
from . entity_contexts_import import EntityContextsImport
from . core_export import CoreExport
from . core_import import CoreImport
from . mux import Mux
request_response_dispatchers = {
@ -77,10 +81,11 @@ class DispatcherWrapper:
class DispatcherManager:
def __init__(self, pulsar_client, config_receiver):
def __init__(self, pulsar_client, config_receiver, prefix="api-gateway"):
self.pulsar_client = pulsar_client
self.config_receiver = config_receiver
self.config_receiver.add_handler(self)
self.prefix = prefix
self.flows = {}
self.dispatchers = {}
@ -98,6 +103,22 @@ class DispatcherManager:
def dispatch_global_service(self):
return DispatcherWrapper(self.process_global_service)
def dispatch_core_export(self):
return DispatcherWrapper(self.process_core_export)
def dispatch_core_import(self):
return DispatcherWrapper(self.process_core_import)
async def process_core_import(self, data, error, ok, request):
ci = CoreImport(self.pulsar_client)
return await ci.process(data, error, ok, request)
async def process_core_export(self, data, error, ok, request):
ce = CoreExport(self.pulsar_client)
return await ce.process(data, error, ok, request)
async def process_global_service(self, data, responder, params):
kind = params.get("kind")
@ -113,8 +134,8 @@ class DispatcherManager:
dispatcher = global_dispatchers[kind](
pulsar_client = self.pulsar_client,
timeout = 120,
consumer = f"api-gateway-{kind}-request",
subscriber = f"api-gateway-{kind}-request",
consumer = f"{self.prefix}-{kind}-request",
subscriber = f"{self.prefix}-{kind}-request",
)
await dispatcher.start()
@ -206,8 +227,8 @@ class DispatcherManager:
ws = ws,
running = running,
queue = qconfig,
consumer = f"api-gateway-{id}",
subscriber = f"api-gateway-{id}",
consumer = f"{self.prefix}-{id}",
subscriber = f"{self.prefix}-{id}",
)
return dispatcher
@ -248,8 +269,8 @@ class DispatcherManager:
request_queue = qconfig["request"],
response_queue = qconfig["response"],
timeout = 120,
consumer = f"api-gateway-{flow}-{kind}-request",
subscriber = f"api-gateway-{flow}-{kind}-request",
consumer = f"{self.prefix}-{flow}-{kind}-request",
subscriber = f"{self.prefix}-{flow}-{kind}-request",
)
elif kind in sender_dispatchers:
dispatcher = sender_dispatchers[kind](

View file

@ -2,6 +2,7 @@
import json
from ... schema import PromptRequest, PromptResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
@ -22,22 +23,12 @@ class PromptRequestor(ServiceRequestor):
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("prompt")
self.response_translator = TranslatorRegistry.get_response_translator("prompt")
def to_request(self, body):
return PromptRequest(
id=body["id"],
terms={
k: json.dumps(v)
for k, v in body["variables"].items()
}
)
return self.request_translator.to_pulsar(body)
def from_response(self, message):
if message.object:
return {
"object": message.object
}, True
else:
return {
"text": message.text
}, True
return self.response_translator.from_response_with_completion(message)

View file

@ -3,6 +3,13 @@ import base64
from ... schema import Value, Triple, DocumentMetadata, ProcessingMetadata
# DEPRECATED: These functions have been moved to trustgraph.... messaging.translators
# Use the new messaging translation system instead for consistency and reusability.
# Examples:
# from trustgraph.... messaging.translators.primitives import ValueTranslator
# value_translator = ValueTranslator()
# pulsar_value = value_translator.to_pulsar({"v": "example", "e": True})
def to_value(x):
return Value(value=x["v"], is_uri=x["e"])

View file

@ -1,5 +1,6 @@
from ... schema import TextCompletionRequest, TextCompletionResponse
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
@ -20,12 +21,12 @@ class TextCompletionRequestor(ServiceRequestor):
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("text-completion")
self.response_translator = TranslatorRegistry.get_response_translator("text-completion")
def to_request(self, body):
return TextCompletionRequest(
system=body["system"],
prompt=body["prompt"]
)
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return { "response": message.response }, True
return self.response_translator.from_response_with_completion(message)

View file

@ -2,9 +2,9 @@
import base64
from ... schema import TextDocument, Metadata
from ... messaging import TranslatorRegistry
from . sender import ServiceSender
from . serialize import to_subgraph
class TextLoad(ServiceSender):
def __init__(self, pulsar_client, queue):
@ -15,30 +15,9 @@ class TextLoad(ServiceSender):
schema = TextDocument,
)
self.translator = TranslatorRegistry.get_request_translator("text-document")
def to_request(self, body):
if "metadata" in body:
metadata = to_subgraph(body["metadata"])
else:
metadata = []
if "charset" in body:
charset = body["charset"]
else:
charset = "utf-8"
# Text is base64 encoded
text = base64.b64decode(body["text"]).decode(charset)
print("Text document received")
return TextDocument(
metadata=Metadata(
id=body.get("id"),
metadata=metadata,
user=body.get("user", "trustgraph"),
collection=body.get("collection", "default"),
),
text=text,
)
return self.translator.to_pulsar(body)

View file

@ -1,8 +1,8 @@
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples
from ... messaging import TranslatorRegistry
from . requestor import ServiceRequestor
from . serialize import to_value, serialize_subgraph
class TriplesQueryRequestor(ServiceRequestor):
def __init__(
@ -21,34 +21,12 @@ class TriplesQueryRequestor(ServiceRequestor):
timeout=timeout,
)
self.request_translator = TranslatorRegistry.get_request_translator("triples-query")
self.response_translator = TranslatorRegistry.get_response_translator("triples-query")
def to_request(self, body):
if "s" in body:
s = to_value(body["s"])
else:
s = None
if "p" in body:
p = to_value(body["p"])
else:
p = None
if "o" in body:
o = to_value(body["o"])
else:
o = None
limit = int(body.get("limit", 10000))
return TriplesQueryRequest(
s = s, p = p, o = o,
limit = limit,
user = body.get("user", "trustgraph"),
collection = body.get("collection", "default"),
)
return self.request_translator.to_pulsar(body)
def from_response(self, message):
return {
"response": serialize_subgraph(message.triples)
}, True
return self.response_translator.from_response_with_completion(message)

View file

@ -3,7 +3,7 @@ import asyncio
from aiohttp import web
from . constant_endpoint import ConstantEndpoint
from . stream_endpoint import StreamEndpoint
from . variable_endpoint import VariableEndpoint
from . socket import SocketEndpoint
from . metrics import MetricsEndpoint
@ -52,6 +52,18 @@ class EndpointManager:
auth = auth,
dispatcher = dispatcher_manager.dispatch_flow_export()
),
StreamEndpoint(
endpoint_path = "/api/v1/import-core",
auth = auth,
method = "POST",
dispatcher = dispatcher_manager.dispatch_core_import(),
),
StreamEndpoint(
endpoint_path = "/api/v1/export-core",
auth = auth,
method = "GET",
dispatcher = dispatcher_manager.dispatch_core_export(),
),
]
def add_routes(self, app):

View file

@ -0,0 +1,82 @@
import asyncio
from aiohttp import web
import logging
logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO)
class StreamEndpoint:
def __init__(self, endpoint_path, auth, dispatcher, method="POST"):
self.path = endpoint_path
self.auth = auth
self.operation = "service"
self.method = method
self.dispatcher = dispatcher
async def start(self):
pass
def add_routes(self, app):
if self.method == "POST":
app.add_routes([
web.post(self.path, self.handle),
])
elif self.method == "GET":
app.add_routes([
web.get(self.path, self.handle),
])
else:
raise RuntimeError("Bad method" + self.method)
async def handle(self, request):
print(request.path, "...")
try:
ht = request.headers["Authorization"]
tokens = ht.split(" ", 2)
if tokens[0] != "Bearer":
return web.HTTPUnauthorized()
token = tokens[1]
except:
token = ""
if not self.auth.permitted(token, self.operation):
return web.HTTPUnauthorized()
try:
data = request.content
async def error(err):
return web.HTTPInternalServerError(text = err)
async def ok(
status=200, reason="OK", type="application/octet-stream"
):
response = web.StreamResponse(
status = status, reason = reason,
headers = {"Content-Type": type}
)
await response.prepare(request)
return response
resp = await self.dispatcher.process(
data, error, ok, request
)
return resp
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)

View file

@ -1,7 +1,6 @@
import asyncio
from aiohttp import web
import uuid
import logging
logger = logging.getLogger("endpoint")

View file

@ -73,6 +73,7 @@ class Api:
self.dispatcher_manager = DispatcherManager(
pulsar_client = self.pulsar_client,
config_receiver = self.config_receiver,
prefix = "gateway",
)
self.endpoint_manager = EndpointManager(

View file

@ -18,12 +18,14 @@ from .... base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
from . prompt_manager import PromptConfiguration, Prompt, PromptManager
default_ident = "prompt"
default_concurrency = 1
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id")
concurrency = params.get("concurrency", 1)
# Config key for prompts
self.config_key = params.get("config_type", "prompt")
@ -31,6 +33,7 @@ class Processor(FlowProcessor):
super(Processor, self).__init__(
**params | {
"id": id,
"concurrency": concurrency,
}
)
@ -38,7 +41,8 @@ class Processor(FlowProcessor):
ConsumerSpec(
name = "request",
schema = PromptRequest,
handler = self.on_request
handler = self.on_request,
concurrency = concurrency,
)
)
@ -219,6 +223,13 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)
parser.add_argument(

View file

@ -0,0 +1,3 @@
from . llm import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . llm import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,138 @@
"""
Simple LLM service, performs text prompt completion using vLLM
Input is prompt, output is response.
"""
import os
import aiohttp
from .... exceptions import TooManyRequests
from .... base import LlmService, LlmResult
default_ident = "text-completion"
default_temperature = 0.0
default_max_output = 2048
default_base_url = os.getenv("VLLM_BASE_URL")
default_model = "TheBloke/Mistral-7B-v0.1-AWQ"
if default_base_url == "" or default_base_url is None:
default_base_url = "http://vllm-service:8899/v1"
class Processor(LlmService):
def __init__(self, **params):
base_url = params.get("url", default_base_url)
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
model = params.get("model", default_model)
super(Processor, self).__init__(
**params | {
"temperature": temperature,
"max_output": max_output,
"url": base_url,
"model": model,
}
)
self.base_url = base_url
self.temperature = temperature
self.max_output = max_output
self.model = model
self.session = aiohttp.ClientSession()
print("Using vLLM service at", base_url)
print("Initialised", flush=True)
async def generate_content(self, system, prompt):
headers = {
"Content-Type": "application/json",
}
request = {
"model": self.model,
"prompt": system + "\n\n" + prompt,
"max_tokens": self.max_output,
"temperature": self.temperature,
}
try:
url = f"{self.base_url}/completions"
async with self.session.post(
url,
headers=headers,
json=request,
) as response:
if response.status != 200:
raise RuntimeError("Bad status: " + str(response.status))
resp = await response.json()
inputtokens = resp["usage"]["prompt_tokens"]
outputtokens = resp["usage"]["completion_tokens"]
ans = resp["choices"][0]["text"]
print(f"Input Tokens: {inputtokens}", flush=True)
print(f"Output Tokens: {outputtokens}", flush=True)
print(ans, flush=True)
resp = LlmResult(
text = ans,
in_token = inputtokens,
out_token = outputtokens,
model = self.model,
)
return resp
# FIXME: Assuming vLLM won't produce rate limits?
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
print(f"Exception: {type(e)} {e}")
raise e
@staticmethod
def add_args(parser):
LlmService.add_args(parser)
parser.add_argument(
'-u', '--url',
default=default_base_url,
help=f'vLLM service base URL (default: {default_base_url})'
)
parser.add_argument(
'-m', '--model',
default=default_model,
help=f'LLM model (default: {default_model})'
)
parser.add_argument(
'-t', '--temperature',
type=float,
default=default_temperature,
help=f'LLM temperature parameter (default: {default_temperature})'
)
parser.add_argument(
'-x', '--max-output',
type=int,
default=default_max_output,
help=f'LLM max output tokens (default: {default_max_output})'
)
def run():
Processor.launch(default_ident, __doc__)

View file

@ -11,12 +11,14 @@ from ... base import PromptClientSpec, EmbeddingsClientSpec
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
default_ident = "graph-rag"
default_concurrency = 1
class Processor(FlowProcessor):
def __init__(self, **params):
id = params.get("id", default_ident)
concurrency = params.get("concurrency", 1)
entity_limit = params.get("entity_limit", 50)
triple_limit = params.get("triple_limit", 30)
@ -26,6 +28,7 @@ class Processor(FlowProcessor):
super(Processor, self).__init__(
**params | {
"id": id,
"concurrency": concurrency,
"entity_limit": entity_limit,
"triple_limit": triple_limit,
"max_subgraph_size": max_subgraph_size,
@ -43,6 +46,7 @@ class Processor(FlowProcessor):
name = "request",
schema = GraphRagQuery,
handler = self.on_request,
concurrency = concurrency,
)
)
@ -157,6 +161,13 @@ class Processor(FlowProcessor):
@staticmethod
def add_args(parser):
parser.add_argument(
'-c', '--concurrency',
type=int,
default=default_concurrency,
help=f'Concurrent processing threads (default: {default_concurrency})'
)
FlowProcessor.add_args(parser)
parser.add_argument(

View file

@ -0,0 +1 @@
from . service import run

View file

@ -0,0 +1,11 @@
import logging
from .service import run
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
if __name__ == "__main__":
run()

View file

@ -0,0 +1,130 @@
import asyncio
import logging
import uuid
from typing import Dict, Any, Optional
from trustgraph.messaging import TranslatorRegistry
from ..gateway.dispatch.manager import DispatcherManager
logger = logging.getLogger("dispatcher")
logger.setLevel(logging.INFO)
class WebSocketResponder:
"""Simple responder that captures response for websocket return"""
def __init__(self):
self.response = None
self.completed = False
async def send(self, data):
"""Capture the response data"""
self.response = data
self.completed = True
async def __call__(self, data, final=False):
"""Make the responder callable for compatibility with requestor"""
await self.send(data)
if final:
self.completed = True
class MessageDispatcher:
def __init__(self, max_workers: int = 10, config_receiver=None, pulsar_client=None):
self.max_workers = max_workers
self.semaphore = asyncio.Semaphore(max_workers)
self.active_tasks = set()
self.pulsar_client = pulsar_client
# Use DispatcherManager for flow and service management
if pulsar_client and config_receiver:
self.dispatcher_manager = DispatcherManager(pulsar_client, config_receiver, prefix="rev-gateway")
else:
self.dispatcher_manager = None
logger.warning("No pulsar_client or config_receiver provided - using fallback mode")
# Service name mapping from websocket protocol to translator registry
self.service_mapping = {
"text-completion": "text-completion",
"graph-rag": "graph-rag",
"agent": "agent",
"embeddings": "embeddings",
"graph-embeddings": "graph-embeddings",
"triples": "triples",
"document-load": "document",
"text-load": "text-document",
"flow": "flow",
"knowledge": "knowledge",
"config": "config",
"librarian": "librarian",
"document-rag": "document-rag"
}
async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
async with self.semaphore:
task = asyncio.create_task(self._process_message(message))
self.active_tasks.add(task)
try:
result = await task
return result
finally:
self.active_tasks.discard(task)
async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]:
request_id = message.get('id', str(uuid.uuid4()))
service = message.get('service')
request_data = message.get('request', {})
flow_id = message.get('flow', 'default') # Default flow
logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}")
try:
if not self.dispatcher_manager:
raise RuntimeError("DispatcherManager not available - pulsar_client and config_receiver required")
# Use DispatcherManager for flow-based processing
responder = WebSocketResponder()
# Map websocket service name to dispatcher service name
dispatcher_service = self.service_mapping.get(service, service)
# Check if this is a global service or flow service
from ..gateway.dispatch.manager import global_dispatchers
if dispatcher_service in global_dispatchers:
# Use global service dispatcher
await self.dispatcher_manager.invoke_global_service(
request_data, responder, dispatcher_service
)
else:
# Use DispatcherManager to process the request through Pulsar queues
await self.dispatcher_manager.invoke_flow_service(
request_data, responder, flow_id, dispatcher_service
)
# Get the response from the responder
if responder.completed:
response_data = responder.response
else:
response_data = {'error': 'No response received'}
response = {
'id': request_id,
'response': response_data
}
except Exception as e:
logger.error(f"Error processing message {request_id}: {e}")
response = {
'id': request_id,
'response': {'error': str(e)}
}
logger.info(f"Completed processing message {request_id}")
return response
async def shutdown(self):
if self.active_tasks:
logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete")
await asyncio.gather(*self.active_tasks, return_exceptions=True)
# DispatcherManager handles its own cleanup
logger.info("Dispatcher shutdown complete")

View file

@ -0,0 +1,242 @@
import asyncio
import argparse
import logging
import json
import sys
import os
from aiohttp import ClientSession, WSMsgType, ClientWebSocketResponse
from typing import Optional
from urllib.parse import urlparse, urlunparse
import pulsar
from .dispatcher import MessageDispatcher
from ..gateway.config.receiver import ConfigReceiver
logger = logging.getLogger("rev_gateway")
logger.setLevel(logging.INFO)
default_websocket = "ws://localhost:7650/out"
class ReverseGateway:
def __init__(self, websocket_uri: str = None, max_workers: int = 10,
pulsar_host: str = None, pulsar_api_key: str = None,
pulsar_listener: str = None):
# Set default WebSocket URI with environment variable support
if websocket_uri is None:
websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket)
# Parse and validate the WebSocket URI
parsed_uri = urlparse(websocket_uri)
if parsed_uri.scheme not in ('ws', 'wss'):
raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}")
if not parsed_uri.netloc:
raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}")
# Store parsed components for debugging/logging
self.websocket_uri = websocket_uri
self.host = parsed_uri.hostname
self.port = parsed_uri.port
self.scheme = parsed_uri.scheme
self.path = parsed_uri.path or "/ws"
# Construct the full URL (in case path was missing)
if not parsed_uri.path:
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
else:
self.url = websocket_uri
self.max_workers = max_workers
self.ws: Optional[ClientWebSocketResponse] = None
self.session: Optional[ClientSession] = None
self.running = False
self.reconnect_delay = 3.0
# Pulsar configuration
self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None)
self.pulsar_listener = pulsar_listener
# Initialize Pulsar client
if self.pulsar_api_key:
self.pulsar_client = pulsar.Client(
self.pulsar_host,
listener_name=self.pulsar_listener,
authentication=pulsar.AuthenticationToken(self.pulsar_api_key)
)
else:
self.pulsar_client = pulsar.Client(
self.pulsar_host,
listener_name=self.pulsar_listener
)
# Initialize config receiver
self.config_receiver = ConfigReceiver(self.pulsar_client)
# Initialize dispatcher with config_receiver and pulsar_client - must be created after config_receiver
self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.pulsar_client)
async def connect(self) -> bool:
try:
if self.session is None:
self.session = ClientSession()
logger.info(f"Connecting to {self.url}")
self.ws = await self.session.ws_connect(self.url)
logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}")
return True
except Exception as e:
logger.error(f"Failed to connect to {self.url}: {e}")
return False
async def disconnect(self):
if self.ws and not self.ws.closed:
await self.ws.close()
if self.session and not self.session.closed:
await self.session.close()
self.ws = None
self.session = None
async def send_message(self, message: dict):
if self.ws and not self.ws.closed:
try:
await self.ws.send_str(json.dumps(message))
except Exception as e:
logger.error(f"Failed to send message: {e}")
async def handle_message(self, message: str):
try:
print(f"Received: {message}", flush=True)
msg_data = json.loads(message)
response = await self.dispatcher.handle_message(msg_data)
if response:
await self.send_message(response)
except Exception as e:
logger.error(f"Error handling message: {e}")
async def listen(self):
while self.running and self.ws and not self.ws.closed:
try:
msg = await self.ws.receive()
if msg.type == WSMsgType.TEXT:
await self.handle_message(msg.data)
elif msg.type == WSMsgType.BINARY:
await self.handle_message(msg.data.decode('utf-8'))
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
logger.warning("WebSocket closed or error occurred")
break
except Exception as e:
logger.error(f"Error in listen loop: {e}")
break
async def run(self):
self.running = True
logger.info("Starting reverse gateway")
# Start config receiver
logger.info("Starting config receiver")
await self.config_receiver.start()
while self.running:
try:
if await self.connect():
await self.listen()
else:
logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds")
await self.disconnect()
if self.running:
await asyncio.sleep(self.reconnect_delay)
except KeyboardInterrupt:
logger.info("Shutdown requested")
break
except Exception as e:
logger.error(f"Unexpected error: {e}")
if self.running:
await asyncio.sleep(self.reconnect_delay)
await self.shutdown()
async def shutdown(self):
logger.info("Shutting down reverse gateway")
self.running = False
await self.dispatcher.shutdown()
await self.disconnect()
# Close Pulsar client
if hasattr(self, 'pulsar_client'):
self.pulsar_client.close()
def stop(self):
self.running = False
def parse_args():
parser = argparse.ArgumentParser(
prog="reverse-gateway",
description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge"
)
parser.add_argument(
'--websocket-uri',
default=None,
help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)'
)
parser.add_argument(
'--max-workers',
type=int,
default=10,
help='Maximum concurrent message handlers (default: 10)'
)
parser.add_argument(
'-p', '--pulsar-host',
default=None,
help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)'
)
parser.add_argument(
'--pulsar-api-key',
default=None,
help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)'
)
parser.add_argument(
'--pulsar-listener',
default=None,
help='Pulsar listener name'
)
return parser.parse_args()
def run():
args = parse_args()
gateway = ReverseGateway(
websocket_uri=args.websocket_uri,
max_workers=args.max_workers,
pulsar_host=args.pulsar_host,
pulsar_api_key=args.pulsar_api_key,
pulsar_listener=args.pulsar_listener
)
print(f"Starting reverse gateway:")
print(f" WebSocket URI: {gateway.url}")
print(f" Max workers: {args.max_workers}")
print(f" Pulsar host: {gateway.pulsar_host}")
try:
asyncio.run(gateway.run())
except KeyboardInterrupt:
print("\nShutdown requested by user")
except Exception as e:
print(f"Fatal error: {e}")
sys.exit(1)