mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Merge branch 'release/v1.0'
This commit is contained in:
commit
dbe78ebe46
66 changed files with 2706 additions and 528 deletions
|
|
@ -9,6 +9,7 @@ FROM docker.io/fedora:42 AS base
|
|||
ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||
|
||||
RUN dnf install -y python3.12 && \
|
||||
dnf install -y tesseract poppler-utils && \
|
||||
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
|
||||
python -m ensurepip --upgrade && \
|
||||
pip3 install --no-cache-dir wheel aiohttp && \
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from prometheus_client import start_http_server, Info
|
|||
|
||||
from .. schema import ConfigPush, config_push_queue
|
||||
from .. log_level import LogLevel
|
||||
from .. exceptions import TooManyRequests
|
||||
from . pubsub import PulsarClient
|
||||
from . producer import Producer
|
||||
from . consumer import Consumer
|
||||
|
|
|
|||
|
|
@ -1,4 +1,14 @@
|
|||
|
||||
# Consumer is similar to subscriber: It takes information from a queue
|
||||
# and passes on to a processor function. This is the main receiving
|
||||
# loop for TrustGraph processors. Incorporates retry functionality
|
||||
|
||||
# Note: there is a 'defect' in the system which is tolerated, althought
|
||||
# the processing handlers are async functions, ideally implementation
|
||||
# would use all async code. In practice if the processor only implements
|
||||
# one handler, and a single thread of concurrency, nothing too outrageous
|
||||
# will happen if synchronous / blocking code is used
|
||||
|
||||
from pulsar.schema import JsonSchema
|
||||
import pulsar
|
||||
import _pulsar
|
||||
|
|
@ -16,6 +26,7 @@ class Consumer:
|
|||
start_of_messages=False,
|
||||
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
|
||||
reconnect_time = 5,
|
||||
concurrency = 1, # Number of concurrent requests to handle
|
||||
):
|
||||
|
||||
self.taskgroup = taskgroup
|
||||
|
|
@ -34,7 +45,9 @@ class Consumer:
|
|||
self.start_of_messages = start_of_messages
|
||||
|
||||
self.running = True
|
||||
self.task = None
|
||||
self.consumer_task = None
|
||||
|
||||
self.concurrency = concurrency
|
||||
|
||||
self.metrics = metrics
|
||||
|
||||
|
|
@ -52,7 +65,11 @@ class Consumer:
|
|||
async def stop(self):
|
||||
|
||||
self.running = False
|
||||
await self.task
|
||||
|
||||
if self.consumer_task:
|
||||
await self.consumer_task
|
||||
|
||||
self.consumer_task = None
|
||||
|
||||
async def start(self):
|
||||
|
||||
|
|
@ -62,9 +79,9 @@ class Consumer:
|
|||
if self.metrics:
|
||||
self.metrics.state("stopped")
|
||||
|
||||
self.task = self.taskgroup.create_task(self.run())
|
||||
self.consumer_task = self.taskgroup.create_task(self.consumer_run())
|
||||
|
||||
async def run(self):
|
||||
async def consumer_run(self):
|
||||
|
||||
while self.running:
|
||||
|
||||
|
|
@ -102,7 +119,19 @@ class Consumer:
|
|||
|
||||
try:
|
||||
|
||||
await self.consume()
|
||||
print(
|
||||
"Starting", self.concurrency, "receiver threads",
|
||||
flush=True
|
||||
)
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
|
||||
tasks = []
|
||||
|
||||
for i in range(0, self.concurrency):
|
||||
tasks.append(
|
||||
tg.create_task(self.consume_from_queue())
|
||||
)
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.state("stopped")
|
||||
|
|
@ -120,7 +149,7 @@ class Consumer:
|
|||
self.consumer.unsubscribe()
|
||||
self.consumer.close()
|
||||
|
||||
async def consume(self):
|
||||
async def consume_from_queue(self):
|
||||
|
||||
while self.running:
|
||||
|
||||
|
|
@ -134,71 +163,75 @@ class Consumer:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
expiry = time.time() + self.rate_limit_timeout
|
||||
await self.handle_one_from_queue(msg)
|
||||
|
||||
# This loop is for retry on rate-limit / resource limits
|
||||
while self.running:
|
||||
async def handle_one_from_queue(self, msg):
|
||||
|
||||
if time.time() > expiry:
|
||||
expiry = time.time() + self.rate_limit_timeout
|
||||
|
||||
print("Gave up waiting for rate-limit retry", flush=True)
|
||||
# This loop is for retry on rate-limit / resource limits
|
||||
while self.running:
|
||||
|
||||
# Message failed to be processed, this causes it to
|
||||
# be retried
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
if time.time() > expiry:
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.process("error")
|
||||
print("Gave up waiting for rate-limit retry", flush=True)
|
||||
|
||||
# Break out of retry loop, processes next message
|
||||
break
|
||||
# Message failed to be processed, this causes it to
|
||||
# be retried
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
|
||||
try:
|
||||
if self.metrics:
|
||||
self.metrics.process("error")
|
||||
|
||||
print("Handle...", flush=True)
|
||||
# Break out of retry loop, processes next message
|
||||
break
|
||||
|
||||
if self.metrics:
|
||||
try:
|
||||
|
||||
with self.metrics.record_time():
|
||||
await self.handler(msg, self, self.flow)
|
||||
print("Handle...", flush=True)
|
||||
|
||||
else:
|
||||
if self.metrics:
|
||||
|
||||
with self.metrics.record_time():
|
||||
await self.handler(msg, self, self.flow)
|
||||
|
||||
print("Handled.", flush=True)
|
||||
else:
|
||||
await self.handler(msg, self, self.flow)
|
||||
|
||||
# Acknowledge successful processing of the message
|
||||
self.consumer.acknowledge(msg)
|
||||
print("Handled.", flush=True)
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.process("success")
|
||||
# Acknowledge successful processing of the message
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
# Break out of retry loop
|
||||
break
|
||||
if self.metrics:
|
||||
self.metrics.process("success")
|
||||
|
||||
except TooManyRequests:
|
||||
# Break out of retry loop
|
||||
break
|
||||
|
||||
print("TooManyRequests: will retry...", flush=True)
|
||||
except TooManyRequests:
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.rate_limit()
|
||||
print("TooManyRequests: will retry...", flush=True)
|
||||
|
||||
# Sleep
|
||||
await asyncio.sleep(self.rate_limit_retry_time)
|
||||
if self.metrics:
|
||||
self.metrics.rate_limit()
|
||||
|
||||
# Contine from retry loop, just causes a reprocessing
|
||||
continue
|
||||
# Sleep
|
||||
await asyncio.sleep(self.rate_limit_retry_time)
|
||||
|
||||
except Exception as e:
|
||||
# Contine from retry loop, just causes a reprocessing
|
||||
continue
|
||||
|
||||
print("consume exception:", e, flush=True)
|
||||
except Exception as e:
|
||||
|
||||
# Message failed to be processed, this causes it to
|
||||
# be retried
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
print("consume exception:", e, flush=True)
|
||||
|
||||
if self.metrics:
|
||||
self.metrics.process("error")
|
||||
# Message failed to be processed, this causes it to
|
||||
# be retried
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
|
||||
# Break out of retry loop, processes next message
|
||||
break
|
||||
if self.metrics:
|
||||
self.metrics.process("error")
|
||||
|
||||
# Break out of retry loop, processes next message
|
||||
break
|
||||
|
|
|
|||
|
|
@ -4,10 +4,11 @@ from . consumer import Consumer
|
|||
from . spec import Spec
|
||||
|
||||
class ConsumerSpec(Spec):
|
||||
def __init__(self, name, schema, handler):
|
||||
def __init__(self, name, schema, handler, concurrency = 1):
|
||||
self.name = name
|
||||
self.schema = schema
|
||||
self.handler = handler
|
||||
self.concurrency = concurrency
|
||||
|
||||
def add(self, flow, processor, definition):
|
||||
|
||||
|
|
@ -24,6 +25,7 @@ class ConsumerSpec(Spec):
|
|||
schema = self.schema,
|
||||
handler = self.handler,
|
||||
metrics = consumer_metrics,
|
||||
concurrency = self.concurrency
|
||||
)
|
||||
|
||||
# Consumer handle gets access to producers and other
|
||||
|
|
|
|||
|
|
@ -11,20 +11,26 @@ from .. exceptions import TooManyRequests
|
|||
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
|
||||
default_ident = "embeddings"
|
||||
default_concurrency = 1
|
||||
|
||||
class EmbeddingsService(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
|
||||
super(EmbeddingsService, self).__init__(**params | { "id": id })
|
||||
super(EmbeddingsService, self).__init__(**params | {
|
||||
"id": id,
|
||||
"concurrency": concurrency,
|
||||
})
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = EmbeddingsRequest,
|
||||
handler = self.on_request
|
||||
handler = self.on_request,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -84,6 +90,13 @@ class EmbeddingsService(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,9 +11,13 @@ from .. exceptions import TooManyRequests
|
|||
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||
|
||||
default_ident = "text-completion"
|
||||
default_concurrency = 1
|
||||
|
||||
class LlmResult:
|
||||
def __init__(self, text=None, in_token=None, out_token=None, model=None):
|
||||
def __init__(
|
||||
self, text = None, in_token = None, out_token = None,
|
||||
model = None,
|
||||
):
|
||||
self.text = text
|
||||
self.in_token = in_token
|
||||
self.out_token = out_token
|
||||
|
|
@ -25,14 +29,19 @@ class LlmService(FlowProcessor):
|
|||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
|
||||
super(LlmService, self).__init__(**params | { "id": id })
|
||||
super(LlmService, self).__init__(**params | {
|
||||
"id": id,
|
||||
"concurrency": concurrency,
|
||||
})
|
||||
|
||||
self.register_specification(
|
||||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = TextCompletionRequest,
|
||||
handler = self.on_request
|
||||
handler = self.on_request,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -115,5 +124,12 @@ class LlmService(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
|
||||
# Subscriber is similar to consumer: It provides a service to take stuff
|
||||
# off of a queue and make it available using an internal broker system,
|
||||
# so suitable for when multiple recipients are reading from the same queue
|
||||
|
||||
from pulsar.schema import JsonSchema
|
||||
import asyncio
|
||||
import _pulsar
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Triples store base class
|
|||
|
||||
from .. schema import Triples
|
||||
from .. base import FlowProcessor, ConsumerSpec
|
||||
from .. exceptions import TooManyRequests
|
||||
|
||||
default_ident = "triples-write"
|
||||
|
||||
|
|
|
|||
105
trustgraph-base/trustgraph/messaging/__init__.py
Normal file
105
trustgraph-base/trustgraph/messaging/__init__.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
from .registry import TranslatorRegistry
|
||||
from .translators import *
|
||||
|
||||
# Auto-register all translators
|
||||
from .translators.agent import AgentRequestTranslator, AgentResponseTranslator
|
||||
from .translators.embeddings import EmbeddingsRequestTranslator, EmbeddingsResponseTranslator
|
||||
from .translators.text_completion import TextCompletionRequestTranslator, TextCompletionResponseTranslator
|
||||
from .translators.retrieval import (
|
||||
DocumentRagRequestTranslator, DocumentRagResponseTranslator,
|
||||
GraphRagRequestTranslator, GraphRagResponseTranslator
|
||||
)
|
||||
from .translators.triples import TriplesQueryRequestTranslator, TriplesQueryResponseTranslator
|
||||
from .translators.knowledge import KnowledgeRequestTranslator, KnowledgeResponseTranslator
|
||||
from .translators.library import LibraryRequestTranslator, LibraryResponseTranslator
|
||||
from .translators.document_loading import DocumentTranslator, TextDocumentTranslator
|
||||
from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator
|
||||
from .translators.flow import FlowRequestTranslator, FlowResponseTranslator
|
||||
from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator
|
||||
from .translators.embeddings_query import (
|
||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
)
|
||||
|
||||
# Register all service translators
|
||||
TranslatorRegistry.register_service(
|
||||
"agent",
|
||||
AgentRequestTranslator(),
|
||||
AgentResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"embeddings",
|
||||
EmbeddingsRequestTranslator(),
|
||||
EmbeddingsResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"text-completion",
|
||||
TextCompletionRequestTranslator(),
|
||||
TextCompletionResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"document-rag",
|
||||
DocumentRagRequestTranslator(),
|
||||
DocumentRagResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"graph-rag",
|
||||
GraphRagRequestTranslator(),
|
||||
GraphRagResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"triples-query",
|
||||
TriplesQueryRequestTranslator(),
|
||||
TriplesQueryResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"knowledge",
|
||||
KnowledgeRequestTranslator(),
|
||||
KnowledgeResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"librarian",
|
||||
LibraryRequestTranslator(),
|
||||
LibraryResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"config",
|
||||
ConfigRequestTranslator(),
|
||||
ConfigResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"flow",
|
||||
FlowRequestTranslator(),
|
||||
FlowResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"prompt",
|
||||
PromptRequestTranslator(),
|
||||
PromptResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"document-embeddings-query",
|
||||
DocumentEmbeddingsRequestTranslator(),
|
||||
DocumentEmbeddingsResponseTranslator()
|
||||
)
|
||||
|
||||
TranslatorRegistry.register_service(
|
||||
"graph-embeddings-query",
|
||||
GraphEmbeddingsRequestTranslator(),
|
||||
GraphEmbeddingsResponseTranslator()
|
||||
)
|
||||
|
||||
# Register single-direction translators for document loading
|
||||
TranslatorRegistry.register_request("document", DocumentTranslator())
|
||||
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())
|
||||
51
trustgraph-base/trustgraph/messaging/registry.py
Normal file
51
trustgraph-base/trustgraph/messaging/registry.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from typing import Dict, List, Union
|
||||
from .translators.base import MessageTranslator
|
||||
|
||||
|
||||
class TranslatorRegistry:
|
||||
"""Registry for service translators"""
|
||||
|
||||
_request_translators: Dict[str, MessageTranslator] = {}
|
||||
_response_translators: Dict[str, MessageTranslator] = {}
|
||||
|
||||
@classmethod
|
||||
def register_request(cls, service_name: str, translator: MessageTranslator):
|
||||
"""Register a request translator for a service"""
|
||||
cls._request_translators[service_name] = translator
|
||||
|
||||
@classmethod
|
||||
def register_response(cls, service_name: str, translator: MessageTranslator):
|
||||
"""Register a response translator for a service"""
|
||||
cls._response_translators[service_name] = translator
|
||||
|
||||
@classmethod
|
||||
def register_service(cls, service_name: str, request_translator: MessageTranslator,
|
||||
response_translator: MessageTranslator):
|
||||
"""Register both request and response translators for a service"""
|
||||
cls.register_request(service_name, request_translator)
|
||||
cls.register_response(service_name, response_translator)
|
||||
|
||||
@classmethod
|
||||
def get_request_translator(cls, service_name: str) -> MessageTranslator:
|
||||
"""Get request translator for a service"""
|
||||
if service_name not in cls._request_translators:
|
||||
raise KeyError(f"No request translator registered for service: {service_name}")
|
||||
return cls._request_translators[service_name]
|
||||
|
||||
@classmethod
|
||||
def get_response_translator(cls, service_name: str) -> MessageTranslator:
|
||||
"""Get response translator for a service"""
|
||||
if service_name not in cls._response_translators:
|
||||
raise KeyError(f"No response translator registered for service: {service_name}")
|
||||
return cls._response_translators[service_name]
|
||||
|
||||
@classmethod
|
||||
def list_services(cls) -> List[str]:
|
||||
"""List all registered services"""
|
||||
return sorted(set(cls._request_translators.keys()) | set(cls._response_translators.keys()))
|
||||
|
||||
@classmethod
|
||||
def has_service(cls, service_name: str) -> bool:
|
||||
"""Check if a service is registered"""
|
||||
return (service_name in cls._request_translators or
|
||||
service_name in cls._response_translators)
|
||||
19
trustgraph-base/trustgraph/messaging/translators/__init__.py
Normal file
19
trustgraph-base/trustgraph/messaging/translators/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from .base import Translator, MessageTranslator
|
||||
from .primitives import ValueTranslator, TripleTranslator, SubgraphTranslator
|
||||
from .metadata import DocumentMetadataTranslator, ProcessingMetadataTranslator
|
||||
from .agent import AgentRequestTranslator, AgentResponseTranslator
|
||||
from .embeddings import EmbeddingsRequestTranslator, EmbeddingsResponseTranslator
|
||||
from .text_completion import TextCompletionRequestTranslator, TextCompletionResponseTranslator
|
||||
from .retrieval import DocumentRagRequestTranslator, DocumentRagResponseTranslator
|
||||
from .retrieval import GraphRagRequestTranslator, GraphRagResponseTranslator
|
||||
from .triples import TriplesQueryRequestTranslator, TriplesQueryResponseTranslator
|
||||
from .knowledge import KnowledgeRequestTranslator, KnowledgeResponseTranslator
|
||||
from .library import LibraryRequestTranslator, LibraryResponseTranslator
|
||||
from .document_loading import DocumentTranslator, TextDocumentTranslator, ChunkTranslator, DocumentEmbeddingsTranslator
|
||||
from .config import ConfigRequestTranslator, ConfigResponseTranslator
|
||||
from .flow import FlowRequestTranslator, FlowResponseTranslator
|
||||
from .prompt import PromptRequestTranslator, PromptResponseTranslator
|
||||
from .embeddings_query import (
|
||||
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||
)
|
||||
44
trustgraph-base/trustgraph/messaging/translators/agent.py
Normal file
44
trustgraph-base/trustgraph/messaging/translators/agent.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import AgentRequest, AgentResponse
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class AgentRequestTranslator(MessageTranslator):
|
||||
"""Translator for AgentRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest:
|
||||
return AgentRequest(
|
||||
question=data["question"],
|
||||
plan=data.get("plan", ""),
|
||||
state=data.get("state", ""),
|
||||
history=data.get("history", [])
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"question": obj.question,
|
||||
"plan": obj.plan,
|
||||
"state": obj.state,
|
||||
"history": obj.history
|
||||
}
|
||||
|
||||
|
||||
class AgentResponseTranslator(MessageTranslator):
|
||||
"""Translator for AgentResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> AgentResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: AgentResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
if obj.answer:
|
||||
result["answer"] = obj.answer
|
||||
if obj.thought:
|
||||
result["thought"] = obj.thought
|
||||
if obj.observation:
|
||||
result["observation"] = obj.observation
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), (obj.answer is not None)
|
||||
43
trustgraph-base/trustgraph/messaging/translators/base.py
Normal file
43
trustgraph-base/trustgraph/messaging/translators/base.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Tuple
|
||||
from pulsar.schema import Record
|
||||
|
||||
|
||||
class Translator(ABC):
|
||||
"""Base class for bidirectional Pulsar ↔ dict translation"""
|
||||
|
||||
@abstractmethod
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> Record:
|
||||
"""Convert dict to Pulsar schema object"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def from_pulsar(self, obj: Record) -> Dict[str, Any]:
|
||||
"""Convert Pulsar schema object to dict"""
|
||||
pass
|
||||
|
||||
|
||||
class MessageTranslator(Translator):
|
||||
"""For complete request/response message translation"""
|
||||
|
||||
def from_response_with_completion(self, obj: Record) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final) - for streaming responses"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
||||
|
||||
class SendTranslator(Translator):
|
||||
"""For fire-and-forget send operations (like ServiceSender)"""
|
||||
|
||||
def from_pulsar(self, obj: Record) -> Dict[str, Any]:
|
||||
"""Usually not needed for send-only operations"""
|
||||
raise NotImplementedError("Send translators typically don't need from_pulsar")
|
||||
|
||||
|
||||
def handle_optional_fields(obj: Record, fields: list) -> Dict[str, Any]:
|
||||
"""Helper to extract optional fields from Pulsar object"""
|
||||
result = {}
|
||||
for field in fields:
|
||||
value = getattr(obj, field, None)
|
||||
if value is not None:
|
||||
result[field] = value
|
||||
return result
|
||||
100
trustgraph-base/trustgraph/messaging/translators/config.py
Normal file
100
trustgraph-base/trustgraph/messaging/translators/config.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class ConfigRequestTranslator(MessageTranslator):
|
||||
"""Translator for ConfigRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> ConfigRequest:
|
||||
keys = None
|
||||
if "keys" in data:
|
||||
keys = [
|
||||
ConfigKey(
|
||||
type=k["type"],
|
||||
key=k["key"]
|
||||
)
|
||||
for k in data["keys"]
|
||||
]
|
||||
|
||||
values = None
|
||||
if "values" in data:
|
||||
values = [
|
||||
ConfigValue(
|
||||
type=v["type"],
|
||||
key=v["key"],
|
||||
value=v["value"]
|
||||
)
|
||||
for v in data["values"]
|
||||
]
|
||||
|
||||
return ConfigRequest(
|
||||
operation=data.get("operation"),
|
||||
keys=keys,
|
||||
type=data.get("type"),
|
||||
values=values
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: ConfigRequest) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.operation:
|
||||
result["operation"] = obj.operation
|
||||
if obj.type:
|
||||
result["type"] = obj.type
|
||||
|
||||
if obj.keys:
|
||||
result["keys"] = [
|
||||
{
|
||||
"type": k.type,
|
||||
"key": k.key
|
||||
}
|
||||
for k in obj.keys
|
||||
]
|
||||
|
||||
if obj.values:
|
||||
result["values"] = [
|
||||
{
|
||||
"type": v.type,
|
||||
"key": v.key,
|
||||
"value": v.value
|
||||
}
|
||||
for v in obj.values
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ConfigResponseTranslator(MessageTranslator):
|
||||
"""Translator for ConfigResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> ConfigResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: ConfigResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.version is not None:
|
||||
result["version"] = obj.version
|
||||
|
||||
if obj.values:
|
||||
result["values"] = [
|
||||
{
|
||||
"type": v.type,
|
||||
"key": v.key,
|
||||
"value": v.value
|
||||
}
|
||||
for v in obj.values
|
||||
]
|
||||
|
||||
if obj.directory:
|
||||
result["directory"] = obj.directory
|
||||
|
||||
if obj.config:
|
||||
result["config"] = obj.config
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: ConfigResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -0,0 +1,191 @@
|
|||
import base64
|
||||
from typing import Dict, Any
|
||||
from ...schema import Document, TextDocument, Chunk, DocumentEmbeddings, ChunkEmbeddings
|
||||
from .base import SendTranslator
|
||||
from .metadata import DocumentMetadataTranslator
|
||||
from .primitives import SubgraphTranslator
|
||||
|
||||
|
||||
class DocumentTranslator(SendTranslator):
|
||||
"""Translator for Document schema objects (PDF docs etc.)"""
|
||||
|
||||
def __init__(self):
|
||||
self.subgraph_translator = SubgraphTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> Document:
|
||||
metadata = data.get("metadata", [])
|
||||
|
||||
# Handle base64 content validation
|
||||
doc = base64.b64decode(data["data"])
|
||||
|
||||
from ...schema import Metadata
|
||||
return Document(
|
||||
metadata=Metadata(
|
||||
id=data.get("id"),
|
||||
metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [],
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
),
|
||||
data=base64.b64encode(doc).decode("utf-8")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: Document) -> Dict[str, Any]:
|
||||
result = {
|
||||
"data": obj.data
|
||||
}
|
||||
|
||||
if obj.metadata:
|
||||
metadata_dict = {}
|
||||
if obj.metadata.id:
|
||||
metadata_dict["id"] = obj.metadata.id
|
||||
if obj.metadata.user:
|
||||
metadata_dict["user"] = obj.metadata.user
|
||||
if obj.metadata.collection:
|
||||
metadata_dict["collection"] = obj.metadata.collection
|
||||
if obj.metadata.metadata:
|
||||
metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata)
|
||||
|
||||
result["metadata"] = metadata_dict
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TextDocumentTranslator(SendTranslator):
|
||||
"""Translator for TextDocument schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.subgraph_translator = SubgraphTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> TextDocument:
|
||||
metadata = data.get("metadata", [])
|
||||
charset = data.get("charset", "utf-8")
|
||||
|
||||
# Text is base64 encoded in input
|
||||
text = base64.b64decode(data["text"]).decode(charset)
|
||||
|
||||
from ...schema import Metadata
|
||||
return TextDocument(
|
||||
metadata=Metadata(
|
||||
id=data.get("id"),
|
||||
metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [],
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
),
|
||||
text=text.encode("utf-8")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: TextDocument) -> Dict[str, Any]:
|
||||
result = {
|
||||
"text": obj.text.decode("utf-8") if isinstance(obj.text, bytes) else obj.text
|
||||
}
|
||||
|
||||
if obj.metadata:
|
||||
metadata_dict = {}
|
||||
if obj.metadata.id:
|
||||
metadata_dict["id"] = obj.metadata.id
|
||||
if obj.metadata.user:
|
||||
metadata_dict["user"] = obj.metadata.user
|
||||
if obj.metadata.collection:
|
||||
metadata_dict["collection"] = obj.metadata.collection
|
||||
if obj.metadata.metadata:
|
||||
metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata)
|
||||
|
||||
result["metadata"] = metadata_dict
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ChunkTranslator(SendTranslator):
|
||||
"""Translator for Chunk schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.subgraph_translator = SubgraphTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> Chunk:
|
||||
metadata = data.get("metadata", [])
|
||||
|
||||
from ...schema import Metadata
|
||||
return Chunk(
|
||||
metadata=Metadata(
|
||||
id=data.get("id"),
|
||||
metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [],
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
),
|
||||
chunk=data["chunk"].encode("utf-8") if isinstance(data["chunk"], str) else data["chunk"]
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: Chunk) -> Dict[str, Any]:
|
||||
result = {
|
||||
"chunk": obj.chunk.decode("utf-8") if isinstance(obj.chunk, bytes) else obj.chunk
|
||||
}
|
||||
|
||||
if obj.metadata:
|
||||
metadata_dict = {}
|
||||
if obj.metadata.id:
|
||||
metadata_dict["id"] = obj.metadata.id
|
||||
if obj.metadata.user:
|
||||
metadata_dict["user"] = obj.metadata.user
|
||||
if obj.metadata.collection:
|
||||
metadata_dict["collection"] = obj.metadata.collection
|
||||
if obj.metadata.metadata:
|
||||
metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata)
|
||||
|
||||
result["metadata"] = metadata_dict
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class DocumentEmbeddingsTranslator(SendTranslator):
|
||||
"""Translator for DocumentEmbeddings schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.subgraph_translator = SubgraphTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddings:
|
||||
metadata = data.get("metadata", {})
|
||||
|
||||
chunks = [
|
||||
ChunkEmbeddings(
|
||||
chunk=chunk["chunk"].encode("utf-8") if isinstance(chunk["chunk"], str) else chunk["chunk"],
|
||||
vectors=chunk["vectors"]
|
||||
)
|
||||
for chunk in data.get("chunks", [])
|
||||
]
|
||||
|
||||
from ...schema import Metadata
|
||||
return DocumentEmbeddings(
|
||||
metadata=Metadata(
|
||||
id=metadata.get("id"),
|
||||
metadata=self.subgraph_translator.to_pulsar(metadata.get("metadata", [])),
|
||||
user=metadata.get("user", "trustgraph"),
|
||||
collection=metadata.get("collection", "default"),
|
||||
),
|
||||
chunks=chunks
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: DocumentEmbeddings) -> Dict[str, Any]:
|
||||
result = {
|
||||
"chunks": [
|
||||
{
|
||||
"chunk": chunk.chunk.decode("utf-8") if isinstance(chunk.chunk, bytes) else chunk.chunk,
|
||||
"vectors": chunk.vectors
|
||||
}
|
||||
for chunk in obj.chunks
|
||||
]
|
||||
}
|
||||
|
||||
if obj.metadata:
|
||||
metadata_dict = {}
|
||||
if obj.metadata.id:
|
||||
metadata_dict["id"] = obj.metadata.id
|
||||
if obj.metadata.user:
|
||||
metadata_dict["user"] = obj.metadata.user
|
||||
if obj.metadata.collection:
|
||||
metadata_dict["collection"] = obj.metadata.collection
|
||||
if obj.metadata.metadata:
|
||||
metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata)
|
||||
|
||||
result["metadata"] = metadata_dict
|
||||
|
||||
return result
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class EmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for EmbeddingsRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsRequest:
|
||||
return EmbeddingsRequest(
|
||||
text=data["text"]
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: EmbeddingsRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"text": obj.text
|
||||
}
|
||||
|
||||
|
||||
class EmbeddingsResponseTranslator(MessageTranslator):
|
||||
"""Translator for EmbeddingsResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: EmbeddingsResponse) -> Dict[str, Any]:
|
||||
return {
|
||||
"vectors": obj.vectors
|
||||
}
|
||||
|
||||
def from_response_with_completion(self, obj: EmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import (
|
||||
DocumentEmbeddingsRequest, DocumentEmbeddingsResponse,
|
||||
GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
)
|
||||
from .base import MessageTranslator
|
||||
from .primitives import ValueTranslator
|
||||
|
||||
|
||||
class DocumentEmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for DocumentEmbeddingsRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
|
||||
return DocumentEmbeddingsRequest(
|
||||
vectors=data["vectors"],
|
||||
limit=int(data.get("limit", 10)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"vectors": obj.vectors,
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection
|
||||
}
|
||||
|
||||
|
||||
class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
||||
"""Translator for DocumentEmbeddingsResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.documents:
|
||||
result["documents"] = [
|
||||
doc.decode("utf-8") if isinstance(doc, bytes) else doc
|
||||
for doc in obj.documents
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
||||
|
||||
class GraphEmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for GraphEmbeddingsRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
|
||||
return GraphEmbeddingsRequest(
|
||||
vectors=data["vectors"],
|
||||
limit=int(data.get("limit", 10)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"vectors": obj.vectors,
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection
|
||||
}
|
||||
|
||||
|
||||
class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
||||
"""Translator for GraphEmbeddingsResponse schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.value_translator = ValueTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.entities:
|
||||
result["entities"] = [
|
||||
self.value_translator.from_pulsar(entity)
|
||||
for entity in obj.entities
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
59
trustgraph-base/trustgraph/messaging/translators/flow.py
Normal file
59
trustgraph-base/trustgraph/messaging/translators/flow.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import FlowRequest, FlowResponse
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class FlowRequestTranslator(MessageTranslator):
|
||||
"""Translator for FlowRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> FlowRequest:
|
||||
return FlowRequest(
|
||||
operation=data.get("operation"),
|
||||
class_name=data.get("class-name"),
|
||||
class_definition=data.get("class-definition"),
|
||||
description=data.get("description"),
|
||||
flow_id=data.get("flow-id")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: FlowRequest) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.operation:
|
||||
result["operation"] = obj.operation
|
||||
if obj.class_name:
|
||||
result["class-name"] = obj.class_name
|
||||
if obj.class_definition:
|
||||
result["class-definition"] = obj.class_definition
|
||||
if obj.description:
|
||||
result["description"] = obj.description
|
||||
if obj.flow_id:
|
||||
result["flow-id"] = obj.flow_id
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class FlowResponseTranslator(MessageTranslator):
|
||||
"""Translator for FlowResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> FlowResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: FlowResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.class_names:
|
||||
result["class-names"] = obj.class_names
|
||||
if obj.flow_ids:
|
||||
result["flow-ids"] = obj.flow_ids
|
||||
if obj.class_definition:
|
||||
result["class-definition"] = obj.class_definition
|
||||
if obj.flow:
|
||||
result["flow"] = obj.flow
|
||||
if obj.description:
|
||||
result["description"] = obj.description
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: FlowResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
183
trustgraph-base/trustgraph/messaging/translators/knowledge.py
Normal file
183
trustgraph-base/trustgraph/messaging/translators/knowledge.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
from typing import Dict, Any, Tuple, Optional
|
||||
from ...schema import (
|
||||
KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings,
|
||||
Metadata, EntityEmbeddings
|
||||
)
|
||||
from .base import MessageTranslator
|
||||
from .primitives import ValueTranslator, SubgraphTranslator
|
||||
from .metadata import DocumentMetadataTranslator
|
||||
|
||||
|
||||
class KnowledgeRequestTranslator(MessageTranslator):
|
||||
"""Translator for KnowledgeRequest schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.value_translator = ValueTranslator()
|
||||
self.subgraph_translator = SubgraphTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> KnowledgeRequest:
|
||||
triples = None
|
||||
if "triples" in data:
|
||||
triples = Triples(
|
||||
metadata=Metadata(
|
||||
id=data["triples"]["metadata"]["id"],
|
||||
metadata=self.subgraph_translator.to_pulsar(
|
||||
data["triples"]["metadata"]["metadata"]
|
||||
),
|
||||
user=data["triples"]["metadata"]["user"],
|
||||
collection=data["triples"]["metadata"]["collection"]
|
||||
),
|
||||
triples=self.subgraph_translator.to_pulsar(data["triples"]["triples"]),
|
||||
)
|
||||
|
||||
graph_embeddings = None
|
||||
if "graph-embeddings" in data:
|
||||
graph_embeddings = GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id=data["graph-embeddings"]["metadata"]["id"],
|
||||
metadata=self.subgraph_translator.to_pulsar(
|
||||
data["graph-embeddings"]["metadata"]["metadata"]
|
||||
),
|
||||
user=data["graph-embeddings"]["metadata"]["user"],
|
||||
collection=data["graph-embeddings"]["metadata"]["collection"]
|
||||
),
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=self.value_translator.to_pulsar(ent["entity"]),
|
||||
vectors=ent["vectors"],
|
||||
)
|
||||
for ent in data["graph-embeddings"]["entities"]
|
||||
]
|
||||
)
|
||||
|
||||
return KnowledgeRequest(
|
||||
operation=data.get("operation"),
|
||||
user=data.get("user"),
|
||||
id=data.get("id"),
|
||||
flow=data.get("flow"),
|
||||
collection=data.get("collection"),
|
||||
triples=triples,
|
||||
graph_embeddings=graph_embeddings,
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: KnowledgeRequest) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.operation:
|
||||
result["operation"] = obj.operation
|
||||
if obj.user:
|
||||
result["user"] = obj.user
|
||||
if obj.id:
|
||||
result["id"] = obj.id
|
||||
if obj.flow:
|
||||
result["flow"] = obj.flow
|
||||
if obj.collection:
|
||||
result["collection"] = obj.collection
|
||||
|
||||
if obj.triples:
|
||||
result["triples"] = {
|
||||
"metadata": {
|
||||
"id": obj.triples.metadata.id,
|
||||
"metadata": self.subgraph_translator.from_pulsar(
|
||||
obj.triples.metadata.metadata
|
||||
),
|
||||
"user": obj.triples.metadata.user,
|
||||
"collection": obj.triples.metadata.collection,
|
||||
},
|
||||
"triples": self.subgraph_translator.from_pulsar(obj.triples.triples),
|
||||
}
|
||||
|
||||
if obj.graph_embeddings:
|
||||
result["graph-embeddings"] = {
|
||||
"metadata": {
|
||||
"id": obj.graph_embeddings.metadata.id,
|
||||
"metadata": self.subgraph_translator.from_pulsar(
|
||||
obj.graph_embeddings.metadata.metadata
|
||||
),
|
||||
"user": obj.graph_embeddings.metadata.user,
|
||||
"collection": obj.graph_embeddings.metadata.collection,
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"vectors": entity.vectors,
|
||||
"entity": self.value_translator.from_pulsar(entity.entity),
|
||||
}
|
||||
for entity in obj.graph_embeddings.entities
|
||||
],
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class KnowledgeResponseTranslator(MessageTranslator):
|
||||
"""Translator for KnowledgeResponse schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.value_translator = ValueTranslator()
|
||||
self.subgraph_translator = SubgraphTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> KnowledgeResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: KnowledgeResponse) -> Dict[str, Any]:
|
||||
# Response to list operation
|
||||
if obj.ids is not None:
|
||||
return {"ids": obj.ids}
|
||||
|
||||
# Streaming triples response
|
||||
if obj.triples:
|
||||
return {
|
||||
"triples": {
|
||||
"metadata": {
|
||||
"id": obj.triples.metadata.id,
|
||||
"metadata": self.subgraph_translator.from_pulsar(
|
||||
obj.triples.metadata.metadata
|
||||
),
|
||||
"user": obj.triples.metadata.user,
|
||||
"collection": obj.triples.metadata.collection,
|
||||
},
|
||||
"triples": self.subgraph_translator.from_pulsar(obj.triples.triples),
|
||||
}
|
||||
}
|
||||
|
||||
# Streaming graph embeddings response
|
||||
if obj.graph_embeddings:
|
||||
return {
|
||||
"graph-embeddings": {
|
||||
"metadata": {
|
||||
"id": obj.graph_embeddings.metadata.id,
|
||||
"metadata": self.subgraph_translator.from_pulsar(
|
||||
obj.graph_embeddings.metadata.metadata
|
||||
),
|
||||
"user": obj.graph_embeddings.metadata.user,
|
||||
"collection": obj.graph_embeddings.metadata.collection,
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"vectors": entity.vectors,
|
||||
"entity": self.value_translator.from_pulsar(entity.entity),
|
||||
}
|
||||
for entity in obj.graph_embeddings.entities
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
# End of stream marker
|
||||
if obj.eos is True:
|
||||
return {"eos": True}
|
||||
|
||||
# Empty response (successful delete)
|
||||
return {}
|
||||
|
||||
def from_response_with_completion(self, obj: KnowledgeResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
response = self.from_pulsar(obj)
|
||||
|
||||
# Check if this is a final response
|
||||
is_final = (
|
||||
obj.ids is not None or # List response
|
||||
obj.eos is True or # End of stream
|
||||
(not obj.triples and not obj.graph_embeddings) # Empty response
|
||||
)
|
||||
|
||||
return response, is_final
|
||||
124
trustgraph-base/trustgraph/messaging/translators/library.py
Normal file
124
trustgraph-base/trustgraph/messaging/translators/library.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
from typing import Dict, Any, Tuple, Optional
|
||||
from ...schema import LibrarianRequest, LibrarianResponse, DocumentMetadata, ProcessingMetadata, Criteria
|
||||
from .base import MessageTranslator
|
||||
from .metadata import DocumentMetadataTranslator, ProcessingMetadataTranslator
|
||||
|
||||
|
||||
class LibraryRequestTranslator(MessageTranslator):
|
||||
"""Translator for LibrarianRequest schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.doc_metadata_translator = DocumentMetadataTranslator()
|
||||
self.proc_metadata_translator = ProcessingMetadataTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> LibrarianRequest:
|
||||
# Document metadata
|
||||
doc_metadata = None
|
||||
if "document-metadata" in data:
|
||||
doc_metadata = self.doc_metadata_translator.to_pulsar(data["document-metadata"])
|
||||
|
||||
# Processing metadata
|
||||
proc_metadata = None
|
||||
if "processing-metadata" in data:
|
||||
proc_metadata = self.proc_metadata_translator.to_pulsar(data["processing-metadata"])
|
||||
|
||||
# Criteria
|
||||
criteria = []
|
||||
if "criteria" in data:
|
||||
criteria = [
|
||||
Criteria(
|
||||
key=c["key"],
|
||||
value=c["value"],
|
||||
operator=c["operator"]
|
||||
)
|
||||
for c in data["criteria"]
|
||||
]
|
||||
|
||||
# Content as bytes
|
||||
content = None
|
||||
if "content" in data:
|
||||
if isinstance(data["content"], str):
|
||||
content = data["content"].encode("utf-8")
|
||||
else:
|
||||
content = data["content"]
|
||||
|
||||
return LibrarianRequest(
|
||||
operation=data.get("operation"),
|
||||
document_id=data.get("document-id"),
|
||||
processing_id=data.get("processing-id"),
|
||||
document_metadata=doc_metadata,
|
||||
processing_metadata=proc_metadata,
|
||||
content=content,
|
||||
user=data.get("user"),
|
||||
collection=data.get("collection"),
|
||||
criteria=criteria
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: LibrarianRequest) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.operation:
|
||||
result["operation"] = obj.operation
|
||||
if obj.document_id:
|
||||
result["document-id"] = obj.document_id
|
||||
if obj.processing_id:
|
||||
result["processing-id"] = obj.processing_id
|
||||
if obj.document_metadata:
|
||||
result["document-metadata"] = self.doc_metadata_translator.from_pulsar(obj.document_metadata)
|
||||
if obj.processing_metadata:
|
||||
result["processing-metadata"] = self.proc_metadata_translator.from_pulsar(obj.processing_metadata)
|
||||
if obj.content:
|
||||
result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content
|
||||
if obj.user:
|
||||
result["user"] = obj.user
|
||||
if obj.collection:
|
||||
result["collection"] = obj.collection
|
||||
if obj.criteria is not None:
|
||||
result["criteria"] = [
|
||||
{
|
||||
"key": c.key,
|
||||
"value": c.value,
|
||||
"operator": c.operator
|
||||
}
|
||||
for c in obj.criteria
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class LibraryResponseTranslator(MessageTranslator):
|
||||
"""Translator for LibrarianResponse schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.doc_metadata_translator = DocumentMetadataTranslator()
|
||||
self.proc_metadata_translator = ProcessingMetadataTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> LibrarianResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: LibrarianResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.document_metadata:
|
||||
result["document-metadata"] = self.doc_metadata_translator.from_pulsar(obj.document_metadata)
|
||||
|
||||
if obj.content:
|
||||
result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content
|
||||
|
||||
if obj.document_metadatas is not None:
|
||||
result["document-metadatas"] = [
|
||||
self.doc_metadata_translator.from_pulsar(dm)
|
||||
for dm in obj.document_metadatas
|
||||
]
|
||||
|
||||
if obj.processing_metadatas is not None:
|
||||
result["processing-metadatas"] = [
|
||||
self.proc_metadata_translator.from_pulsar(pm)
|
||||
for pm in obj.processing_metadatas
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: LibrarianResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
81
trustgraph-base/trustgraph/messaging/translators/metadata.py
Normal file
81
trustgraph-base/trustgraph/messaging/translators/metadata.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
from typing import Dict, Any, Optional
|
||||
from ...schema import DocumentMetadata, ProcessingMetadata
|
||||
from .base import Translator
|
||||
from .primitives import SubgraphTranslator
|
||||
|
||||
|
||||
class DocumentMetadataTranslator(Translator):
|
||||
"""Translator for DocumentMetadata schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.subgraph_translator = SubgraphTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentMetadata:
|
||||
metadata = data.get("metadata", [])
|
||||
return DocumentMetadata(
|
||||
id=data.get("id"),
|
||||
time=data.get("time"),
|
||||
kind=data.get("kind"),
|
||||
title=data.get("title"),
|
||||
comments=data.get("comments"),
|
||||
metadata=self.subgraph_translator.to_pulsar(metadata) if metadata is not None else [],
|
||||
user=data.get("user"),
|
||||
tags=data.get("tags")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: DocumentMetadata) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.id:
|
||||
result["id"] = obj.id
|
||||
if obj.time:
|
||||
result["time"] = obj.time
|
||||
if obj.kind:
|
||||
result["kind"] = obj.kind
|
||||
if obj.title:
|
||||
result["title"] = obj.title
|
||||
if obj.comments:
|
||||
result["comments"] = obj.comments
|
||||
if obj.metadata is not None:
|
||||
result["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata)
|
||||
if obj.user:
|
||||
result["user"] = obj.user
|
||||
if obj.tags is not None:
|
||||
result["tags"] = obj.tags
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ProcessingMetadataTranslator(Translator):
|
||||
"""Translator for ProcessingMetadata schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> ProcessingMetadata:
|
||||
return ProcessingMetadata(
|
||||
id=data.get("id"),
|
||||
document_id=data.get("document-id"),
|
||||
time=data.get("time"),
|
||||
flow=data.get("flow"),
|
||||
user=data.get("user"),
|
||||
collection=data.get("collection"),
|
||||
tags=data.get("tags")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: ProcessingMetadata) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.id:
|
||||
result["id"] = obj.id
|
||||
if obj.document_id:
|
||||
result["document-id"] = obj.document_id
|
||||
if obj.time:
|
||||
result["time"] = obj.time
|
||||
if obj.flow:
|
||||
result["flow"] = obj.flow
|
||||
if obj.user:
|
||||
result["user"] = obj.user
|
||||
if obj.collection:
|
||||
result["collection"] = obj.collection
|
||||
if obj.tags is not None:
|
||||
result["tags"] = obj.tags
|
||||
|
||||
return result
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
from typing import Dict, Any, List
|
||||
from ...schema import Value, Triple
|
||||
from .base import Translator
|
||||
|
||||
|
||||
class ValueTranslator(Translator):
|
||||
"""Translator for Value schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> Value:
|
||||
return Value(value=data["v"], is_uri=data["e"])
|
||||
|
||||
def from_pulsar(self, obj: Value) -> Dict[str, Any]:
|
||||
return {"v": obj.value, "e": obj.is_uri}
|
||||
|
||||
|
||||
class TripleTranslator(Translator):
|
||||
"""Translator for Triple schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.value_translator = ValueTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> Triple:
|
||||
return Triple(
|
||||
s=self.value_translator.to_pulsar(data["s"]),
|
||||
p=self.value_translator.to_pulsar(data["p"]),
|
||||
o=self.value_translator.to_pulsar(data["o"])
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: Triple) -> Dict[str, Any]:
|
||||
return {
|
||||
"s": self.value_translator.from_pulsar(obj.s),
|
||||
"p": self.value_translator.from_pulsar(obj.p),
|
||||
"o": self.value_translator.from_pulsar(obj.o)
|
||||
}
|
||||
|
||||
|
||||
class SubgraphTranslator(Translator):
|
||||
"""Translator for lists of Triple objects (subgraphs)"""
|
||||
|
||||
def __init__(self):
|
||||
self.triple_translator = TripleTranslator()
|
||||
|
||||
def to_pulsar(self, data: List[Dict[str, Any]]) -> List[Triple]:
|
||||
return [self.triple_translator.to_pulsar(t) for t in data]
|
||||
|
||||
def from_pulsar(self, obj: List[Triple]) -> List[Dict[str, Any]]:
|
||||
return [self.triple_translator.from_pulsar(t) for t in obj]
|
||||
54
trustgraph-base/trustgraph/messaging/translators/prompt.py
Normal file
54
trustgraph-base/trustgraph/messaging/translators/prompt.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
import json
|
||||
from typing import Dict, Any, Tuple
|
||||
from ...schema import PromptRequest, PromptResponse
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class PromptRequestTranslator(MessageTranslator):
|
||||
"""Translator for PromptRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> PromptRequest:
|
||||
# Handle both "terms" and "variables" input keys
|
||||
terms = data.get("terms", {})
|
||||
if "variables" in data:
|
||||
# Convert variables to JSON strings as expected by the service
|
||||
terms = {
|
||||
k: json.dumps(v)
|
||||
for k, v in data["variables"].items()
|
||||
}
|
||||
|
||||
return PromptRequest(
|
||||
id=data.get("id"),
|
||||
terms=terms
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: PromptRequest) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.id:
|
||||
result["id"] = obj.id
|
||||
if obj.terms:
|
||||
result["terms"] = obj.terms
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class PromptResponseTranslator(MessageTranslator):
|
||||
"""Translator for PromptResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> PromptResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]:
|
||||
result = {}
|
||||
|
||||
if obj.text:
|
||||
result["text"] = obj.text
|
||||
if obj.object:
|
||||
result["object"] = obj.object
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import DocumentRagQuery, DocumentRagResponse, GraphRagQuery, GraphRagResponse
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class DocumentRagRequestTranslator(MessageTranslator):
|
||||
"""Translator for DocumentRagQuery schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagQuery:
|
||||
return DocumentRagQuery(
|
||||
query=data["query"],
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
doc_limit=int(data.get("doc-limit", 20))
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: DocumentRagQuery) -> Dict[str, Any]:
|
||||
return {
|
||||
"query": obj.query,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection,
|
||||
"doc-limit": obj.doc_limit
|
||||
}
|
||||
|
||||
|
||||
class DocumentRagResponseTranslator(MessageTranslator):
|
||||
"""Translator for DocumentRagResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
|
||||
return {
|
||||
"response": obj.response
|
||||
}
|
||||
|
||||
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
||||
|
||||
class GraphRagRequestTranslator(MessageTranslator):
|
||||
"""Translator for GraphRagQuery schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphRagQuery:
|
||||
return GraphRagQuery(
|
||||
query=data["query"],
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
entity_limit=int(data.get("entity-limit", 50)),
|
||||
triple_limit=int(data.get("triple-limit", 30)),
|
||||
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
|
||||
max_path_length=int(data.get("max-path-length", 2))
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: GraphRagQuery) -> Dict[str, Any]:
|
||||
return {
|
||||
"query": obj.query,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection,
|
||||
"entity-limit": obj.entity_limit,
|
||||
"triple-limit": obj.triple_limit,
|
||||
"max-subgraph-size": obj.max_subgraph_size,
|
||||
"max-path-length": obj.max_path_length
|
||||
}
|
||||
|
||||
|
||||
class GraphRagResponseTranslator(MessageTranslator):
|
||||
"""Translator for GraphRagResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> GraphRagResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
|
||||
return {
|
||||
"response": obj.response
|
||||
}
|
||||
|
||||
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
from typing import Dict, Any, Tuple
|
||||
from ...schema import TextCompletionRequest, TextCompletionResponse
|
||||
from .base import MessageTranslator
|
||||
|
||||
|
||||
class TextCompletionRequestTranslator(MessageTranslator):
|
||||
"""Translator for TextCompletionRequest schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> TextCompletionRequest:
|
||||
return TextCompletionRequest(
|
||||
system=data["system"],
|
||||
prompt=data["prompt"]
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: TextCompletionRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"system": obj.system,
|
||||
"prompt": obj.prompt
|
||||
}
|
||||
|
||||
|
||||
class TextCompletionResponseTranslator(MessageTranslator):
|
||||
"""Translator for TextCompletionResponse schema objects"""
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> TextCompletionResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: TextCompletionResponse) -> Dict[str, Any]:
|
||||
result = {"response": obj.response}
|
||||
|
||||
if obj.in_token:
|
||||
result["in_token"] = obj.in_token
|
||||
if obj.out_token:
|
||||
result["out_token"] = obj.out_token
|
||||
if obj.model:
|
||||
result["model"] = obj.model
|
||||
|
||||
return result
|
||||
|
||||
def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
60
trustgraph-base/trustgraph/messaging/translators/triples.py
Normal file
60
trustgraph-base/trustgraph/messaging/translators/triples.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
from typing import Dict, Any, Tuple, Optional
|
||||
from ...schema import TriplesQueryRequest, TriplesQueryResponse
|
||||
from .base import MessageTranslator
|
||||
from .primitives import ValueTranslator, SubgraphTranslator
|
||||
|
||||
|
||||
class TriplesQueryRequestTranslator(MessageTranslator):
|
||||
"""Translator for TriplesQueryRequest schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.value_translator = ValueTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> TriplesQueryRequest:
|
||||
s = self.value_translator.to_pulsar(data["s"]) if "s" in data else None
|
||||
p = self.value_translator.to_pulsar(data["p"]) if "p" in data else None
|
||||
o = self.value_translator.to_pulsar(data["o"]) if "o" in data else None
|
||||
|
||||
return TriplesQueryRequest(
|
||||
s=s,
|
||||
p=p,
|
||||
o=o,
|
||||
limit=int(data.get("limit", 10000)),
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default")
|
||||
)
|
||||
|
||||
def from_pulsar(self, obj: TriplesQueryRequest) -> Dict[str, Any]:
|
||||
result = {
|
||||
"limit": obj.limit,
|
||||
"user": obj.user,
|
||||
"collection": obj.collection
|
||||
}
|
||||
|
||||
if obj.s:
|
||||
result["s"] = self.value_translator.from_pulsar(obj.s)
|
||||
if obj.p:
|
||||
result["p"] = self.value_translator.from_pulsar(obj.p)
|
||||
if obj.o:
|
||||
result["o"] = self.value_translator.from_pulsar(obj.o)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TriplesQueryResponseTranslator(MessageTranslator):
|
||||
"""Translator for TriplesQueryResponse schema objects"""
|
||||
|
||||
def __init__(self):
|
||||
self.subgraph_translator = SubgraphTranslator()
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> TriplesQueryResponse:
|
||||
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||
|
||||
def from_pulsar(self, obj: TriplesQueryResponse) -> Dict[str, Any]:
|
||||
return {
|
||||
"response": self.subgraph_translator.from_pulsar(obj.triples)
|
||||
}
|
||||
|
||||
def from_response_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]:
|
||||
"""Returns (response_dict, is_final)"""
|
||||
return self.from_pulsar(obj), True
|
||||
|
|
@ -1,80 +1,68 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Loads Graph embeddings into TrustGraph processing.
|
||||
|
||||
FIXME: This hasn't been updated following API gateway change.
|
||||
Loads triples into the knowledge graph.
|
||||
"""
|
||||
|
||||
import pulsar
|
||||
from pulsar.schema import JsonSchema
|
||||
from trustgraph.schema import Triples, Triple, Value, Metadata
|
||||
import asyncio
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import pyarrow as pa
|
||||
import rdflib
|
||||
import json
|
||||
from websockets.asyncio.client import connect
|
||||
|
||||
from trustgraph.log_level import LogLevel
|
||||
|
||||
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
|
||||
default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
|
||||
|
||||
default_output_queue = triples_store_queue
|
||||
|
||||
class Loader:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host,
|
||||
output_queue,
|
||||
log_level,
|
||||
files,
|
||||
flow,
|
||||
user,
|
||||
collection,
|
||||
pulsar_api_key=None,
|
||||
document_id,
|
||||
url = default_url,
|
||||
):
|
||||
|
||||
if pulsar_api_key:
|
||||
auth = pulsar.AuthenticationToken(pulsar_api_key)
|
||||
self.client = pulsar.Client(
|
||||
pulsar_host,
|
||||
authentication=auth,
|
||||
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
|
||||
)
|
||||
else:
|
||||
self.client = pulsar.Client(
|
||||
pulsar_host,
|
||||
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
|
||||
)
|
||||
if not url.endswith("/"):
|
||||
url += "/"
|
||||
|
||||
self.producer = self.client.create_producer(
|
||||
topic=output_queue,
|
||||
schema=JsonSchema(Triples),
|
||||
chunking_enabled=True,
|
||||
)
|
||||
url = url + f"api/v1/flow/{flow}/import/triples"
|
||||
|
||||
self.url = url
|
||||
|
||||
self.files = files
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
self.document_id = document_id
|
||||
|
||||
def run(self):
|
||||
async def run(self):
|
||||
|
||||
try:
|
||||
|
||||
for file in self.files:
|
||||
self.load_file(file)
|
||||
async with connect(self.url) as ws:
|
||||
for file in self.files:
|
||||
await self.load_file(file, ws)
|
||||
|
||||
except Exception as e:
|
||||
print(e, flush=True)
|
||||
|
||||
def load_file(self, file):
|
||||
async def load_file(self, file, ws):
|
||||
|
||||
g = rdflib.Graph()
|
||||
g.parse(file, format="turtle")
|
||||
|
||||
def Value(value, is_uri):
|
||||
return { "v": value, "e": is_uri }
|
||||
|
||||
triples = []
|
||||
|
||||
for e in g:
|
||||
s = Value(value=str(e[0]), is_uri=True)
|
||||
p = Value(value=str(e[1]), is_uri=True)
|
||||
|
|
@ -83,20 +71,23 @@ class Loader:
|
|||
else:
|
||||
o = Value(value=str(e[2]), is_uri=False)
|
||||
|
||||
r = Triples(
|
||||
metadata=Metadata(
|
||||
id=None,
|
||||
metadata=[],
|
||||
user=self.user,
|
||||
collection=self.collection,
|
||||
),
|
||||
triples=[ Triple(s=s, p=p, o=o) ]
|
||||
)
|
||||
req = {
|
||||
"metadata": {
|
||||
"id": self.document_id,
|
||||
"metadata": [],
|
||||
"user": self.user,
|
||||
"collection": self.collection
|
||||
},
|
||||
"triples": [
|
||||
{
|
||||
"s": s,
|
||||
"p": p,
|
||||
"o": o,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
self.producer.send(r)
|
||||
|
||||
def __del__(self):
|
||||
self.client.close()
|
||||
await ws.send(json.dumps(req))
|
||||
|
||||
def main():
|
||||
|
||||
|
|
@ -106,9 +97,15 @@ def main():
|
|||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-p', '--pulsar-host',
|
||||
default=default_pulsar_host,
|
||||
help=f'Pulsar host (default: {default_pulsar_host})',
|
||||
'-u', '--api-url',
|
||||
default=default_url,
|
||||
help=f'API URL (default: {default_url})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-i', '--document-id',
|
||||
required=True,
|
||||
help=f'Document ID)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -116,39 +113,19 @@ def main():
|
|||
default="default",
|
||||
help=f'Flow ID (default: default)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-api-key',
|
||||
default=default_pulsar_api_key,
|
||||
help=f'Pulsar API key',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-o', '--output-queue',
|
||||
default=default_output_queue,
|
||||
help=f'Output queue (default: {default_output_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--user',
|
||||
'-U', '--user',
|
||||
default=default_user,
|
||||
help=f'User ID (default: {default_user})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--collection',
|
||||
'-C', '--collection',
|
||||
default=default_collection,
|
||||
help=f'Collection ID (default: {default_collection})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--log-level',
|
||||
type=LogLevel,
|
||||
default=LogLevel.ERROR,
|
||||
choices=list(LogLevel),
|
||||
help=f'Output queue (default: info)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'files', nargs='+',
|
||||
help=f'Turtle files to load'
|
||||
|
|
@ -160,16 +137,15 @@ def main():
|
|||
|
||||
try:
|
||||
p = Loader(
|
||||
pulsar_host=args.pulsar_host,
|
||||
pulsar_api_key=args.pulsar_api_key,
|
||||
output_queue=args.output_queue,
|
||||
log_level=args.log_level,
|
||||
files=args.files,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
document_id = args.document_id,
|
||||
url = args.api_url,
|
||||
flow = args.flow_id,
|
||||
files = args.files,
|
||||
user = args.user,
|
||||
collection = args.collection,
|
||||
)
|
||||
|
||||
p.run()
|
||||
asyncio.run(p.run())
|
||||
|
||||
print("File loaded.")
|
||||
break
|
||||
|
|
@ -181,6 +157,5 @@ def main():
|
|||
|
||||
time.sleep(10)
|
||||
|
||||
print("Not implemented.")
|
||||
#main()
|
||||
main()
|
||||
|
||||
|
|
|
|||
109
trustgraph-cli/scripts/tg-show-token-rate
Executable file
109
trustgraph-cli/scripts/tg-show-token-rate
Executable file
|
|
@ -0,0 +1,109 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Dump out a stream of token rates, input, output and total. This is averaged
|
||||
across the time since tg-show-token-rate is started.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
|
||||
default_metrics_url = "http://localhost:8088/api/metrics"
|
||||
|
||||
class Collate:
|
||||
|
||||
def look(self, data):
|
||||
return sum(
|
||||
[
|
||||
float(x["value"][1])
|
||||
for x in data["data"]["result"]
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, data):
|
||||
self.last = self.look(data)
|
||||
self.total = 0
|
||||
self.time = 0
|
||||
|
||||
def record(self, data, time):
|
||||
|
||||
value = self.look(data)
|
||||
delta = value - self.last
|
||||
self.last = value
|
||||
|
||||
self.total += delta
|
||||
self.time += time
|
||||
|
||||
return delta/time, self.total/self.time
|
||||
|
||||
def dump_status(metrics_url, number_samples, period):
|
||||
|
||||
input_url = f"{metrics_url}/query?query=input_tokens_total"
|
||||
output_url = f"{metrics_url}/query?query=output_tokens_total"
|
||||
|
||||
resp = requests.get(input_url)
|
||||
obj = resp.json()
|
||||
input = Collate(obj)
|
||||
|
||||
resp = requests.get(output_url)
|
||||
obj = resp.json()
|
||||
output = Collate(obj)
|
||||
|
||||
print(f"{'Input':>10s} {'Output':>10s} {'Total':>10s}")
|
||||
print(f"{'-----':>10s} {'------':>10s} {'-----':>10s}")
|
||||
|
||||
for i in range(number_samples):
|
||||
|
||||
time.sleep(period)
|
||||
|
||||
resp = requests.get(input_url)
|
||||
obj = resp.json()
|
||||
inr, inl = input.record(obj, period)
|
||||
|
||||
resp = requests.get(output_url)
|
||||
obj = resp.json()
|
||||
outr, outl = output.record(obj, period)
|
||||
|
||||
print(f"{inl:10.1f} {outl:10.1f} {inl+outl:10.1f}")
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='tg-show-processor-state',
|
||||
description=__doc__,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--metrics-url',
|
||||
default=default_metrics_url,
|
||||
help=f'Metrics URL (default: {default_metrics_url})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-p', '--period',
|
||||
type=int,
|
||||
default=1,
|
||||
help=f'Metrics period (default: 1)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-n', '--number-samples',
|
||||
type=int,
|
||||
default=100,
|
||||
help=f'Metrics period (default: 100)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
dump_status(**vars(args))
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
main()
|
||||
|
||||
|
|
@ -80,6 +80,7 @@ setuptools.setup(
|
|||
"scripts/tg-show-processor-state",
|
||||
"scripts/tg-show-prompts",
|
||||
"scripts/tg-show-token-costs",
|
||||
"scripts/tg-show-token-rate",
|
||||
"scripts/tg-show-tools",
|
||||
"scripts/tg-start-flow",
|
||||
"scripts/tg-unload-kg-core",
|
||||
|
|
|
|||
6
trustgraph-flow/scripts/rev-gateway
Executable file
6
trustgraph-flow/scripts/rev-gateway
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.rev_gateway import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/text-completion-vllm
Executable file
6
trustgraph-flow/scripts/text-completion-vllm
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.model.text_completion.vllm import run
|
||||
|
||||
run()
|
||||
|
||||
|
|
@ -71,6 +71,7 @@ setuptools.setup(
|
|||
scripts=[
|
||||
"scripts/agent-manager-react",
|
||||
"scripts/api-gateway",
|
||||
"scripts/rev-gateway",
|
||||
"scripts/chunker-recursive",
|
||||
"scripts/chunker-token",
|
||||
"scripts/config-svc",
|
||||
|
|
@ -118,6 +119,7 @@ setuptools.setup(
|
|||
"scripts/text-completion-ollama",
|
||||
"scripts/text-completion-openai",
|
||||
"scripts/text-completion-tgi",
|
||||
"scripts/text-completion-vllm",
|
||||
"scripts/triples-query-cassandra",
|
||||
"scripts/triples-query-falkordb",
|
||||
"scripts/triples-query-memgraph",
|
||||
|
|
|
|||
|
|
@ -21,16 +21,19 @@ RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
|
|||
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
||||
|
||||
default_ident = "kg-extract-definitions"
|
||||
default_concurrency = 1
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"concurrency": concurrency,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -38,7 +41,8 @@ class Processor(FlowProcessor):
|
|||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = Chunk,
|
||||
handler = self.on_message
|
||||
handler = self.on_message,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -190,6 +194,13 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
def run():
|
||||
|
|
|
|||
|
|
@ -20,16 +20,19 @@ RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
|
|||
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
||||
|
||||
default_ident = "kg-extract-relationships"
|
||||
default_concurrency = 1
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"concurrency": concurrency,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -37,7 +40,8 @@ class Processor(FlowProcessor):
|
|||
ConsumerSpec(
|
||||
name = "input",
|
||||
schema = Chunk,
|
||||
handler = self.on_message
|
||||
handler = self.on_message,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -192,6 +196,13 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
def run():
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
from ... schema import AgentRequest, AgentResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
|
|
@ -20,24 +21,12 @@ class AgentRequestor(ServiceRequestor):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("agent")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("agent")
|
||||
|
||||
def to_request(self, body):
|
||||
return AgentRequest(
|
||||
question=body["question"]
|
||||
)
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
resp = {
|
||||
}
|
||||
|
||||
if message.answer:
|
||||
resp["answer"] = message.answer
|
||||
|
||||
if message.thought:
|
||||
resp["thought"] = message.thought
|
||||
|
||||
if message.observation:
|
||||
resp["observation"] = message.observation
|
||||
|
||||
# The 2nd boolean expression indicates whether we're done responding
|
||||
return resp, (message.answer is not None)
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
from ... schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
|
||||
from ... schema import config_request_queue
|
||||
from ... schema import config_response_queue
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
|
|
@ -19,60 +20,12 @@ class ConfigRequestor(ServiceRequestor):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("config")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("config")
|
||||
|
||||
def to_request(self, body):
|
||||
|
||||
if "keys" in body:
|
||||
keys = [
|
||||
ConfigKey(
|
||||
type = k["type"],
|
||||
key = k["key"],
|
||||
)
|
||||
for k in body["keys"]
|
||||
]
|
||||
else:
|
||||
keys = None
|
||||
|
||||
if "values" in body:
|
||||
values = [
|
||||
ConfigValue(
|
||||
type = v["type"],
|
||||
key = v["key"],
|
||||
value = v["value"],
|
||||
)
|
||||
for v in body["values"]
|
||||
]
|
||||
else:
|
||||
values = None
|
||||
|
||||
return ConfigRequest(
|
||||
operation = body.get("operation", None),
|
||||
keys = keys,
|
||||
type = body.get("type", None),
|
||||
values = values
|
||||
)
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
|
||||
response = { }
|
||||
|
||||
if message.version is not None:
|
||||
response["version"] = message.version
|
||||
|
||||
if message.values is not None:
|
||||
response["values"] = [
|
||||
{
|
||||
"type": v.type,
|
||||
"key": v.key,
|
||||
"value": v.value,
|
||||
}
|
||||
for v in message.values
|
||||
]
|
||||
|
||||
if message.directory is not None:
|
||||
response["directory"] = message.directory
|
||||
|
||||
if message.config is not None:
|
||||
response["config"] = message.config
|
||||
|
||||
return response, True
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
||||
|
|
|
|||
96
trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
Normal file
96
trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
import msgpack
|
||||
from . knowledge import KnowledgeRequestor
|
||||
|
||||
class CoreExport:
|
||||
|
||||
def __init__(self, pulsar_client):
|
||||
self.pulsar_client = pulsar_client
|
||||
|
||||
async def process(self, data, error, ok, request):
|
||||
|
||||
id = request.query["id"]
|
||||
user = request.query["user"]
|
||||
|
||||
response = await ok()
|
||||
|
||||
kr = KnowledgeRequestor(
|
||||
pulsar_client = self.pulsar_client,
|
||||
consumer = "api-gateway-core-export-" + str(uuid.uuid4()),
|
||||
subscriber = "api-gateway-core-export-" + str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
await kr.start()
|
||||
|
||||
async def responder(resp, fin):
|
||||
|
||||
if "graph-embeddings" in resp:
|
||||
|
||||
data = resp["graph-embeddings"]
|
||||
|
||||
msg = (
|
||||
"ge",
|
||||
{
|
||||
"m": {
|
||||
"i": data["metadata"]["id"],
|
||||
"m": data["metadata"]["metadata"],
|
||||
"u": data["metadata"]["user"],
|
||||
"c": data["metadata"]["collection"],
|
||||
},
|
||||
"e": [
|
||||
{
|
||||
"e": ent["entity"],
|
||||
"v": ent["vectors"],
|
||||
}
|
||||
for ent in data["entities"]
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
enc = msgpack.packb(msg)
|
||||
await response.write(enc)
|
||||
|
||||
if "triples" in resp:
|
||||
|
||||
data = resp["triples"]
|
||||
msg = (
|
||||
"t",
|
||||
{
|
||||
"m": {
|
||||
"i": data["metadata"]["id"],
|
||||
"m": data["metadata"]["metadata"],
|
||||
"u": data["metadata"]["user"],
|
||||
"c": data["metadata"]["collection"],
|
||||
},
|
||||
"t": data["triples"],
|
||||
}
|
||||
)
|
||||
|
||||
enc = msgpack.packb(msg)
|
||||
await response.write(enc)
|
||||
|
||||
await kr.process(
|
||||
{
|
||||
"operation": "get-kg-core",
|
||||
"user": user,
|
||||
"id": id,
|
||||
},
|
||||
responder
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e)
|
||||
|
||||
finally:
|
||||
|
||||
await kr.stop()
|
||||
|
||||
await response.write_eof()
|
||||
|
||||
return response
|
||||
|
||||
94
trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
Normal file
94
trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
import msgpack
|
||||
from . knowledge import KnowledgeRequestor
|
||||
|
||||
class CoreImport:
|
||||
|
||||
def __init__(self, pulsar_client):
|
||||
self.pulsar_client = pulsar_client
|
||||
|
||||
async def process(self, data, error, ok, request):
|
||||
|
||||
id = request.query["id"]
|
||||
user = request.query["user"]
|
||||
|
||||
kr = KnowledgeRequestor(
|
||||
pulsar_client = self.pulsar_client,
|
||||
consumer = "api-gateway-core-import-" + str(uuid.uuid4()),
|
||||
subscriber = "api-gateway-core-import-" + str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
await kr.start()
|
||||
|
||||
try:
|
||||
|
||||
unpacker = msgpack.Unpacker()
|
||||
|
||||
while True:
|
||||
buf = await data.read(128*1024)
|
||||
if not buf: break
|
||||
|
||||
unpacker.feed(buf)
|
||||
|
||||
for unpacked in unpacker:
|
||||
|
||||
if unpacked[0] == "t":
|
||||
msg = unpacked[1]
|
||||
msg = {
|
||||
"operation": "put-kg-core",
|
||||
"user": user,
|
||||
"id": id,
|
||||
"triples": {
|
||||
"metadata": {
|
||||
"id": id,
|
||||
"metadata": msg["m"]["m"],
|
||||
"user": user,
|
||||
"collection": "default", # Not used?
|
||||
},
|
||||
"triples": msg["t"],
|
||||
}
|
||||
}
|
||||
|
||||
await kr.process(msg)
|
||||
|
||||
elif unpacked[0] == "ge":
|
||||
msg = unpacked[1]
|
||||
msg = {
|
||||
"operation": "put-kg-core",
|
||||
"user": user,
|
||||
"id": id,
|
||||
"graph-embeddings": {
|
||||
"metadata": {
|
||||
"id": id,
|
||||
"metadata": msg["m"]["m"],
|
||||
"user": user,
|
||||
"collection": "default", # Not used?
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"entity": ent["e"],
|
||||
"vectors": ent["v"],
|
||||
}
|
||||
for ent in msg["e"]
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
await kr.process(msg)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e)
|
||||
await error(str(e))
|
||||
|
||||
finally:
|
||||
|
||||
await kr.stop()
|
||||
|
||||
print("All done.")
|
||||
response = await ok()
|
||||
await response.write_eof()
|
||||
|
||||
return response
|
||||
|
|
@ -6,8 +6,7 @@ from aiohttp import WSMsgType
|
|||
from ... schema import Metadata
|
||||
from ... schema import DocumentEmbeddings, ChunkEmbeddings
|
||||
from ... base import Publisher
|
||||
|
||||
from . serialize import to_subgraph
|
||||
from ... messaging.translators.document_loading import DocumentEmbeddingsTranslator
|
||||
|
||||
class DocumentEmbeddingsImport:
|
||||
|
||||
|
|
@ -17,6 +16,7 @@ class DocumentEmbeddingsImport:
|
|||
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
self.translator = DocumentEmbeddingsTranslator()
|
||||
|
||||
self.publisher = Publisher(
|
||||
pulsar_client, topic = queue, schema = DocumentEmbeddings
|
||||
|
|
@ -36,23 +36,7 @@ class DocumentEmbeddingsImport:
|
|||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
||||
elt = DocumentEmbeddings(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
chunks=[
|
||||
ChunkEmbeddings(
|
||||
chunk=de["chunk"].encode("utf-8"),
|
||||
vectors=de["vectors"],
|
||||
)
|
||||
for de in data["chunks"]
|
||||
],
|
||||
)
|
||||
|
||||
elt = self.translator.to_pulsar(data)
|
||||
await self.publisher.send(None, elt)
|
||||
|
||||
async def run(self):
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
import base64
|
||||
|
||||
from ... schema import Document, Metadata
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . sender import ServiceSender
|
||||
from . serialize import to_subgraph
|
||||
|
||||
class DocumentLoad(ServiceSender):
|
||||
def __init__(self, pulsar_client, queue):
|
||||
|
|
@ -15,26 +15,9 @@ class DocumentLoad(ServiceSender):
|
|||
schema = Document,
|
||||
)
|
||||
|
||||
self.translator = TranslatorRegistry.get_request_translator("document")
|
||||
|
||||
def to_request(self, body):
|
||||
|
||||
if "metadata" in body:
|
||||
metadata = to_subgraph(body["metadata"])
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
# Doing a base64 decoe/encode here to make sure the
|
||||
# content is valid base64
|
||||
doc = base64.b64decode(body["data"])
|
||||
|
||||
print("Document received")
|
||||
|
||||
return Document(
|
||||
metadata=Metadata(
|
||||
id=body.get("id"),
|
||||
metadata=metadata,
|
||||
user=body.get("user", "trustgraph"),
|
||||
collection=body.get("collection", "default"),
|
||||
),
|
||||
data=base64.b64encode(doc).decode("utf-8")
|
||||
)
|
||||
return self.translator.to_pulsar(body)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
from ... schema import DocumentRagQuery, DocumentRagResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
|
|
@ -20,14 +21,12 @@ class DocumentRagRequestor(ServiceRequestor):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("document-rag")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("document-rag")
|
||||
|
||||
def to_request(self, body):
|
||||
return DocumentRagQuery(
|
||||
query=body["query"],
|
||||
user=body.get("user", "trustgraph"),
|
||||
collection=body.get("collection", "default"),
|
||||
doc_limit=int(body.get("doc-limit", 20)),
|
||||
)
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return { "response": message.response }, True
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
|
|
@ -20,11 +21,12 @@ class EmbeddingsRequestor(ServiceRequestor):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("embeddings")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("embeddings")
|
||||
|
||||
def to_request(self, body):
|
||||
return EmbeddingsRequest(
|
||||
text=body["text"]
|
||||
)
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return { "vectors": message.vectors }, True
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
from ... schema import FlowRequest, FlowResponse
|
||||
from ... schema import flow_request_queue
|
||||
from ... schema import flow_response_queue
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
|
|
@ -19,34 +20,12 @@ class FlowRequestor(ServiceRequestor):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("flow")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("flow")
|
||||
|
||||
return FlowRequest(
|
||||
operation = body.get("operation", None),
|
||||
class_name = body.get("class-name", None),
|
||||
class_definition = body.get("class-definition", None),
|
||||
description = body.get("description", None),
|
||||
flow_id = body.get("flow-id", None),
|
||||
)
|
||||
def to_request(self, body):
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
|
||||
response = { }
|
||||
|
||||
if message.class_names is not None:
|
||||
response["class-names"] = message.class_names
|
||||
|
||||
if message.flow_ids is not None:
|
||||
response["flow-ids"] = message.flow_ids
|
||||
|
||||
if message.class_definition is not None:
|
||||
response["class-definition"] = message.class_definition
|
||||
|
||||
if message.flow is not None:
|
||||
response["flow"] = message.flow
|
||||
|
||||
if message.description is not None:
|
||||
response["description"] = message.description
|
||||
|
||||
return response, True
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
|
||||
from ... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
from . serialize import serialize_value
|
||||
|
||||
class GraphEmbeddingsQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
|
|
@ -21,22 +21,12 @@ class GraphEmbeddingsQueryRequestor(ServiceRequestor):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("graph-embeddings-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("graph-embeddings-query")
|
||||
|
||||
def to_request(self, body):
|
||||
|
||||
limit = int(body.get("limit", 20))
|
||||
|
||||
return GraphEmbeddingsRequest(
|
||||
vectors = body["vectors"],
|
||||
limit = limit,
|
||||
user = body.get("user", "trustgraph"),
|
||||
collection = body.get("collection", "default"),
|
||||
)
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
|
||||
return {
|
||||
"entities": [
|
||||
serialize_value(ent) for ent in message.entities
|
||||
]
|
||||
}, True
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
from ... schema import GraphRagQuery, GraphRagResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
|
|
@ -20,17 +21,12 @@ class GraphRagRequestor(ServiceRequestor):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("graph-rag")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("graph-rag")
|
||||
|
||||
def to_request(self, body):
|
||||
return GraphRagQuery(
|
||||
query=body["query"],
|
||||
user=body.get("user", "trustgraph"),
|
||||
collection=body.get("collection", "default"),
|
||||
entity_limit=int(body.get("entity-limit", 50)),
|
||||
triple_limit=int(body.get("triple-limit", 30)),
|
||||
max_subgraph_size=int(body.get("max-subgraph-size", 1000)),
|
||||
max_path_length=int(body.get("max-path-length", 2)),
|
||||
)
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return { "response": message.response }, True
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,11 +5,9 @@ from ... schema import KnowledgeRequest, KnowledgeResponse, Triples
|
|||
from ... schema import GraphEmbeddings, Metadata, EntityEmbeddings
|
||||
from ... schema import knowledge_request_queue
|
||||
from ... schema import knowledge_response_queue
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
from . serialize import serialize_graph_embeddings
|
||||
from . serialize import serialize_triples, to_subgraph, to_value
|
||||
from . serialize import to_document_metadata, to_processing_metadata
|
||||
|
||||
class KnowledgeRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
|
||||
|
|
@ -25,73 +23,12 @@ class KnowledgeRequestor(ServiceRequestor):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("knowledge")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("knowledge")
|
||||
|
||||
def to_request(self, body):
|
||||
|
||||
if "triples" in body:
|
||||
triples = Triples(
|
||||
metadata=Metadata(
|
||||
id = body["triples"]["metadata"]["id"],
|
||||
metadata = to_subgraph(body["triples"]["metadata"]["metadata"]),
|
||||
user = body["triples"]["metadata"]["user"],
|
||||
),
|
||||
triples = to_subgraph(body["triples"]["triples"]),
|
||||
)
|
||||
else:
|
||||
triples = None
|
||||
|
||||
if "graph-embeddings" in body:
|
||||
ge = GraphEmbeddings(
|
||||
metadata = Metadata(
|
||||
id = body["graph-embeddings"]["metadata"]["id"],
|
||||
metadata = to_subgraph(body["graph-embeddings"]["metadata"]["metadata"]),
|
||||
user = body["graph-embeddings"]["metadata"]["user"],
|
||||
),
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity = to_value(ent["entity"]),
|
||||
vectors = ent["vectors"],
|
||||
)
|
||||
for ent in body["graph-embeddings"]["entities"]
|
||||
]
|
||||
)
|
||||
else:
|
||||
ge = None
|
||||
|
||||
return KnowledgeRequest(
|
||||
operation = body.get("operation", None),
|
||||
user = body.get("user", None),
|
||||
id = body.get("id", None),
|
||||
flow = body.get("flow", None),
|
||||
collection = body.get("collection", None),
|
||||
triples = triples,
|
||||
graph_embeddings = ge,
|
||||
)
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
|
||||
# Response to list,
|
||||
if message.ids is not None:
|
||||
return {
|
||||
"ids": message.ids
|
||||
}, True
|
||||
|
||||
if message.triples:
|
||||
return {
|
||||
"triples": serialize_triples(message.triples)
|
||||
}, False
|
||||
|
||||
if message.graph_embeddings:
|
||||
return {
|
||||
"graph-embeddings": serialize_graph_embeddings(
|
||||
message.graph_embeddings
|
||||
)
|
||||
}, False
|
||||
|
||||
if message.eos is True:
|
||||
return {
|
||||
"eos": True
|
||||
}, True
|
||||
|
||||
# Empty case, return from successful delete.
|
||||
return {}, True
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,9 @@ import base64
|
|||
from ... schema import LibrarianRequest, LibrarianResponse
|
||||
from ... schema import librarian_request_queue
|
||||
from ... schema import librarian_response_queue
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
from . serialize import serialize_document_metadata
|
||||
from . serialize import serialize_processing_metadata
|
||||
from . serialize import to_document_metadata, to_processing_metadata
|
||||
from . serialize import to_criteria
|
||||
|
||||
class LibrarianRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
|
||||
|
|
@ -25,67 +22,20 @@ class LibrarianRequestor(ServiceRequestor):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("librarian")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("librarian")
|
||||
|
||||
def to_request(self, body):
|
||||
|
||||
# Content gets base64 decoded & encoded again. It at least makes
|
||||
# sure payload is valid base64.
|
||||
|
||||
if "document-metadata" in body:
|
||||
dm = to_document_metadata(body["document-metadata"])
|
||||
else:
|
||||
dm = None
|
||||
|
||||
if "processing-metadata" in body:
|
||||
pm = to_processing_metadata(body["processing-metadata"])
|
||||
else:
|
||||
pm = None
|
||||
|
||||
if "criteria" in body:
|
||||
criteria = to_criteria(body["criteria"])
|
||||
else:
|
||||
criteria = None
|
||||
|
||||
# Handle base64 content processing
|
||||
if "content" in body:
|
||||
# Content gets base64 decoded & encoded again to ensure valid base64
|
||||
content = base64.b64decode(body["content"].encode("utf-8"))
|
||||
content = base64.b64encode(content).decode("utf-8")
|
||||
else:
|
||||
content = None
|
||||
|
||||
return LibrarianRequest(
|
||||
operation = body.get("operation", None),
|
||||
document_id = body.get("document-id", None),
|
||||
processing_id = body.get("processing-id", None),
|
||||
document_metadata = dm,
|
||||
processing_metadata = pm,
|
||||
content = content,
|
||||
user = body.get("user", None),
|
||||
collection = body.get("collection", None),
|
||||
criteria = criteria,
|
||||
)
|
||||
body = body.copy()
|
||||
body["content"] = content
|
||||
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
|
||||
response = {}
|
||||
|
||||
if message.document_metadata:
|
||||
response["document-metadata"] = serialize_document_metadata(
|
||||
message.document_metadata
|
||||
)
|
||||
|
||||
if message.content:
|
||||
response["content"] = message.content.decode("utf-8")
|
||||
|
||||
if message.document_metadatas != None:
|
||||
response["document-metadatas"] = [
|
||||
serialize_document_metadata(v)
|
||||
for v in message.document_metadatas
|
||||
]
|
||||
|
||||
if message.processing_metadatas != None:
|
||||
response["processing-metadatas"] = [
|
||||
serialize_processing_metadata(v)
|
||||
for v in message.processing_metadatas
|
||||
]
|
||||
|
||||
return response, True
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
import asyncio
|
||||
from aiohttp import web
|
||||
import uuid
|
||||
|
||||
from . config import ConfigRequestor
|
||||
|
|
@ -30,6 +31,9 @@ from . graph_embeddings_import import GraphEmbeddingsImport
|
|||
from . document_embeddings_import import DocumentEmbeddingsImport
|
||||
from . entity_contexts_import import EntityContextsImport
|
||||
|
||||
from . core_export import CoreExport
|
||||
from . core_import import CoreImport
|
||||
|
||||
from . mux import Mux
|
||||
|
||||
request_response_dispatchers = {
|
||||
|
|
@ -77,10 +81,11 @@ class DispatcherWrapper:
|
|||
|
||||
class DispatcherManager:
|
||||
|
||||
def __init__(self, pulsar_client, config_receiver):
|
||||
def __init__(self, pulsar_client, config_receiver, prefix="api-gateway"):
|
||||
self.pulsar_client = pulsar_client
|
||||
self.config_receiver = config_receiver
|
||||
self.config_receiver.add_handler(self)
|
||||
self.prefix = prefix
|
||||
|
||||
self.flows = {}
|
||||
self.dispatchers = {}
|
||||
|
|
@ -98,6 +103,22 @@ class DispatcherManager:
|
|||
def dispatch_global_service(self):
|
||||
return DispatcherWrapper(self.process_global_service)
|
||||
|
||||
def dispatch_core_export(self):
|
||||
return DispatcherWrapper(self.process_core_export)
|
||||
|
||||
def dispatch_core_import(self):
|
||||
return DispatcherWrapper(self.process_core_import)
|
||||
|
||||
async def process_core_import(self, data, error, ok, request):
|
||||
|
||||
ci = CoreImport(self.pulsar_client)
|
||||
return await ci.process(data, error, ok, request)
|
||||
|
||||
async def process_core_export(self, data, error, ok, request):
|
||||
|
||||
ce = CoreExport(self.pulsar_client)
|
||||
return await ce.process(data, error, ok, request)
|
||||
|
||||
async def process_global_service(self, data, responder, params):
|
||||
|
||||
kind = params.get("kind")
|
||||
|
|
@ -113,8 +134,8 @@ class DispatcherManager:
|
|||
dispatcher = global_dispatchers[kind](
|
||||
pulsar_client = self.pulsar_client,
|
||||
timeout = 120,
|
||||
consumer = f"api-gateway-{kind}-request",
|
||||
subscriber = f"api-gateway-{kind}-request",
|
||||
consumer = f"{self.prefix}-{kind}-request",
|
||||
subscriber = f"{self.prefix}-{kind}-request",
|
||||
)
|
||||
|
||||
await dispatcher.start()
|
||||
|
|
@ -206,8 +227,8 @@ class DispatcherManager:
|
|||
ws = ws,
|
||||
running = running,
|
||||
queue = qconfig,
|
||||
consumer = f"api-gateway-{id}",
|
||||
subscriber = f"api-gateway-{id}",
|
||||
consumer = f"{self.prefix}-{id}",
|
||||
subscriber = f"{self.prefix}-{id}",
|
||||
)
|
||||
|
||||
return dispatcher
|
||||
|
|
@ -248,8 +269,8 @@ class DispatcherManager:
|
|||
request_queue = qconfig["request"],
|
||||
response_queue = qconfig["response"],
|
||||
timeout = 120,
|
||||
consumer = f"api-gateway-{flow}-{kind}-request",
|
||||
subscriber = f"api-gateway-{flow}-{kind}-request",
|
||||
consumer = f"{self.prefix}-{flow}-{kind}-request",
|
||||
subscriber = f"{self.prefix}-{flow}-{kind}-request",
|
||||
)
|
||||
elif kind in sender_dispatchers:
|
||||
dispatcher = sender_dispatchers[kind](
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
import json
|
||||
|
||||
from ... schema import PromptRequest, PromptResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
|
|
@ -22,22 +23,12 @@ class PromptRequestor(ServiceRequestor):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("prompt")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("prompt")
|
||||
|
||||
def to_request(self, body):
|
||||
return PromptRequest(
|
||||
id=body["id"],
|
||||
terms={
|
||||
k: json.dumps(v)
|
||||
for k, v in body["variables"].items()
|
||||
}
|
||||
)
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
if message.object:
|
||||
return {
|
||||
"object": message.object
|
||||
}, True
|
||||
else:
|
||||
return {
|
||||
"text": message.text
|
||||
}, True
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,13 @@ import base64
|
|||
|
||||
from ... schema import Value, Triple, DocumentMetadata, ProcessingMetadata
|
||||
|
||||
# DEPRECATED: These functions have been moved to trustgraph.... messaging.translators
|
||||
# Use the new messaging translation system instead for consistency and reusability.
|
||||
# Examples:
|
||||
# from trustgraph.... messaging.translators.primitives import ValueTranslator
|
||||
# value_translator = ValueTranslator()
|
||||
# pulsar_value = value_translator.to_pulsar({"v": "example", "e": True})
|
||||
|
||||
def to_value(x):
|
||||
return Value(value=x["v"], is_uri=x["e"])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
from ... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
|
|
@ -20,12 +21,12 @@ class TextCompletionRequestor(ServiceRequestor):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("text-completion")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("text-completion")
|
||||
|
||||
def to_request(self, body):
|
||||
return TextCompletionRequest(
|
||||
system=body["system"],
|
||||
prompt=body["prompt"]
|
||||
)
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return { "response": message.response }, True
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
import base64
|
||||
|
||||
from ... schema import TextDocument, Metadata
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . sender import ServiceSender
|
||||
from . serialize import to_subgraph
|
||||
|
||||
class TextLoad(ServiceSender):
|
||||
def __init__(self, pulsar_client, queue):
|
||||
|
|
@ -15,30 +15,9 @@ class TextLoad(ServiceSender):
|
|||
schema = TextDocument,
|
||||
)
|
||||
|
||||
self.translator = TranslatorRegistry.get_request_translator("text-document")
|
||||
|
||||
def to_request(self, body):
|
||||
|
||||
if "metadata" in body:
|
||||
metadata = to_subgraph(body["metadata"])
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
if "charset" in body:
|
||||
charset = body["charset"]
|
||||
else:
|
||||
charset = "utf-8"
|
||||
|
||||
# Text is base64 encoded
|
||||
text = base64.b64decode(body["text"]).decode(charset)
|
||||
|
||||
print("Text document received")
|
||||
|
||||
return TextDocument(
|
||||
metadata=Metadata(
|
||||
id=body.get("id"),
|
||||
metadata=metadata,
|
||||
user=body.get("user", "trustgraph"),
|
||||
collection=body.get("collection", "default"),
|
||||
),
|
||||
text=text,
|
||||
)
|
||||
return self.translator.to_pulsar(body)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
|
||||
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples
|
||||
from ... messaging import TranslatorRegistry
|
||||
|
||||
from . requestor import ServiceRequestor
|
||||
from . serialize import to_value, serialize_subgraph
|
||||
|
||||
class TriplesQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
|
|
@ -21,34 +21,12 @@ class TriplesQueryRequestor(ServiceRequestor):
|
|||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.request_translator = TranslatorRegistry.get_request_translator("triples-query")
|
||||
self.response_translator = TranslatorRegistry.get_response_translator("triples-query")
|
||||
|
||||
def to_request(self, body):
|
||||
|
||||
if "s" in body:
|
||||
s = to_value(body["s"])
|
||||
else:
|
||||
s = None
|
||||
|
||||
if "p" in body:
|
||||
p = to_value(body["p"])
|
||||
else:
|
||||
p = None
|
||||
|
||||
if "o" in body:
|
||||
o = to_value(body["o"])
|
||||
else:
|
||||
o = None
|
||||
|
||||
limit = int(body.get("limit", 10000))
|
||||
|
||||
return TriplesQueryRequest(
|
||||
s = s, p = p, o = o,
|
||||
limit = limit,
|
||||
user = body.get("user", "trustgraph"),
|
||||
collection = body.get("collection", "default"),
|
||||
)
|
||||
return self.request_translator.to_pulsar(body)
|
||||
|
||||
def from_response(self, message):
|
||||
return {
|
||||
"response": serialize_subgraph(message.triples)
|
||||
}, True
|
||||
return self.response_translator.from_response_with_completion(message)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import asyncio
|
|||
|
||||
from aiohttp import web
|
||||
|
||||
from . constant_endpoint import ConstantEndpoint
|
||||
from . stream_endpoint import StreamEndpoint
|
||||
from . variable_endpoint import VariableEndpoint
|
||||
from . socket import SocketEndpoint
|
||||
from . metrics import MetricsEndpoint
|
||||
|
|
@ -52,6 +52,18 @@ class EndpointManager:
|
|||
auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_flow_export()
|
||||
),
|
||||
StreamEndpoint(
|
||||
endpoint_path = "/api/v1/import-core",
|
||||
auth = auth,
|
||||
method = "POST",
|
||||
dispatcher = dispatcher_manager.dispatch_core_import(),
|
||||
),
|
||||
StreamEndpoint(
|
||||
endpoint_path = "/api/v1/export-core",
|
||||
auth = auth,
|
||||
method = "GET",
|
||||
dispatcher = dispatcher_manager.dispatch_core_export(),
|
||||
),
|
||||
]
|
||||
|
||||
def add_routes(self, app):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,82 @@
|
|||
|
||||
import asyncio
|
||||
from aiohttp import web
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("endpoint")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class StreamEndpoint:
|
||||
|
||||
def __init__(self, endpoint_path, auth, dispatcher, method="POST"):
|
||||
|
||||
self.path = endpoint_path
|
||||
|
||||
self.auth = auth
|
||||
self.operation = "service"
|
||||
self.method = method
|
||||
|
||||
self.dispatcher = dispatcher
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
def add_routes(self, app):
|
||||
|
||||
if self.method == "POST":
|
||||
app.add_routes([
|
||||
web.post(self.path, self.handle),
|
||||
])
|
||||
elif self.method == "GET":
|
||||
app.add_routes([
|
||||
web.get(self.path, self.handle),
|
||||
])
|
||||
else:
|
||||
raise RuntimeError("Bad method" + self.method)
|
||||
|
||||
async def handle(self, request):
|
||||
|
||||
print(request.path, "...")
|
||||
|
||||
try:
|
||||
ht = request.headers["Authorization"]
|
||||
tokens = ht.split(" ", 2)
|
||||
if tokens[0] != "Bearer":
|
||||
return web.HTTPUnauthorized()
|
||||
token = tokens[1]
|
||||
except:
|
||||
token = ""
|
||||
|
||||
if not self.auth.permitted(token, self.operation):
|
||||
return web.HTTPUnauthorized()
|
||||
|
||||
try:
|
||||
|
||||
data = request.content
|
||||
|
||||
async def error(err):
|
||||
return web.HTTPInternalServerError(text = err)
|
||||
|
||||
async def ok(
|
||||
status=200, reason="OK", type="application/octet-stream"
|
||||
):
|
||||
response = web.StreamResponse(
|
||||
status = status, reason = reason,
|
||||
headers = {"Content-Type": type}
|
||||
)
|
||||
await response.prepare(request)
|
||||
return response
|
||||
|
||||
resp = await self.dispatcher.process(
|
||||
data, error, ok, request
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
|
||||
import asyncio
|
||||
from aiohttp import web
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("endpoint")
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class Api:
|
|||
self.dispatcher_manager = DispatcherManager(
|
||||
pulsar_client = self.pulsar_client,
|
||||
config_receiver = self.config_receiver,
|
||||
prefix = "gateway",
|
||||
)
|
||||
|
||||
self.endpoint_manager = EndpointManager(
|
||||
|
|
|
|||
|
|
@ -18,12 +18,14 @@ from .... base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
|
|||
from . prompt_manager import PromptConfiguration, Prompt, PromptManager
|
||||
|
||||
default_ident = "prompt"
|
||||
default_concurrency = 1
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id")
|
||||
concurrency = params.get("concurrency", 1)
|
||||
|
||||
# Config key for prompts
|
||||
self.config_key = params.get("config_type", "prompt")
|
||||
|
|
@ -31,6 +33,7 @@ class Processor(FlowProcessor):
|
|||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"concurrency": concurrency,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -38,7 +41,8 @@ class Processor(FlowProcessor):
|
|||
ConsumerSpec(
|
||||
name = "request",
|
||||
schema = PromptRequest,
|
||||
handler = self.on_request
|
||||
handler = self.on_request,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -219,6 +223,13 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/vllm/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/vllm/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
138
trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py
Executable file
138
trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py
Executable file
|
|
@ -0,0 +1,138 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using vLLM
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import os
|
||||
import aiohttp
|
||||
|
||||
from .... exceptions import TooManyRequests
|
||||
from .... base import LlmService, LlmResult
|
||||
|
||||
default_ident = "text-completion"
|
||||
|
||||
default_temperature = 0.0
|
||||
default_max_output = 2048
|
||||
default_base_url = os.getenv("VLLM_BASE_URL")
|
||||
default_model = "TheBloke/Mistral-7B-v0.1-AWQ"
|
||||
|
||||
if default_base_url == "" or default_base_url is None:
|
||||
default_base_url = "http://vllm-service:8899/v1"
|
||||
|
||||
class Processor(LlmService):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
base_url = params.get("url", default_base_url)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
model = params.get("model", default_model)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
"url": base_url,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
|
||||
self.base_url = base_url
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.model = model
|
||||
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
print("Using vLLM service at", base_url)
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
async def generate_content(self, system, prompt):
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
request = {
|
||||
"model": self.model,
|
||||
"prompt": system + "\n\n" + prompt,
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
try:
|
||||
|
||||
url = f"{self.base_url}/completions"
|
||||
|
||||
async with self.session.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=request,
|
||||
) as response:
|
||||
|
||||
if response.status != 200:
|
||||
raise RuntimeError("Bad status: " + str(response.status))
|
||||
|
||||
resp = await response.json()
|
||||
|
||||
inputtokens = resp["usage"]["prompt_tokens"]
|
||||
outputtokens = resp["usage"]["completion_tokens"]
|
||||
ans = resp["choices"][0]["text"]
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
print(ans, flush=True)
|
||||
|
||||
resp = LlmResult(
|
||||
text = ans,
|
||||
in_token = inputtokens,
|
||||
out_token = outputtokens,
|
||||
model = self.model,
|
||||
)
|
||||
|
||||
return resp
|
||||
|
||||
# FIXME: Assuming vLLM won't produce rate limits?
|
||||
|
||||
except Exception as e:
|
||||
|
||||
# Apart from rate limits, treat all exceptions as unrecoverable
|
||||
|
||||
print(f"Exception: {type(e)} {e}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
LlmService.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--url',
|
||||
default=default_base_url,
|
||||
help=f'vLLM service base URL (default: {default_base_url})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default=default_model,
|
||||
help=f'LLM model (default: {default_model})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.launch(default_ident, __doc__)
|
||||
|
|
@ -11,12 +11,14 @@ from ... base import PromptClientSpec, EmbeddingsClientSpec
|
|||
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
||||
|
||||
default_ident = "graph-rag"
|
||||
default_concurrency = 1
|
||||
|
||||
class Processor(FlowProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
id = params.get("id", default_ident)
|
||||
concurrency = params.get("concurrency", 1)
|
||||
|
||||
entity_limit = params.get("entity_limit", 50)
|
||||
triple_limit = params.get("triple_limit", 30)
|
||||
|
|
@ -26,6 +28,7 @@ class Processor(FlowProcessor):
|
|||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"id": id,
|
||||
"concurrency": concurrency,
|
||||
"entity_limit": entity_limit,
|
||||
"triple_limit": triple_limit,
|
||||
"max_subgraph_size": max_subgraph_size,
|
||||
|
|
@ -43,6 +46,7 @@ class Processor(FlowProcessor):
|
|||
name = "request",
|
||||
schema = GraphRagQuery,
|
||||
handler = self.on_request,
|
||||
concurrency = concurrency,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -157,6 +161,13 @@ class Processor(FlowProcessor):
|
|||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--concurrency',
|
||||
type=int,
|
||||
default=default_concurrency,
|
||||
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||
)
|
||||
|
||||
FlowProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
|
|||
1
trustgraph-flow/trustgraph/rev_gateway/__init__.py
Normal file
1
trustgraph-flow/trustgraph/rev_gateway/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from . service import run
|
||||
11
trustgraph-flow/trustgraph/rev_gateway/__main__.py
Normal file
11
trustgraph-flow/trustgraph/rev_gateway/__main__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import logging
|
||||
from .service import run
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
||||
|
||||
130
trustgraph-flow/trustgraph/rev_gateway/dispatcher.py
Normal file
130
trustgraph-flow/trustgraph/rev_gateway/dispatcher.py
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional
|
||||
from trustgraph.messaging import TranslatorRegistry
|
||||
from ..gateway.dispatch.manager import DispatcherManager
|
||||
|
||||
logger = logging.getLogger("dispatcher")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class WebSocketResponder:
|
||||
"""Simple responder that captures response for websocket return"""
|
||||
def __init__(self):
|
||||
self.response = None
|
||||
self.completed = False
|
||||
|
||||
async def send(self, data):
|
||||
"""Capture the response data"""
|
||||
self.response = data
|
||||
self.completed = True
|
||||
|
||||
async def __call__(self, data, final=False):
|
||||
"""Make the responder callable for compatibility with requestor"""
|
||||
await self.send(data)
|
||||
if final:
|
||||
self.completed = True
|
||||
|
||||
class MessageDispatcher:
|
||||
|
||||
def __init__(self, max_workers: int = 10, config_receiver=None, pulsar_client=None):
|
||||
self.max_workers = max_workers
|
||||
self.semaphore = asyncio.Semaphore(max_workers)
|
||||
self.active_tasks = set()
|
||||
self.pulsar_client = pulsar_client
|
||||
|
||||
# Use DispatcherManager for flow and service management
|
||||
if pulsar_client and config_receiver:
|
||||
self.dispatcher_manager = DispatcherManager(pulsar_client, config_receiver, prefix="rev-gateway")
|
||||
else:
|
||||
self.dispatcher_manager = None
|
||||
logger.warning("No pulsar_client or config_receiver provided - using fallback mode")
|
||||
|
||||
# Service name mapping from websocket protocol to translator registry
|
||||
self.service_mapping = {
|
||||
"text-completion": "text-completion",
|
||||
"graph-rag": "graph-rag",
|
||||
"agent": "agent",
|
||||
"embeddings": "embeddings",
|
||||
"graph-embeddings": "graph-embeddings",
|
||||
"triples": "triples",
|
||||
"document-load": "document",
|
||||
"text-load": "text-document",
|
||||
"flow": "flow",
|
||||
"knowledge": "knowledge",
|
||||
"config": "config",
|
||||
"librarian": "librarian",
|
||||
"document-rag": "document-rag"
|
||||
}
|
||||
|
||||
async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
|
||||
async with self.semaphore:
|
||||
task = asyncio.create_task(self._process_message(message))
|
||||
self.active_tasks.add(task)
|
||||
|
||||
try:
|
||||
result = await task
|
||||
return result
|
||||
finally:
|
||||
self.active_tasks.discard(task)
|
||||
|
||||
async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
request_id = message.get('id', str(uuid.uuid4()))
|
||||
service = message.get('service')
|
||||
request_data = message.get('request', {})
|
||||
flow_id = message.get('flow', 'default') # Default flow
|
||||
|
||||
logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}")
|
||||
|
||||
try:
|
||||
if not self.dispatcher_manager:
|
||||
raise RuntimeError("DispatcherManager not available - pulsar_client and config_receiver required")
|
||||
|
||||
# Use DispatcherManager for flow-based processing
|
||||
responder = WebSocketResponder()
|
||||
|
||||
# Map websocket service name to dispatcher service name
|
||||
dispatcher_service = self.service_mapping.get(service, service)
|
||||
|
||||
# Check if this is a global service or flow service
|
||||
from ..gateway.dispatch.manager import global_dispatchers
|
||||
if dispatcher_service in global_dispatchers:
|
||||
# Use global service dispatcher
|
||||
await self.dispatcher_manager.invoke_global_service(
|
||||
request_data, responder, dispatcher_service
|
||||
)
|
||||
else:
|
||||
# Use DispatcherManager to process the request through Pulsar queues
|
||||
await self.dispatcher_manager.invoke_flow_service(
|
||||
request_data, responder, flow_id, dispatcher_service
|
||||
)
|
||||
|
||||
# Get the response from the responder
|
||||
if responder.completed:
|
||||
response_data = responder.response
|
||||
else:
|
||||
response_data = {'error': 'No response received'}
|
||||
|
||||
response = {
|
||||
'id': request_id,
|
||||
'response': response_data
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message {request_id}: {e}")
|
||||
response = {
|
||||
'id': request_id,
|
||||
'response': {'error': str(e)}
|
||||
}
|
||||
|
||||
logger.info(f"Completed processing message {request_id}")
|
||||
return response
|
||||
|
||||
|
||||
async def shutdown(self):
|
||||
if self.active_tasks:
|
||||
logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete")
|
||||
await asyncio.gather(*self.active_tasks, return_exceptions=True)
|
||||
|
||||
# DispatcherManager handles its own cleanup
|
||||
logger.info("Dispatcher shutdown complete")
|
||||
242
trustgraph-flow/trustgraph/rev_gateway/service.py
Normal file
242
trustgraph-flow/trustgraph/rev_gateway/service.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
import asyncio
|
||||
import argparse
|
||||
import logging
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from aiohttp import ClientSession, WSMsgType, ClientWebSocketResponse
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
import pulsar
|
||||
|
||||
from .dispatcher import MessageDispatcher
|
||||
from ..gateway.config.receiver import ConfigReceiver
|
||||
|
||||
logger = logging.getLogger("rev_gateway")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
default_websocket = "ws://localhost:7650/out"
|
||||
|
||||
class ReverseGateway:
|
||||
|
||||
def __init__(self, websocket_uri: str = None, max_workers: int = 10,
|
||||
pulsar_host: str = None, pulsar_api_key: str = None,
|
||||
pulsar_listener: str = None):
|
||||
# Set default WebSocket URI with environment variable support
|
||||
if websocket_uri is None:
|
||||
websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket)
|
||||
|
||||
# Parse and validate the WebSocket URI
|
||||
parsed_uri = urlparse(websocket_uri)
|
||||
if parsed_uri.scheme not in ('ws', 'wss'):
|
||||
raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}")
|
||||
if not parsed_uri.netloc:
|
||||
raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}")
|
||||
|
||||
# Store parsed components for debugging/logging
|
||||
self.websocket_uri = websocket_uri
|
||||
self.host = parsed_uri.hostname
|
||||
self.port = parsed_uri.port
|
||||
self.scheme = parsed_uri.scheme
|
||||
self.path = parsed_uri.path or "/ws"
|
||||
|
||||
# Construct the full URL (in case path was missing)
|
||||
if not parsed_uri.path:
|
||||
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
|
||||
else:
|
||||
self.url = websocket_uri
|
||||
|
||||
self.max_workers = max_workers
|
||||
self.ws: Optional[ClientWebSocketResponse] = None
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.running = False
|
||||
self.reconnect_delay = 3.0
|
||||
|
||||
# Pulsar configuration
|
||||
self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
|
||||
self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None)
|
||||
self.pulsar_listener = pulsar_listener
|
||||
|
||||
# Initialize Pulsar client
|
||||
if self.pulsar_api_key:
|
||||
self.pulsar_client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
listener_name=self.pulsar_listener,
|
||||
authentication=pulsar.AuthenticationToken(self.pulsar_api_key)
|
||||
)
|
||||
else:
|
||||
self.pulsar_client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
listener_name=self.pulsar_listener
|
||||
)
|
||||
|
||||
# Initialize config receiver
|
||||
self.config_receiver = ConfigReceiver(self.pulsar_client)
|
||||
|
||||
# Initialize dispatcher with config_receiver and pulsar_client - must be created after config_receiver
|
||||
self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.pulsar_client)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
try:
|
||||
if self.session is None:
|
||||
self.session = ClientSession()
|
||||
|
||||
logger.info(f"Connecting to {self.url}")
|
||||
self.ws = await self.session.ws_connect(self.url)
|
||||
logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to {self.url}: {e}")
|
||||
return False
|
||||
|
||||
async def disconnect(self):
|
||||
if self.ws and not self.ws.closed:
|
||||
await self.ws.close()
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
self.ws = None
|
||||
self.session = None
|
||||
|
||||
async def send_message(self, message: dict):
|
||||
if self.ws and not self.ws.closed:
|
||||
try:
|
||||
await self.ws.send_str(json.dumps(message))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message: {e}")
|
||||
|
||||
async def handle_message(self, message: str):
|
||||
try:
|
||||
print(f"Received: {message}", flush=True)
|
||||
|
||||
msg_data = json.loads(message)
|
||||
response = await self.dispatcher.handle_message(msg_data)
|
||||
|
||||
if response:
|
||||
await self.send_message(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
|
||||
async def listen(self):
|
||||
while self.running and self.ws and not self.ws.closed:
|
||||
try:
|
||||
msg = await self.ws.receive()
|
||||
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await self.handle_message(msg.data)
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
await self.handle_message(msg.data.decode('utf-8'))
|
||||
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
|
||||
logger.warning("WebSocket closed or error occurred")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in listen loop: {e}")
|
||||
break
|
||||
|
||||
async def run(self):
|
||||
self.running = True
|
||||
logger.info("Starting reverse gateway")
|
||||
|
||||
# Start config receiver
|
||||
logger.info("Starting config receiver")
|
||||
await self.config_receiver.start()
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
if await self.connect():
|
||||
await self.listen()
|
||||
else:
|
||||
logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds")
|
||||
|
||||
await self.disconnect()
|
||||
|
||||
if self.running:
|
||||
await asyncio.sleep(self.reconnect_delay)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutdown requested")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
if self.running:
|
||||
await asyncio.sleep(self.reconnect_delay)
|
||||
|
||||
await self.shutdown()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Shutting down reverse gateway")
|
||||
self.running = False
|
||||
await self.dispatcher.shutdown()
|
||||
await self.disconnect()
|
||||
|
||||
# Close Pulsar client
|
||||
if hasattr(self, 'pulsar_client'):
|
||||
self.pulsar_client.close()
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="reverse-gateway",
|
||||
description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--websocket-uri',
|
||||
default=None,
|
||||
help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--max-workers',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Maximum concurrent message handlers (default: 10)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-p', '--pulsar-host',
|
||||
default=None,
|
||||
help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-api-key',
|
||||
default=None,
|
||||
help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--pulsar-listener',
|
||||
default=None,
|
||||
help='Pulsar listener name'
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def run():
|
||||
args = parse_args()
|
||||
|
||||
gateway = ReverseGateway(
|
||||
websocket_uri=args.websocket_uri,
|
||||
max_workers=args.max_workers,
|
||||
pulsar_host=args.pulsar_host,
|
||||
pulsar_api_key=args.pulsar_api_key,
|
||||
pulsar_listener=args.pulsar_listener
|
||||
)
|
||||
|
||||
print(f"Starting reverse gateway:")
|
||||
print(f" WebSocket URI: {gateway.url}")
|
||||
print(f" Max workers: {args.max_workers}")
|
||||
print(f" Pulsar host: {gateway.pulsar_host}")
|
||||
|
||||
try:
|
||||
asyncio.run(gateway.run())
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutdown requested by user")
|
||||
except Exception as e:
|
||||
print(f"Fatal error: {e}")
|
||||
sys.exit(1)
|
||||
Loading…
Add table
Add a link
Reference in a new issue