mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-07-03 23:11:00 +02:00
Merge branch 'release/v1.0'
This commit is contained in:
commit
dbe78ebe46
66 changed files with 2706 additions and 528 deletions
|
|
@ -9,6 +9,7 @@ FROM docker.io/fedora:42 AS base
|
||||||
ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
ENV PIP_BREAK_SYSTEM_PACKAGES=1
|
||||||
|
|
||||||
RUN dnf install -y python3.12 && \
|
RUN dnf install -y python3.12 && \
|
||||||
|
dnf install -y tesseract poppler-utils && \
|
||||||
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
|
alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
|
||||||
python -m ensurepip --upgrade && \
|
python -m ensurepip --upgrade && \
|
||||||
pip3 install --no-cache-dir wheel aiohttp && \
|
pip3 install --no-cache-dir wheel aiohttp && \
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ from prometheus_client import start_http_server, Info
|
||||||
|
|
||||||
from .. schema import ConfigPush, config_push_queue
|
from .. schema import ConfigPush, config_push_queue
|
||||||
from .. log_level import LogLevel
|
from .. log_level import LogLevel
|
||||||
from .. exceptions import TooManyRequests
|
|
||||||
from . pubsub import PulsarClient
|
from . pubsub import PulsarClient
|
||||||
from . producer import Producer
|
from . producer import Producer
|
||||||
from . consumer import Consumer
|
from . consumer import Consumer
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,14 @@
|
||||||
|
|
||||||
|
# Consumer is similar to subscriber: It takes information from a queue
|
||||||
|
# and passes on to a processor function. This is the main receiving
|
||||||
|
# loop for TrustGraph processors. Incorporates retry functionality
|
||||||
|
|
||||||
|
# Note: there is a 'defect' in the system which is tolerated, althought
|
||||||
|
# the processing handlers are async functions, ideally implementation
|
||||||
|
# would use all async code. In practice if the processor only implements
|
||||||
|
# one handler, and a single thread of concurrency, nothing too outrageous
|
||||||
|
# will happen if synchronous / blocking code is used
|
||||||
|
|
||||||
from pulsar.schema import JsonSchema
|
from pulsar.schema import JsonSchema
|
||||||
import pulsar
|
import pulsar
|
||||||
import _pulsar
|
import _pulsar
|
||||||
|
|
@ -16,6 +26,7 @@ class Consumer:
|
||||||
start_of_messages=False,
|
start_of_messages=False,
|
||||||
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
|
rate_limit_retry_time = 10, rate_limit_timeout = 7200,
|
||||||
reconnect_time = 5,
|
reconnect_time = 5,
|
||||||
|
concurrency = 1, # Number of concurrent requests to handle
|
||||||
):
|
):
|
||||||
|
|
||||||
self.taskgroup = taskgroup
|
self.taskgroup = taskgroup
|
||||||
|
|
@ -34,7 +45,9 @@ class Consumer:
|
||||||
self.start_of_messages = start_of_messages
|
self.start_of_messages = start_of_messages
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
self.task = None
|
self.consumer_task = None
|
||||||
|
|
||||||
|
self.concurrency = concurrency
|
||||||
|
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
|
|
||||||
|
|
@ -52,7 +65,11 @@ class Consumer:
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
|
|
||||||
self.running = False
|
self.running = False
|
||||||
await self.task
|
|
||||||
|
if self.consumer_task:
|
||||||
|
await self.consumer_task
|
||||||
|
|
||||||
|
self.consumer_task = None
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
|
||||||
|
|
@ -62,9 +79,9 @@ class Consumer:
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.state("stopped")
|
self.metrics.state("stopped")
|
||||||
|
|
||||||
self.task = self.taskgroup.create_task(self.run())
|
self.consumer_task = self.taskgroup.create_task(self.consumer_run())
|
||||||
|
|
||||||
async def run(self):
|
async def consumer_run(self):
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
|
|
||||||
|
|
@ -102,7 +119,19 @@ class Consumer:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
await self.consume()
|
print(
|
||||||
|
"Starting", self.concurrency, "receiver threads",
|
||||||
|
flush=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async with asyncio.TaskGroup() as tg:
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
for i in range(0, self.concurrency):
|
||||||
|
tasks.append(
|
||||||
|
tg.create_task(self.consume_from_queue())
|
||||||
|
)
|
||||||
|
|
||||||
if self.metrics:
|
if self.metrics:
|
||||||
self.metrics.state("stopped")
|
self.metrics.state("stopped")
|
||||||
|
|
@ -120,7 +149,7 @@ class Consumer:
|
||||||
self.consumer.unsubscribe()
|
self.consumer.unsubscribe()
|
||||||
self.consumer.close()
|
self.consumer.close()
|
||||||
|
|
||||||
async def consume(self):
|
async def consume_from_queue(self):
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
|
|
||||||
|
|
@ -134,71 +163,75 @@ class Consumer:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
expiry = time.time() + self.rate_limit_timeout
|
await self.handle_one_from_queue(msg)
|
||||||
|
|
||||||
# This loop is for retry on rate-limit / resource limits
|
async def handle_one_from_queue(self, msg):
|
||||||
while self.running:
|
|
||||||
|
|
||||||
if time.time() > expiry:
|
expiry = time.time() + self.rate_limit_timeout
|
||||||
|
|
||||||
print("Gave up waiting for rate-limit retry", flush=True)
|
# This loop is for retry on rate-limit / resource limits
|
||||||
|
while self.running:
|
||||||
|
|
||||||
# Message failed to be processed, this causes it to
|
if time.time() > expiry:
|
||||||
# be retried
|
|
||||||
self.consumer.negative_acknowledge(msg)
|
|
||||||
|
|
||||||
if self.metrics:
|
print("Gave up waiting for rate-limit retry", flush=True)
|
||||||
self.metrics.process("error")
|
|
||||||
|
|
||||||
# Break out of retry loop, processes next message
|
# Message failed to be processed, this causes it to
|
||||||
break
|
# be retried
|
||||||
|
self.consumer.negative_acknowledge(msg)
|
||||||
|
|
||||||
try:
|
if self.metrics:
|
||||||
|
self.metrics.process("error")
|
||||||
|
|
||||||
print("Handle...", flush=True)
|
# Break out of retry loop, processes next message
|
||||||
|
break
|
||||||
|
|
||||||
if self.metrics:
|
try:
|
||||||
|
|
||||||
with self.metrics.record_time():
|
print("Handle...", flush=True)
|
||||||
await self.handler(msg, self, self.flow)
|
|
||||||
|
|
||||||
else:
|
if self.metrics:
|
||||||
|
|
||||||
|
with self.metrics.record_time():
|
||||||
await self.handler(msg, self, self.flow)
|
await self.handler(msg, self, self.flow)
|
||||||
|
|
||||||
print("Handled.", flush=True)
|
else:
|
||||||
|
await self.handler(msg, self, self.flow)
|
||||||
|
|
||||||
# Acknowledge successful processing of the message
|
print("Handled.", flush=True)
|
||||||
self.consumer.acknowledge(msg)
|
|
||||||
|
|
||||||
if self.metrics:
|
# Acknowledge successful processing of the message
|
||||||
self.metrics.process("success")
|
self.consumer.acknowledge(msg)
|
||||||
|
|
||||||
# Break out of retry loop
|
if self.metrics:
|
||||||
break
|
self.metrics.process("success")
|
||||||
|
|
||||||
except TooManyRequests:
|
# Break out of retry loop
|
||||||
|
break
|
||||||
|
|
||||||
print("TooManyRequests: will retry...", flush=True)
|
except TooManyRequests:
|
||||||
|
|
||||||
if self.metrics:
|
print("TooManyRequests: will retry...", flush=True)
|
||||||
self.metrics.rate_limit()
|
|
||||||
|
|
||||||
# Sleep
|
if self.metrics:
|
||||||
await asyncio.sleep(self.rate_limit_retry_time)
|
self.metrics.rate_limit()
|
||||||
|
|
||||||
# Contine from retry loop, just causes a reprocessing
|
# Sleep
|
||||||
continue
|
await asyncio.sleep(self.rate_limit_retry_time)
|
||||||
|
|
||||||
except Exception as e:
|
# Contine from retry loop, just causes a reprocessing
|
||||||
|
continue
|
||||||
|
|
||||||
print("consume exception:", e, flush=True)
|
except Exception as e:
|
||||||
|
|
||||||
# Message failed to be processed, this causes it to
|
print("consume exception:", e, flush=True)
|
||||||
# be retried
|
|
||||||
self.consumer.negative_acknowledge(msg)
|
|
||||||
|
|
||||||
if self.metrics:
|
# Message failed to be processed, this causes it to
|
||||||
self.metrics.process("error")
|
# be retried
|
||||||
|
self.consumer.negative_acknowledge(msg)
|
||||||
|
|
||||||
# Break out of retry loop, processes next message
|
if self.metrics:
|
||||||
break
|
self.metrics.process("error")
|
||||||
|
|
||||||
|
# Break out of retry loop, processes next message
|
||||||
|
break
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,11 @@ from . consumer import Consumer
|
||||||
from . spec import Spec
|
from . spec import Spec
|
||||||
|
|
||||||
class ConsumerSpec(Spec):
|
class ConsumerSpec(Spec):
|
||||||
def __init__(self, name, schema, handler):
|
def __init__(self, name, schema, handler, concurrency = 1):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.handler = handler
|
self.handler = handler
|
||||||
|
self.concurrency = concurrency
|
||||||
|
|
||||||
def add(self, flow, processor, definition):
|
def add(self, flow, processor, definition):
|
||||||
|
|
||||||
|
|
@ -24,6 +25,7 @@ class ConsumerSpec(Spec):
|
||||||
schema = self.schema,
|
schema = self.schema,
|
||||||
handler = self.handler,
|
handler = self.handler,
|
||||||
metrics = consumer_metrics,
|
metrics = consumer_metrics,
|
||||||
|
concurrency = self.concurrency
|
||||||
)
|
)
|
||||||
|
|
||||||
# Consumer handle gets access to producers and other
|
# Consumer handle gets access to producers and other
|
||||||
|
|
|
||||||
|
|
@ -11,20 +11,26 @@ from .. exceptions import TooManyRequests
|
||||||
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
|
|
||||||
default_ident = "embeddings"
|
default_ident = "embeddings"
|
||||||
|
default_concurrency = 1
|
||||||
|
|
||||||
class EmbeddingsService(FlowProcessor):
|
class EmbeddingsService(FlowProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
id = params.get("id")
|
id = params.get("id")
|
||||||
|
concurrency = params.get("concurrency", 1)
|
||||||
|
|
||||||
super(EmbeddingsService, self).__init__(**params | { "id": id })
|
super(EmbeddingsService, self).__init__(**params | {
|
||||||
|
"id": id,
|
||||||
|
"concurrency": concurrency,
|
||||||
|
})
|
||||||
|
|
||||||
self.register_specification(
|
self.register_specification(
|
||||||
ConsumerSpec(
|
ConsumerSpec(
|
||||||
name = "request",
|
name = "request",
|
||||||
schema = EmbeddingsRequest,
|
schema = EmbeddingsRequest,
|
||||||
handler = self.on_request
|
handler = self.on_request,
|
||||||
|
concurrency = concurrency,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -84,6 +90,13 @@ class EmbeddingsService(FlowProcessor):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--concurrency',
|
||||||
|
type=int,
|
||||||
|
default=default_concurrency,
|
||||||
|
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||||
|
)
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,13 @@ from .. exceptions import TooManyRequests
|
||||||
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
|
from .. base import FlowProcessor, ConsumerSpec, ProducerSpec
|
||||||
|
|
||||||
default_ident = "text-completion"
|
default_ident = "text-completion"
|
||||||
|
default_concurrency = 1
|
||||||
|
|
||||||
class LlmResult:
|
class LlmResult:
|
||||||
def __init__(self, text=None, in_token=None, out_token=None, model=None):
|
def __init__(
|
||||||
|
self, text = None, in_token = None, out_token = None,
|
||||||
|
model = None,
|
||||||
|
):
|
||||||
self.text = text
|
self.text = text
|
||||||
self.in_token = in_token
|
self.in_token = in_token
|
||||||
self.out_token = out_token
|
self.out_token = out_token
|
||||||
|
|
@ -25,14 +29,19 @@ class LlmService(FlowProcessor):
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
id = params.get("id")
|
id = params.get("id")
|
||||||
|
concurrency = params.get("concurrency", 1)
|
||||||
|
|
||||||
super(LlmService, self).__init__(**params | { "id": id })
|
super(LlmService, self).__init__(**params | {
|
||||||
|
"id": id,
|
||||||
|
"concurrency": concurrency,
|
||||||
|
})
|
||||||
|
|
||||||
self.register_specification(
|
self.register_specification(
|
||||||
ConsumerSpec(
|
ConsumerSpec(
|
||||||
name = "request",
|
name = "request",
|
||||||
schema = TextCompletionRequest,
|
schema = TextCompletionRequest,
|
||||||
handler = self.on_request
|
handler = self.on_request,
|
||||||
|
concurrency = concurrency,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -115,5 +124,12 @@ class LlmService(FlowProcessor):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--concurrency',
|
||||||
|
type=int,
|
||||||
|
default=default_concurrency,
|
||||||
|
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||||
|
)
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
|
|
||||||
|
# Subscriber is similar to consumer: It provides a service to take stuff
|
||||||
|
# off of a queue and make it available using an internal broker system,
|
||||||
|
# so suitable for when multiple recipients are reading from the same queue
|
||||||
|
|
||||||
from pulsar.schema import JsonSchema
|
from pulsar.schema import JsonSchema
|
||||||
import asyncio
|
import asyncio
|
||||||
import _pulsar
|
import _pulsar
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ Triples store base class
|
||||||
|
|
||||||
from .. schema import Triples
|
from .. schema import Triples
|
||||||
from .. base import FlowProcessor, ConsumerSpec
|
from .. base import FlowProcessor, ConsumerSpec
|
||||||
|
from .. exceptions import TooManyRequests
|
||||||
|
|
||||||
default_ident = "triples-write"
|
default_ident = "triples-write"
|
||||||
|
|
||||||
|
|
|
||||||
105
trustgraph-base/trustgraph/messaging/__init__.py
Normal file
105
trustgraph-base/trustgraph/messaging/__init__.py
Normal file
|
|
@ -0,0 +1,105 @@
|
||||||
|
from .registry import TranslatorRegistry
|
||||||
|
from .translators import *
|
||||||
|
|
||||||
|
# Auto-register all translators
|
||||||
|
from .translators.agent import AgentRequestTranslator, AgentResponseTranslator
|
||||||
|
from .translators.embeddings import EmbeddingsRequestTranslator, EmbeddingsResponseTranslator
|
||||||
|
from .translators.text_completion import TextCompletionRequestTranslator, TextCompletionResponseTranslator
|
||||||
|
from .translators.retrieval import (
|
||||||
|
DocumentRagRequestTranslator, DocumentRagResponseTranslator,
|
||||||
|
GraphRagRequestTranslator, GraphRagResponseTranslator
|
||||||
|
)
|
||||||
|
from .translators.triples import TriplesQueryRequestTranslator, TriplesQueryResponseTranslator
|
||||||
|
from .translators.knowledge import KnowledgeRequestTranslator, KnowledgeResponseTranslator
|
||||||
|
from .translators.library import LibraryRequestTranslator, LibraryResponseTranslator
|
||||||
|
from .translators.document_loading import DocumentTranslator, TextDocumentTranslator
|
||||||
|
from .translators.config import ConfigRequestTranslator, ConfigResponseTranslator
|
||||||
|
from .translators.flow import FlowRequestTranslator, FlowResponseTranslator
|
||||||
|
from .translators.prompt import PromptRequestTranslator, PromptResponseTranslator
|
||||||
|
from .translators.embeddings_query import (
|
||||||
|
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||||
|
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register all service translators
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"agent",
|
||||||
|
AgentRequestTranslator(),
|
||||||
|
AgentResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"embeddings",
|
||||||
|
EmbeddingsRequestTranslator(),
|
||||||
|
EmbeddingsResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"text-completion",
|
||||||
|
TextCompletionRequestTranslator(),
|
||||||
|
TextCompletionResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"document-rag",
|
||||||
|
DocumentRagRequestTranslator(),
|
||||||
|
DocumentRagResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"graph-rag",
|
||||||
|
GraphRagRequestTranslator(),
|
||||||
|
GraphRagResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"triples-query",
|
||||||
|
TriplesQueryRequestTranslator(),
|
||||||
|
TriplesQueryResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"knowledge",
|
||||||
|
KnowledgeRequestTranslator(),
|
||||||
|
KnowledgeResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"librarian",
|
||||||
|
LibraryRequestTranslator(),
|
||||||
|
LibraryResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"config",
|
||||||
|
ConfigRequestTranslator(),
|
||||||
|
ConfigResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"flow",
|
||||||
|
FlowRequestTranslator(),
|
||||||
|
FlowResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"prompt",
|
||||||
|
PromptRequestTranslator(),
|
||||||
|
PromptResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"document-embeddings-query",
|
||||||
|
DocumentEmbeddingsRequestTranslator(),
|
||||||
|
DocumentEmbeddingsResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
TranslatorRegistry.register_service(
|
||||||
|
"graph-embeddings-query",
|
||||||
|
GraphEmbeddingsRequestTranslator(),
|
||||||
|
GraphEmbeddingsResponseTranslator()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register single-direction translators for document loading
|
||||||
|
TranslatorRegistry.register_request("document", DocumentTranslator())
|
||||||
|
TranslatorRegistry.register_request("text-document", TextDocumentTranslator())
|
||||||
51
trustgraph-base/trustgraph/messaging/registry.py
Normal file
51
trustgraph-base/trustgraph/messaging/registry.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
from .translators.base import MessageTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class TranslatorRegistry:
|
||||||
|
"""Registry for service translators"""
|
||||||
|
|
||||||
|
_request_translators: Dict[str, MessageTranslator] = {}
|
||||||
|
_response_translators: Dict[str, MessageTranslator] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_request(cls, service_name: str, translator: MessageTranslator):
|
||||||
|
"""Register a request translator for a service"""
|
||||||
|
cls._request_translators[service_name] = translator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_response(cls, service_name: str, translator: MessageTranslator):
|
||||||
|
"""Register a response translator for a service"""
|
||||||
|
cls._response_translators[service_name] = translator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_service(cls, service_name: str, request_translator: MessageTranslator,
|
||||||
|
response_translator: MessageTranslator):
|
||||||
|
"""Register both request and response translators for a service"""
|
||||||
|
cls.register_request(service_name, request_translator)
|
||||||
|
cls.register_response(service_name, response_translator)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_request_translator(cls, service_name: str) -> MessageTranslator:
|
||||||
|
"""Get request translator for a service"""
|
||||||
|
if service_name not in cls._request_translators:
|
||||||
|
raise KeyError(f"No request translator registered for service: {service_name}")
|
||||||
|
return cls._request_translators[service_name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_response_translator(cls, service_name: str) -> MessageTranslator:
|
||||||
|
"""Get response translator for a service"""
|
||||||
|
if service_name not in cls._response_translators:
|
||||||
|
raise KeyError(f"No response translator registered for service: {service_name}")
|
||||||
|
return cls._response_translators[service_name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_services(cls) -> List[str]:
|
||||||
|
"""List all registered services"""
|
||||||
|
return sorted(set(cls._request_translators.keys()) | set(cls._response_translators.keys()))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def has_service(cls, service_name: str) -> bool:
|
||||||
|
"""Check if a service is registered"""
|
||||||
|
return (service_name in cls._request_translators or
|
||||||
|
service_name in cls._response_translators)
|
||||||
19
trustgraph-base/trustgraph/messaging/translators/__init__.py
Normal file
19
trustgraph-base/trustgraph/messaging/translators/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
from .base import Translator, MessageTranslator
|
||||||
|
from .primitives import ValueTranslator, TripleTranslator, SubgraphTranslator
|
||||||
|
from .metadata import DocumentMetadataTranslator, ProcessingMetadataTranslator
|
||||||
|
from .agent import AgentRequestTranslator, AgentResponseTranslator
|
||||||
|
from .embeddings import EmbeddingsRequestTranslator, EmbeddingsResponseTranslator
|
||||||
|
from .text_completion import TextCompletionRequestTranslator, TextCompletionResponseTranslator
|
||||||
|
from .retrieval import DocumentRagRequestTranslator, DocumentRagResponseTranslator
|
||||||
|
from .retrieval import GraphRagRequestTranslator, GraphRagResponseTranslator
|
||||||
|
from .triples import TriplesQueryRequestTranslator, TriplesQueryResponseTranslator
|
||||||
|
from .knowledge import KnowledgeRequestTranslator, KnowledgeResponseTranslator
|
||||||
|
from .library import LibraryRequestTranslator, LibraryResponseTranslator
|
||||||
|
from .document_loading import DocumentTranslator, TextDocumentTranslator, ChunkTranslator, DocumentEmbeddingsTranslator
|
||||||
|
from .config import ConfigRequestTranslator, ConfigResponseTranslator
|
||||||
|
from .flow import FlowRequestTranslator, FlowResponseTranslator
|
||||||
|
from .prompt import PromptRequestTranslator, PromptResponseTranslator
|
||||||
|
from .embeddings_query import (
|
||||||
|
DocumentEmbeddingsRequestTranslator, DocumentEmbeddingsResponseTranslator,
|
||||||
|
GraphEmbeddingsRequestTranslator, GraphEmbeddingsResponseTranslator
|
||||||
|
)
|
||||||
44
trustgraph-base/trustgraph/messaging/translators/agent.py
Normal file
44
trustgraph-base/trustgraph/messaging/translators/agent.py
Normal file
|
|
@ -0,0 +1,44 @@
|
||||||
|
from typing import Dict, Any, Tuple
|
||||||
|
from ...schema import AgentRequest, AgentResponse
|
||||||
|
from .base import MessageTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for AgentRequest schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> AgentRequest:
|
||||||
|
return AgentRequest(
|
||||||
|
question=data["question"],
|
||||||
|
plan=data.get("plan", ""),
|
||||||
|
state=data.get("state", ""),
|
||||||
|
history=data.get("history", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: AgentRequest) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"question": obj.question,
|
||||||
|
"plan": obj.plan,
|
||||||
|
"state": obj.state,
|
||||||
|
"history": obj.history
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AgentResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for AgentResponse schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> AgentResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: AgentResponse) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
if obj.answer:
|
||||||
|
result["answer"] = obj.answer
|
||||||
|
if obj.thought:
|
||||||
|
result["thought"] = obj.thought
|
||||||
|
if obj.observation:
|
||||||
|
result["observation"] = obj.observation
|
||||||
|
return result
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: AgentResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), (obj.answer is not None)
|
||||||
43
trustgraph-base/trustgraph/messaging/translators/base.py
Normal file
43
trustgraph-base/trustgraph/messaging/translators/base.py
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Any, Tuple
|
||||||
|
from pulsar.schema import Record
|
||||||
|
|
||||||
|
|
||||||
|
class Translator(ABC):
|
||||||
|
"""Base class for bidirectional Pulsar ↔ dict translation"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> Record:
|
||||||
|
"""Convert dict to Pulsar schema object"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def from_pulsar(self, obj: Record) -> Dict[str, Any]:
|
||||||
|
"""Convert Pulsar schema object to dict"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MessageTranslator(Translator):
|
||||||
|
"""For complete request/response message translation"""
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: Record) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final) - for streaming responses"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
|
|
||||||
|
|
||||||
|
class SendTranslator(Translator):
|
||||||
|
"""For fire-and-forget send operations (like ServiceSender)"""
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: Record) -> Dict[str, Any]:
|
||||||
|
"""Usually not needed for send-only operations"""
|
||||||
|
raise NotImplementedError("Send translators typically don't need from_pulsar")
|
||||||
|
|
||||||
|
|
||||||
|
def handle_optional_fields(obj: Record, fields: list) -> Dict[str, Any]:
|
||||||
|
"""Helper to extract optional fields from Pulsar object"""
|
||||||
|
result = {}
|
||||||
|
for field in fields:
|
||||||
|
value = getattr(obj, field, None)
|
||||||
|
if value is not None:
|
||||||
|
result[field] = value
|
||||||
|
return result
|
||||||
100
trustgraph-base/trustgraph/messaging/translators/config.py
Normal file
100
trustgraph-base/trustgraph/messaging/translators/config.py
Normal file
|
|
@ -0,0 +1,100 @@
|
||||||
|
from typing import Dict, Any, Tuple
|
||||||
|
from ...schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
|
||||||
|
from .base import MessageTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for ConfigRequest schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> ConfigRequest:
|
||||||
|
keys = None
|
||||||
|
if "keys" in data:
|
||||||
|
keys = [
|
||||||
|
ConfigKey(
|
||||||
|
type=k["type"],
|
||||||
|
key=k["key"]
|
||||||
|
)
|
||||||
|
for k in data["keys"]
|
||||||
|
]
|
||||||
|
|
||||||
|
values = None
|
||||||
|
if "values" in data:
|
||||||
|
values = [
|
||||||
|
ConfigValue(
|
||||||
|
type=v["type"],
|
||||||
|
key=v["key"],
|
||||||
|
value=v["value"]
|
||||||
|
)
|
||||||
|
for v in data["values"]
|
||||||
|
]
|
||||||
|
|
||||||
|
return ConfigRequest(
|
||||||
|
operation=data.get("operation"),
|
||||||
|
keys=keys,
|
||||||
|
type=data.get("type"),
|
||||||
|
values=values
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: ConfigRequest) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.operation:
|
||||||
|
result["operation"] = obj.operation
|
||||||
|
if obj.type:
|
||||||
|
result["type"] = obj.type
|
||||||
|
|
||||||
|
if obj.keys:
|
||||||
|
result["keys"] = [
|
||||||
|
{
|
||||||
|
"type": k.type,
|
||||||
|
"key": k.key
|
||||||
|
}
|
||||||
|
for k in obj.keys
|
||||||
|
]
|
||||||
|
|
||||||
|
if obj.values:
|
||||||
|
result["values"] = [
|
||||||
|
{
|
||||||
|
"type": v.type,
|
||||||
|
"key": v.key,
|
||||||
|
"value": v.value
|
||||||
|
}
|
||||||
|
for v in obj.values
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for ConfigResponse schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> ConfigResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: ConfigResponse) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.version is not None:
|
||||||
|
result["version"] = obj.version
|
||||||
|
|
||||||
|
if obj.values:
|
||||||
|
result["values"] = [
|
||||||
|
{
|
||||||
|
"type": v.type,
|
||||||
|
"key": v.key,
|
||||||
|
"value": v.value
|
||||||
|
}
|
||||||
|
for v in obj.values
|
||||||
|
]
|
||||||
|
|
||||||
|
if obj.directory:
|
||||||
|
result["directory"] = obj.directory
|
||||||
|
|
||||||
|
if obj.config:
|
||||||
|
result["config"] = obj.config
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: ConfigResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
|
|
@ -0,0 +1,191 @@
|
||||||
|
import base64
|
||||||
|
from typing import Dict, Any
|
||||||
|
from ...schema import Document, TextDocument, Chunk, DocumentEmbeddings, ChunkEmbeddings
|
||||||
|
from .base import SendTranslator
|
||||||
|
from .metadata import DocumentMetadataTranslator
|
||||||
|
from .primitives import SubgraphTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentTranslator(SendTranslator):
|
||||||
|
"""Translator for Document schema objects (PDF docs etc.)"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.subgraph_translator = SubgraphTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> Document:
|
||||||
|
metadata = data.get("metadata", [])
|
||||||
|
|
||||||
|
# Handle base64 content validation
|
||||||
|
doc = base64.b64decode(data["data"])
|
||||||
|
|
||||||
|
from ...schema import Metadata
|
||||||
|
return Document(
|
||||||
|
metadata=Metadata(
|
||||||
|
id=data.get("id"),
|
||||||
|
metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [],
|
||||||
|
user=data.get("user", "trustgraph"),
|
||||||
|
collection=data.get("collection", "default"),
|
||||||
|
),
|
||||||
|
data=base64.b64encode(doc).decode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: Document) -> Dict[str, Any]:
|
||||||
|
result = {
|
||||||
|
"data": obj.data
|
||||||
|
}
|
||||||
|
|
||||||
|
if obj.metadata:
|
||||||
|
metadata_dict = {}
|
||||||
|
if obj.metadata.id:
|
||||||
|
metadata_dict["id"] = obj.metadata.id
|
||||||
|
if obj.metadata.user:
|
||||||
|
metadata_dict["user"] = obj.metadata.user
|
||||||
|
if obj.metadata.collection:
|
||||||
|
metadata_dict["collection"] = obj.metadata.collection
|
||||||
|
if obj.metadata.metadata:
|
||||||
|
metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata)
|
||||||
|
|
||||||
|
result["metadata"] = metadata_dict
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TextDocumentTranslator(SendTranslator):
|
||||||
|
"""Translator for TextDocument schema objects"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.subgraph_translator = SubgraphTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> TextDocument:
|
||||||
|
metadata = data.get("metadata", [])
|
||||||
|
charset = data.get("charset", "utf-8")
|
||||||
|
|
||||||
|
# Text is base64 encoded in input
|
||||||
|
text = base64.b64decode(data["text"]).decode(charset)
|
||||||
|
|
||||||
|
from ...schema import Metadata
|
||||||
|
return TextDocument(
|
||||||
|
metadata=Metadata(
|
||||||
|
id=data.get("id"),
|
||||||
|
metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [],
|
||||||
|
user=data.get("user", "trustgraph"),
|
||||||
|
collection=data.get("collection", "default"),
|
||||||
|
),
|
||||||
|
text=text.encode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: TextDocument) -> Dict[str, Any]:
|
||||||
|
result = {
|
||||||
|
"text": obj.text.decode("utf-8") if isinstance(obj.text, bytes) else obj.text
|
||||||
|
}
|
||||||
|
|
||||||
|
if obj.metadata:
|
||||||
|
metadata_dict = {}
|
||||||
|
if obj.metadata.id:
|
||||||
|
metadata_dict["id"] = obj.metadata.id
|
||||||
|
if obj.metadata.user:
|
||||||
|
metadata_dict["user"] = obj.metadata.user
|
||||||
|
if obj.metadata.collection:
|
||||||
|
metadata_dict["collection"] = obj.metadata.collection
|
||||||
|
if obj.metadata.metadata:
|
||||||
|
metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata)
|
||||||
|
|
||||||
|
result["metadata"] = metadata_dict
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkTranslator(SendTranslator):
|
||||||
|
"""Translator for Chunk schema objects"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.subgraph_translator = SubgraphTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> Chunk:
|
||||||
|
metadata = data.get("metadata", [])
|
||||||
|
|
||||||
|
from ...schema import Metadata
|
||||||
|
return Chunk(
|
||||||
|
metadata=Metadata(
|
||||||
|
id=data.get("id"),
|
||||||
|
metadata=self.subgraph_translator.to_pulsar(metadata) if metadata else [],
|
||||||
|
user=data.get("user", "trustgraph"),
|
||||||
|
collection=data.get("collection", "default"),
|
||||||
|
),
|
||||||
|
chunk=data["chunk"].encode("utf-8") if isinstance(data["chunk"], str) else data["chunk"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: Chunk) -> Dict[str, Any]:
|
||||||
|
result = {
|
||||||
|
"chunk": obj.chunk.decode("utf-8") if isinstance(obj.chunk, bytes) else obj.chunk
|
||||||
|
}
|
||||||
|
|
||||||
|
if obj.metadata:
|
||||||
|
metadata_dict = {}
|
||||||
|
if obj.metadata.id:
|
||||||
|
metadata_dict["id"] = obj.metadata.id
|
||||||
|
if obj.metadata.user:
|
||||||
|
metadata_dict["user"] = obj.metadata.user
|
||||||
|
if obj.metadata.collection:
|
||||||
|
metadata_dict["collection"] = obj.metadata.collection
|
||||||
|
if obj.metadata.metadata:
|
||||||
|
metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata)
|
||||||
|
|
||||||
|
result["metadata"] = metadata_dict
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentEmbeddingsTranslator(SendTranslator):
|
||||||
|
"""Translator for DocumentEmbeddings schema objects"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.subgraph_translator = SubgraphTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddings:
|
||||||
|
metadata = data.get("metadata", {})
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
ChunkEmbeddings(
|
||||||
|
chunk=chunk["chunk"].encode("utf-8") if isinstance(chunk["chunk"], str) else chunk["chunk"],
|
||||||
|
vectors=chunk["vectors"]
|
||||||
|
)
|
||||||
|
for chunk in data.get("chunks", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
from ...schema import Metadata
|
||||||
|
return DocumentEmbeddings(
|
||||||
|
metadata=Metadata(
|
||||||
|
id=metadata.get("id"),
|
||||||
|
metadata=self.subgraph_translator.to_pulsar(metadata.get("metadata", [])),
|
||||||
|
user=metadata.get("user", "trustgraph"),
|
||||||
|
collection=metadata.get("collection", "default"),
|
||||||
|
),
|
||||||
|
chunks=chunks
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: DocumentEmbeddings) -> Dict[str, Any]:
|
||||||
|
result = {
|
||||||
|
"chunks": [
|
||||||
|
{
|
||||||
|
"chunk": chunk.chunk.decode("utf-8") if isinstance(chunk.chunk, bytes) else chunk.chunk,
|
||||||
|
"vectors": chunk.vectors
|
||||||
|
}
|
||||||
|
for chunk in obj.chunks
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
if obj.metadata:
|
||||||
|
metadata_dict = {}
|
||||||
|
if obj.metadata.id:
|
||||||
|
metadata_dict["id"] = obj.metadata.id
|
||||||
|
if obj.metadata.user:
|
||||||
|
metadata_dict["user"] = obj.metadata.user
|
||||||
|
if obj.metadata.collection:
|
||||||
|
metadata_dict["collection"] = obj.metadata.collection
|
||||||
|
if obj.metadata.metadata:
|
||||||
|
metadata_dict["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata.metadata)
|
||||||
|
|
||||||
|
result["metadata"] = metadata_dict
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
from typing import Dict, Any, Tuple
|
||||||
|
from ...schema import EmbeddingsRequest, EmbeddingsResponse
|
||||||
|
from .base import MessageTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for EmbeddingsRequest schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsRequest:
|
||||||
|
return EmbeddingsRequest(
|
||||||
|
text=data["text"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: EmbeddingsRequest) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"text": obj.text
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for EmbeddingsResponse schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: EmbeddingsResponse) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"vectors": obj.vectors
|
||||||
|
}
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: EmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
from typing import Dict, Any, Tuple
|
||||||
|
from ...schema import (
|
||||||
|
DocumentEmbeddingsRequest, DocumentEmbeddingsResponse,
|
||||||
|
GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||||
|
)
|
||||||
|
from .base import MessageTranslator
|
||||||
|
from .primitives import ValueTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for DocumentEmbeddingsRequest schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsRequest:
|
||||||
|
return DocumentEmbeddingsRequest(
|
||||||
|
vectors=data["vectors"],
|
||||||
|
limit=int(data.get("limit", 10)),
|
||||||
|
user=data.get("user", "trustgraph"),
|
||||||
|
collection=data.get("collection", "default")
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: DocumentEmbeddingsRequest) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"vectors": obj.vectors,
|
||||||
|
"limit": obj.limit,
|
||||||
|
"user": obj.user,
|
||||||
|
"collection": obj.collection
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentEmbeddingsResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for DocumentEmbeddingsResponse schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> DocumentEmbeddingsResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: DocumentEmbeddingsResponse) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.documents:
|
||||||
|
result["documents"] = [
|
||||||
|
doc.decode("utf-8") if isinstance(doc, bytes) else doc
|
||||||
|
for doc in obj.documents
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: DocumentEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
|
|
||||||
|
|
||||||
|
class GraphEmbeddingsRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for GraphEmbeddingsRequest schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsRequest:
|
||||||
|
return GraphEmbeddingsRequest(
|
||||||
|
vectors=data["vectors"],
|
||||||
|
limit=int(data.get("limit", 10)),
|
||||||
|
user=data.get("user", "trustgraph"),
|
||||||
|
collection=data.get("collection", "default")
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: GraphEmbeddingsRequest) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"vectors": obj.vectors,
|
||||||
|
"limit": obj.limit,
|
||||||
|
"user": obj.user,
|
||||||
|
"collection": obj.collection
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GraphEmbeddingsResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for GraphEmbeddingsResponse schema objects"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.value_translator = ValueTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> GraphEmbeddingsResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: GraphEmbeddingsResponse) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.entities:
|
||||||
|
result["entities"] = [
|
||||||
|
self.value_translator.from_pulsar(entity)
|
||||||
|
for entity in obj.entities
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: GraphEmbeddingsResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
59
trustgraph-base/trustgraph/messaging/translators/flow.py
Normal file
59
trustgraph-base/trustgraph/messaging/translators/flow.py
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
from typing import Dict, Any, Tuple
|
||||||
|
from ...schema import FlowRequest, FlowResponse
|
||||||
|
from .base import MessageTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class FlowRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for FlowRequest schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> FlowRequest:
|
||||||
|
return FlowRequest(
|
||||||
|
operation=data.get("operation"),
|
||||||
|
class_name=data.get("class-name"),
|
||||||
|
class_definition=data.get("class-definition"),
|
||||||
|
description=data.get("description"),
|
||||||
|
flow_id=data.get("flow-id")
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: FlowRequest) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.operation:
|
||||||
|
result["operation"] = obj.operation
|
||||||
|
if obj.class_name:
|
||||||
|
result["class-name"] = obj.class_name
|
||||||
|
if obj.class_definition:
|
||||||
|
result["class-definition"] = obj.class_definition
|
||||||
|
if obj.description:
|
||||||
|
result["description"] = obj.description
|
||||||
|
if obj.flow_id:
|
||||||
|
result["flow-id"] = obj.flow_id
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class FlowResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for FlowResponse schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> FlowResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: FlowResponse) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.class_names:
|
||||||
|
result["class-names"] = obj.class_names
|
||||||
|
if obj.flow_ids:
|
||||||
|
result["flow-ids"] = obj.flow_ids
|
||||||
|
if obj.class_definition:
|
||||||
|
result["class-definition"] = obj.class_definition
|
||||||
|
if obj.flow:
|
||||||
|
result["flow"] = obj.flow
|
||||||
|
if obj.description:
|
||||||
|
result["description"] = obj.description
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: FlowResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
183
trustgraph-base/trustgraph/messaging/translators/knowledge.py
Normal file
183
trustgraph-base/trustgraph/messaging/translators/knowledge.py
Normal file
|
|
@ -0,0 +1,183 @@
|
||||||
|
from typing import Dict, Any, Tuple, Optional
|
||||||
|
from ...schema import (
|
||||||
|
KnowledgeRequest, KnowledgeResponse, Triples, GraphEmbeddings,
|
||||||
|
Metadata, EntityEmbeddings
|
||||||
|
)
|
||||||
|
from .base import MessageTranslator
|
||||||
|
from .primitives import ValueTranslator, SubgraphTranslator
|
||||||
|
from .metadata import DocumentMetadataTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for KnowledgeRequest schema objects"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.value_translator = ValueTranslator()
|
||||||
|
self.subgraph_translator = SubgraphTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> KnowledgeRequest:
|
||||||
|
triples = None
|
||||||
|
if "triples" in data:
|
||||||
|
triples = Triples(
|
||||||
|
metadata=Metadata(
|
||||||
|
id=data["triples"]["metadata"]["id"],
|
||||||
|
metadata=self.subgraph_translator.to_pulsar(
|
||||||
|
data["triples"]["metadata"]["metadata"]
|
||||||
|
),
|
||||||
|
user=data["triples"]["metadata"]["user"],
|
||||||
|
collection=data["triples"]["metadata"]["collection"]
|
||||||
|
),
|
||||||
|
triples=self.subgraph_translator.to_pulsar(data["triples"]["triples"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_embeddings = None
|
||||||
|
if "graph-embeddings" in data:
|
||||||
|
graph_embeddings = GraphEmbeddings(
|
||||||
|
metadata=Metadata(
|
||||||
|
id=data["graph-embeddings"]["metadata"]["id"],
|
||||||
|
metadata=self.subgraph_translator.to_pulsar(
|
||||||
|
data["graph-embeddings"]["metadata"]["metadata"]
|
||||||
|
),
|
||||||
|
user=data["graph-embeddings"]["metadata"]["user"],
|
||||||
|
collection=data["graph-embeddings"]["metadata"]["collection"]
|
||||||
|
),
|
||||||
|
entities=[
|
||||||
|
EntityEmbeddings(
|
||||||
|
entity=self.value_translator.to_pulsar(ent["entity"]),
|
||||||
|
vectors=ent["vectors"],
|
||||||
|
)
|
||||||
|
for ent in data["graph-embeddings"]["entities"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return KnowledgeRequest(
|
||||||
|
operation=data.get("operation"),
|
||||||
|
user=data.get("user"),
|
||||||
|
id=data.get("id"),
|
||||||
|
flow=data.get("flow"),
|
||||||
|
collection=data.get("collection"),
|
||||||
|
triples=triples,
|
||||||
|
graph_embeddings=graph_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: KnowledgeRequest) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.operation:
|
||||||
|
result["operation"] = obj.operation
|
||||||
|
if obj.user:
|
||||||
|
result["user"] = obj.user
|
||||||
|
if obj.id:
|
||||||
|
result["id"] = obj.id
|
||||||
|
if obj.flow:
|
||||||
|
result["flow"] = obj.flow
|
||||||
|
if obj.collection:
|
||||||
|
result["collection"] = obj.collection
|
||||||
|
|
||||||
|
if obj.triples:
|
||||||
|
result["triples"] = {
|
||||||
|
"metadata": {
|
||||||
|
"id": obj.triples.metadata.id,
|
||||||
|
"metadata": self.subgraph_translator.from_pulsar(
|
||||||
|
obj.triples.metadata.metadata
|
||||||
|
),
|
||||||
|
"user": obj.triples.metadata.user,
|
||||||
|
"collection": obj.triples.metadata.collection,
|
||||||
|
},
|
||||||
|
"triples": self.subgraph_translator.from_pulsar(obj.triples.triples),
|
||||||
|
}
|
||||||
|
|
||||||
|
if obj.graph_embeddings:
|
||||||
|
result["graph-embeddings"] = {
|
||||||
|
"metadata": {
|
||||||
|
"id": obj.graph_embeddings.metadata.id,
|
||||||
|
"metadata": self.subgraph_translator.from_pulsar(
|
||||||
|
obj.graph_embeddings.metadata.metadata
|
||||||
|
),
|
||||||
|
"user": obj.graph_embeddings.metadata.user,
|
||||||
|
"collection": obj.graph_embeddings.metadata.collection,
|
||||||
|
},
|
||||||
|
"entities": [
|
||||||
|
{
|
||||||
|
"vectors": entity.vectors,
|
||||||
|
"entity": self.value_translator.from_pulsar(entity.entity),
|
||||||
|
}
|
||||||
|
for entity in obj.graph_embeddings.entities
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for KnowledgeResponse schema objects"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.value_translator = ValueTranslator()
|
||||||
|
self.subgraph_translator = SubgraphTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> KnowledgeResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: KnowledgeResponse) -> Dict[str, Any]:
|
||||||
|
# Response to list operation
|
||||||
|
if obj.ids is not None:
|
||||||
|
return {"ids": obj.ids}
|
||||||
|
|
||||||
|
# Streaming triples response
|
||||||
|
if obj.triples:
|
||||||
|
return {
|
||||||
|
"triples": {
|
||||||
|
"metadata": {
|
||||||
|
"id": obj.triples.metadata.id,
|
||||||
|
"metadata": self.subgraph_translator.from_pulsar(
|
||||||
|
obj.triples.metadata.metadata
|
||||||
|
),
|
||||||
|
"user": obj.triples.metadata.user,
|
||||||
|
"collection": obj.triples.metadata.collection,
|
||||||
|
},
|
||||||
|
"triples": self.subgraph_translator.from_pulsar(obj.triples.triples),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Streaming graph embeddings response
|
||||||
|
if obj.graph_embeddings:
|
||||||
|
return {
|
||||||
|
"graph-embeddings": {
|
||||||
|
"metadata": {
|
||||||
|
"id": obj.graph_embeddings.metadata.id,
|
||||||
|
"metadata": self.subgraph_translator.from_pulsar(
|
||||||
|
obj.graph_embeddings.metadata.metadata
|
||||||
|
),
|
||||||
|
"user": obj.graph_embeddings.metadata.user,
|
||||||
|
"collection": obj.graph_embeddings.metadata.collection,
|
||||||
|
},
|
||||||
|
"entities": [
|
||||||
|
{
|
||||||
|
"vectors": entity.vectors,
|
||||||
|
"entity": self.value_translator.from_pulsar(entity.entity),
|
||||||
|
}
|
||||||
|
for entity in obj.graph_embeddings.entities
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# End of stream marker
|
||||||
|
if obj.eos is True:
|
||||||
|
return {"eos": True}
|
||||||
|
|
||||||
|
# Empty response (successful delete)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: KnowledgeResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
response = self.from_pulsar(obj)
|
||||||
|
|
||||||
|
# Check if this is a final response
|
||||||
|
is_final = (
|
||||||
|
obj.ids is not None or # List response
|
||||||
|
obj.eos is True or # End of stream
|
||||||
|
(not obj.triples and not obj.graph_embeddings) # Empty response
|
||||||
|
)
|
||||||
|
|
||||||
|
return response, is_final
|
||||||
124
trustgraph-base/trustgraph/messaging/translators/library.py
Normal file
124
trustgraph-base/trustgraph/messaging/translators/library.py
Normal file
|
|
@ -0,0 +1,124 @@
|
||||||
|
from typing import Dict, Any, Tuple, Optional
|
||||||
|
from ...schema import LibrarianRequest, LibrarianResponse, DocumentMetadata, ProcessingMetadata, Criteria
|
||||||
|
from .base import MessageTranslator
|
||||||
|
from .metadata import DocumentMetadataTranslator, ProcessingMetadataTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class LibraryRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for LibrarianRequest schema objects"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.doc_metadata_translator = DocumentMetadataTranslator()
|
||||||
|
self.proc_metadata_translator = ProcessingMetadataTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> LibrarianRequest:
|
||||||
|
# Document metadata
|
||||||
|
doc_metadata = None
|
||||||
|
if "document-metadata" in data:
|
||||||
|
doc_metadata = self.doc_metadata_translator.to_pulsar(data["document-metadata"])
|
||||||
|
|
||||||
|
# Processing metadata
|
||||||
|
proc_metadata = None
|
||||||
|
if "processing-metadata" in data:
|
||||||
|
proc_metadata = self.proc_metadata_translator.to_pulsar(data["processing-metadata"])
|
||||||
|
|
||||||
|
# Criteria
|
||||||
|
criteria = []
|
||||||
|
if "criteria" in data:
|
||||||
|
criteria = [
|
||||||
|
Criteria(
|
||||||
|
key=c["key"],
|
||||||
|
value=c["value"],
|
||||||
|
operator=c["operator"]
|
||||||
|
)
|
||||||
|
for c in data["criteria"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Content as bytes
|
||||||
|
content = None
|
||||||
|
if "content" in data:
|
||||||
|
if isinstance(data["content"], str):
|
||||||
|
content = data["content"].encode("utf-8")
|
||||||
|
else:
|
||||||
|
content = data["content"]
|
||||||
|
|
||||||
|
return LibrarianRequest(
|
||||||
|
operation=data.get("operation"),
|
||||||
|
document_id=data.get("document-id"),
|
||||||
|
processing_id=data.get("processing-id"),
|
||||||
|
document_metadata=doc_metadata,
|
||||||
|
processing_metadata=proc_metadata,
|
||||||
|
content=content,
|
||||||
|
user=data.get("user"),
|
||||||
|
collection=data.get("collection"),
|
||||||
|
criteria=criteria
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: LibrarianRequest) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.operation:
|
||||||
|
result["operation"] = obj.operation
|
||||||
|
if obj.document_id:
|
||||||
|
result["document-id"] = obj.document_id
|
||||||
|
if obj.processing_id:
|
||||||
|
result["processing-id"] = obj.processing_id
|
||||||
|
if obj.document_metadata:
|
||||||
|
result["document-metadata"] = self.doc_metadata_translator.from_pulsar(obj.document_metadata)
|
||||||
|
if obj.processing_metadata:
|
||||||
|
result["processing-metadata"] = self.proc_metadata_translator.from_pulsar(obj.processing_metadata)
|
||||||
|
if obj.content:
|
||||||
|
result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content
|
||||||
|
if obj.user:
|
||||||
|
result["user"] = obj.user
|
||||||
|
if obj.collection:
|
||||||
|
result["collection"] = obj.collection
|
||||||
|
if obj.criteria is not None:
|
||||||
|
result["criteria"] = [
|
||||||
|
{
|
||||||
|
"key": c.key,
|
||||||
|
"value": c.value,
|
||||||
|
"operator": c.operator
|
||||||
|
}
|
||||||
|
for c in obj.criteria
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class LibraryResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for LibrarianResponse schema objects"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.doc_metadata_translator = DocumentMetadataTranslator()
|
||||||
|
self.proc_metadata_translator = ProcessingMetadataTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> LibrarianResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: LibrarianResponse) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.document_metadata:
|
||||||
|
result["document-metadata"] = self.doc_metadata_translator.from_pulsar(obj.document_metadata)
|
||||||
|
|
||||||
|
if obj.content:
|
||||||
|
result["content"] = obj.content.decode("utf-8") if isinstance(obj.content, bytes) else obj.content
|
||||||
|
|
||||||
|
if obj.document_metadatas is not None:
|
||||||
|
result["document-metadatas"] = [
|
||||||
|
self.doc_metadata_translator.from_pulsar(dm)
|
||||||
|
for dm in obj.document_metadatas
|
||||||
|
]
|
||||||
|
|
||||||
|
if obj.processing_metadatas is not None:
|
||||||
|
result["processing-metadatas"] = [
|
||||||
|
self.proc_metadata_translator.from_pulsar(pm)
|
||||||
|
for pm in obj.processing_metadatas
|
||||||
|
]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: LibrarianResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
81
trustgraph-base/trustgraph/messaging/translators/metadata.py
Normal file
81
trustgraph-base/trustgraph/messaging/translators/metadata.py
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from ...schema import DocumentMetadata, ProcessingMetadata
|
||||||
|
from .base import Translator
|
||||||
|
from .primitives import SubgraphTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentMetadataTranslator(Translator):
|
||||||
|
"""Translator for DocumentMetadata schema objects"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.subgraph_translator = SubgraphTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> DocumentMetadata:
|
||||||
|
metadata = data.get("metadata", [])
|
||||||
|
return DocumentMetadata(
|
||||||
|
id=data.get("id"),
|
||||||
|
time=data.get("time"),
|
||||||
|
kind=data.get("kind"),
|
||||||
|
title=data.get("title"),
|
||||||
|
comments=data.get("comments"),
|
||||||
|
metadata=self.subgraph_translator.to_pulsar(metadata) if metadata is not None else [],
|
||||||
|
user=data.get("user"),
|
||||||
|
tags=data.get("tags")
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: DocumentMetadata) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.id:
|
||||||
|
result["id"] = obj.id
|
||||||
|
if obj.time:
|
||||||
|
result["time"] = obj.time
|
||||||
|
if obj.kind:
|
||||||
|
result["kind"] = obj.kind
|
||||||
|
if obj.title:
|
||||||
|
result["title"] = obj.title
|
||||||
|
if obj.comments:
|
||||||
|
result["comments"] = obj.comments
|
||||||
|
if obj.metadata is not None:
|
||||||
|
result["metadata"] = self.subgraph_translator.from_pulsar(obj.metadata)
|
||||||
|
if obj.user:
|
||||||
|
result["user"] = obj.user
|
||||||
|
if obj.tags is not None:
|
||||||
|
result["tags"] = obj.tags
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingMetadataTranslator(Translator):
|
||||||
|
"""Translator for ProcessingMetadata schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> ProcessingMetadata:
|
||||||
|
return ProcessingMetadata(
|
||||||
|
id=data.get("id"),
|
||||||
|
document_id=data.get("document-id"),
|
||||||
|
time=data.get("time"),
|
||||||
|
flow=data.get("flow"),
|
||||||
|
user=data.get("user"),
|
||||||
|
collection=data.get("collection"),
|
||||||
|
tags=data.get("tags")
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: ProcessingMetadata) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.id:
|
||||||
|
result["id"] = obj.id
|
||||||
|
if obj.document_id:
|
||||||
|
result["document-id"] = obj.document_id
|
||||||
|
if obj.time:
|
||||||
|
result["time"] = obj.time
|
||||||
|
if obj.flow:
|
||||||
|
result["flow"] = obj.flow
|
||||||
|
if obj.user:
|
||||||
|
result["user"] = obj.user
|
||||||
|
if obj.collection:
|
||||||
|
result["collection"] = obj.collection
|
||||||
|
if obj.tags is not None:
|
||||||
|
result["tags"] = obj.tags
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
from ...schema import Value, Triple
|
||||||
|
from .base import Translator
|
||||||
|
|
||||||
|
|
||||||
|
class ValueTranslator(Translator):
|
||||||
|
"""Translator for Value schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> Value:
|
||||||
|
return Value(value=data["v"], is_uri=data["e"])
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: Value) -> Dict[str, Any]:
|
||||||
|
return {"v": obj.value, "e": obj.is_uri}
|
||||||
|
|
||||||
|
|
||||||
|
class TripleTranslator(Translator):
|
||||||
|
"""Translator for Triple schema objects"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.value_translator = ValueTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> Triple:
|
||||||
|
return Triple(
|
||||||
|
s=self.value_translator.to_pulsar(data["s"]),
|
||||||
|
p=self.value_translator.to_pulsar(data["p"]),
|
||||||
|
o=self.value_translator.to_pulsar(data["o"])
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: Triple) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"s": self.value_translator.from_pulsar(obj.s),
|
||||||
|
"p": self.value_translator.from_pulsar(obj.p),
|
||||||
|
"o": self.value_translator.from_pulsar(obj.o)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SubgraphTranslator(Translator):
|
||||||
|
"""Translator for lists of Triple objects (subgraphs)"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.triple_translator = TripleTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: List[Dict[str, Any]]) -> List[Triple]:
|
||||||
|
return [self.triple_translator.to_pulsar(t) for t in data]
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: List[Triple]) -> List[Dict[str, Any]]:
|
||||||
|
return [self.triple_translator.from_pulsar(t) for t in obj]
|
||||||
54
trustgraph-base/trustgraph/messaging/translators/prompt.py
Normal file
54
trustgraph-base/trustgraph/messaging/translators/prompt.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, Tuple
|
||||||
|
from ...schema import PromptRequest, PromptResponse
|
||||||
|
from .base import MessageTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class PromptRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for PromptRequest schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> PromptRequest:
|
||||||
|
# Handle both "terms" and "variables" input keys
|
||||||
|
terms = data.get("terms", {})
|
||||||
|
if "variables" in data:
|
||||||
|
# Convert variables to JSON strings as expected by the service
|
||||||
|
terms = {
|
||||||
|
k: json.dumps(v)
|
||||||
|
for k, v in data["variables"].items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return PromptRequest(
|
||||||
|
id=data.get("id"),
|
||||||
|
terms=terms
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: PromptRequest) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.id:
|
||||||
|
result["id"] = obj.id
|
||||||
|
if obj.terms:
|
||||||
|
result["terms"] = obj.terms
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class PromptResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for PromptResponse schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> PromptResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: PromptResponse) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
if obj.text:
|
||||||
|
result["text"] = obj.text
|
||||||
|
if obj.object:
|
||||||
|
result["object"] = obj.object
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: PromptResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
from typing import Dict, Any, Tuple
|
||||||
|
from ...schema import DocumentRagQuery, DocumentRagResponse, GraphRagQuery, GraphRagResponse
|
||||||
|
from .base import MessageTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentRagRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for DocumentRagQuery schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagQuery:
|
||||||
|
return DocumentRagQuery(
|
||||||
|
query=data["query"],
|
||||||
|
user=data.get("user", "trustgraph"),
|
||||||
|
collection=data.get("collection", "default"),
|
||||||
|
doc_limit=int(data.get("doc-limit", 20))
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: DocumentRagQuery) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"query": obj.query,
|
||||||
|
"user": obj.user,
|
||||||
|
"collection": obj.collection,
|
||||||
|
"doc-limit": obj.doc_limit
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentRagResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for DocumentRagResponse schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> DocumentRagResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: DocumentRagResponse) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"response": obj.response
|
||||||
|
}
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: DocumentRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRagRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for GraphRagQuery schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> GraphRagQuery:
|
||||||
|
return GraphRagQuery(
|
||||||
|
query=data["query"],
|
||||||
|
user=data.get("user", "trustgraph"),
|
||||||
|
collection=data.get("collection", "default"),
|
||||||
|
entity_limit=int(data.get("entity-limit", 50)),
|
||||||
|
triple_limit=int(data.get("triple-limit", 30)),
|
||||||
|
max_subgraph_size=int(data.get("max-subgraph-size", 1000)),
|
||||||
|
max_path_length=int(data.get("max-path-length", 2))
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: GraphRagQuery) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"query": obj.query,
|
||||||
|
"user": obj.user,
|
||||||
|
"collection": obj.collection,
|
||||||
|
"entity-limit": obj.entity_limit,
|
||||||
|
"triple-limit": obj.triple_limit,
|
||||||
|
"max-subgraph-size": obj.max_subgraph_size,
|
||||||
|
"max-path-length": obj.max_path_length
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRagResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for GraphRagResponse schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> GraphRagResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: GraphRagResponse) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"response": obj.response
|
||||||
|
}
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: GraphRagResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
from typing import Dict, Any, Tuple
|
||||||
|
from ...schema import TextCompletionRequest, TextCompletionResponse
|
||||||
|
from .base import MessageTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class TextCompletionRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for TextCompletionRequest schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> TextCompletionRequest:
|
||||||
|
return TextCompletionRequest(
|
||||||
|
system=data["system"],
|
||||||
|
prompt=data["prompt"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: TextCompletionRequest) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"system": obj.system,
|
||||||
|
"prompt": obj.prompt
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TextCompletionResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for TextCompletionResponse schema objects"""
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> TextCompletionResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: TextCompletionResponse) -> Dict[str, Any]:
|
||||||
|
result = {"response": obj.response}
|
||||||
|
|
||||||
|
if obj.in_token:
|
||||||
|
result["in_token"] = obj.in_token
|
||||||
|
if obj.out_token:
|
||||||
|
result["out_token"] = obj.out_token
|
||||||
|
if obj.model:
|
||||||
|
result["model"] = obj.model
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: TextCompletionResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
60
trustgraph-base/trustgraph/messaging/translators/triples.py
Normal file
60
trustgraph-base/trustgraph/messaging/translators/triples.py
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
from typing import Dict, Any, Tuple, Optional
|
||||||
|
from ...schema import TriplesQueryRequest, TriplesQueryResponse
|
||||||
|
from .base import MessageTranslator
|
||||||
|
from .primitives import ValueTranslator, SubgraphTranslator
|
||||||
|
|
||||||
|
|
||||||
|
class TriplesQueryRequestTranslator(MessageTranslator):
|
||||||
|
"""Translator for TriplesQueryRequest schema objects"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.value_translator = ValueTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> TriplesQueryRequest:
|
||||||
|
s = self.value_translator.to_pulsar(data["s"]) if "s" in data else None
|
||||||
|
p = self.value_translator.to_pulsar(data["p"]) if "p" in data else None
|
||||||
|
o = self.value_translator.to_pulsar(data["o"]) if "o" in data else None
|
||||||
|
|
||||||
|
return TriplesQueryRequest(
|
||||||
|
s=s,
|
||||||
|
p=p,
|
||||||
|
o=o,
|
||||||
|
limit=int(data.get("limit", 10000)),
|
||||||
|
user=data.get("user", "trustgraph"),
|
||||||
|
collection=data.get("collection", "default")
|
||||||
|
)
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: TriplesQueryRequest) -> Dict[str, Any]:
|
||||||
|
result = {
|
||||||
|
"limit": obj.limit,
|
||||||
|
"user": obj.user,
|
||||||
|
"collection": obj.collection
|
||||||
|
}
|
||||||
|
|
||||||
|
if obj.s:
|
||||||
|
result["s"] = self.value_translator.from_pulsar(obj.s)
|
||||||
|
if obj.p:
|
||||||
|
result["p"] = self.value_translator.from_pulsar(obj.p)
|
||||||
|
if obj.o:
|
||||||
|
result["o"] = self.value_translator.from_pulsar(obj.o)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TriplesQueryResponseTranslator(MessageTranslator):
|
||||||
|
"""Translator for TriplesQueryResponse schema objects"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.subgraph_translator = SubgraphTranslator()
|
||||||
|
|
||||||
|
def to_pulsar(self, data: Dict[str, Any]) -> TriplesQueryResponse:
|
||||||
|
raise NotImplementedError("Response translation to Pulsar not typically needed")
|
||||||
|
|
||||||
|
def from_pulsar(self, obj: TriplesQueryResponse) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"response": self.subgraph_translator.from_pulsar(obj.triples)
|
||||||
|
}
|
||||||
|
|
||||||
|
def from_response_with_completion(self, obj: TriplesQueryResponse) -> Tuple[Dict[str, Any], bool]:
|
||||||
|
"""Returns (response_dict, is_final)"""
|
||||||
|
return self.from_pulsar(obj), True
|
||||||
|
|
@ -1,80 +1,68 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Loads Graph embeddings into TrustGraph processing.
|
Loads triples into the knowledge graph.
|
||||||
|
|
||||||
FIXME: This hasn't been updated following API gateway change.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pulsar
|
import asyncio
|
||||||
from pulsar.schema import JsonSchema
|
|
||||||
from trustgraph.schema import Triples, Triple, Value, Metadata
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import pyarrow as pa
|
|
||||||
import rdflib
|
import rdflib
|
||||||
|
import json
|
||||||
|
from websockets.asyncio.client import connect
|
||||||
|
|
||||||
from trustgraph.log_level import LogLevel
|
from trustgraph.log_level import LogLevel
|
||||||
|
|
||||||
|
default_url = os.getenv("TRUSTGRAPH_URL", 'ws://localhost:8088/')
|
||||||
default_user = 'trustgraph'
|
default_user = 'trustgraph'
|
||||||
default_collection = 'default'
|
default_collection = 'default'
|
||||||
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
|
|
||||||
default_pulsar_api_key = os.getenv("PULSAR_API_KEY", None)
|
|
||||||
|
|
||||||
default_output_queue = triples_store_queue
|
|
||||||
|
|
||||||
class Loader:
|
class Loader:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pulsar_host,
|
|
||||||
output_queue,
|
|
||||||
log_level,
|
|
||||||
files,
|
files,
|
||||||
|
flow,
|
||||||
user,
|
user,
|
||||||
collection,
|
collection,
|
||||||
pulsar_api_key=None,
|
document_id,
|
||||||
|
url = default_url,
|
||||||
):
|
):
|
||||||
|
|
||||||
if pulsar_api_key:
|
if not url.endswith("/"):
|
||||||
auth = pulsar.AuthenticationToken(pulsar_api_key)
|
url += "/"
|
||||||
self.client = pulsar.Client(
|
|
||||||
pulsar_host,
|
|
||||||
authentication=auth,
|
|
||||||
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.client = pulsar.Client(
|
|
||||||
pulsar_host,
|
|
||||||
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
|
|
||||||
)
|
|
||||||
|
|
||||||
self.producer = self.client.create_producer(
|
url = url + f"api/v1/flow/{flow}/import/triples"
|
||||||
topic=output_queue,
|
|
||||||
schema=JsonSchema(Triples),
|
self.url = url
|
||||||
chunking_enabled=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.files = files
|
self.files = files
|
||||||
self.user = user
|
self.user = user
|
||||||
self.collection = collection
|
self.collection = collection
|
||||||
|
self.document_id = document_id
|
||||||
|
|
||||||
def run(self):
|
async def run(self):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
for file in self.files:
|
async with connect(self.url) as ws:
|
||||||
self.load_file(file)
|
for file in self.files:
|
||||||
|
await self.load_file(file, ws)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e, flush=True)
|
print(e, flush=True)
|
||||||
|
|
||||||
def load_file(self, file):
|
async def load_file(self, file, ws):
|
||||||
|
|
||||||
g = rdflib.Graph()
|
g = rdflib.Graph()
|
||||||
g.parse(file, format="turtle")
|
g.parse(file, format="turtle")
|
||||||
|
|
||||||
|
def Value(value, is_uri):
|
||||||
|
return { "v": value, "e": is_uri }
|
||||||
|
|
||||||
|
triples = []
|
||||||
|
|
||||||
for e in g:
|
for e in g:
|
||||||
s = Value(value=str(e[0]), is_uri=True)
|
s = Value(value=str(e[0]), is_uri=True)
|
||||||
p = Value(value=str(e[1]), is_uri=True)
|
p = Value(value=str(e[1]), is_uri=True)
|
||||||
|
|
@ -83,20 +71,23 @@ class Loader:
|
||||||
else:
|
else:
|
||||||
o = Value(value=str(e[2]), is_uri=False)
|
o = Value(value=str(e[2]), is_uri=False)
|
||||||
|
|
||||||
r = Triples(
|
req = {
|
||||||
metadata=Metadata(
|
"metadata": {
|
||||||
id=None,
|
"id": self.document_id,
|
||||||
metadata=[],
|
"metadata": [],
|
||||||
user=self.user,
|
"user": self.user,
|
||||||
collection=self.collection,
|
"collection": self.collection
|
||||||
),
|
},
|
||||||
triples=[ Triple(s=s, p=p, o=o) ]
|
"triples": [
|
||||||
)
|
{
|
||||||
|
"s": s,
|
||||||
|
"p": p,
|
||||||
|
"o": o,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
self.producer.send(r)
|
await ws.send(json.dumps(req))
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self.client.close()
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
|
|
@ -106,9 +97,15 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-p', '--pulsar-host',
|
'-u', '--api-url',
|
||||||
default=default_pulsar_host,
|
default=default_url,
|
||||||
help=f'Pulsar host (default: {default_pulsar_host})',
|
help=f'API URL (default: {default_url})',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-i', '--document-id',
|
||||||
|
required=True,
|
||||||
|
help=f'Document ID)',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -116,39 +113,19 @@ def main():
|
||||||
default="default",
|
default="default",
|
||||||
help=f'Flow ID (default: default)'
|
help=f'Flow ID (default: default)'
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--pulsar-api-key',
|
|
||||||
default=default_pulsar_api_key,
|
|
||||||
help=f'Pulsar API key',
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-o', '--output-queue',
|
'-U', '--user',
|
||||||
default=default_output_queue,
|
|
||||||
help=f'Output queue (default: {default_output_queue})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-u', '--user',
|
|
||||||
default=default_user,
|
default=default_user,
|
||||||
help=f'User ID (default: {default_user})'
|
help=f'User ID (default: {default_user})'
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-c', '--collection',
|
'-C', '--collection',
|
||||||
default=default_collection,
|
default=default_collection,
|
||||||
help=f'Collection ID (default: {default_collection})'
|
help=f'Collection ID (default: {default_collection})'
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-l', '--log-level',
|
|
||||||
type=LogLevel,
|
|
||||||
default=LogLevel.ERROR,
|
|
||||||
choices=list(LogLevel),
|
|
||||||
help=f'Output queue (default: info)'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'files', nargs='+',
|
'files', nargs='+',
|
||||||
help=f'Turtle files to load'
|
help=f'Turtle files to load'
|
||||||
|
|
@ -160,16 +137,15 @@ def main():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
p = Loader(
|
p = Loader(
|
||||||
pulsar_host=args.pulsar_host,
|
document_id = args.document_id,
|
||||||
pulsar_api_key=args.pulsar_api_key,
|
url = args.api_url,
|
||||||
output_queue=args.output_queue,
|
flow = args.flow_id,
|
||||||
log_level=args.log_level,
|
files = args.files,
|
||||||
files=args.files,
|
user = args.user,
|
||||||
user=args.user,
|
collection = args.collection,
|
||||||
collection=args.collection,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
p.run()
|
asyncio.run(p.run())
|
||||||
|
|
||||||
print("File loaded.")
|
print("File loaded.")
|
||||||
break
|
break
|
||||||
|
|
@ -181,6 +157,5 @@ def main():
|
||||||
|
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
|
||||||
print("Not implemented.")
|
main()
|
||||||
#main()
|
|
||||||
|
|
||||||
|
|
|
||||||
109
trustgraph-cli/scripts/tg-show-token-rate
Executable file
109
trustgraph-cli/scripts/tg-show-token-rate
Executable file
|
|
@ -0,0 +1,109 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
Dump out a stream of token rates, input, output and total. This is averaged
|
||||||
|
across the time since tg-show-token-rate is started.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
default_metrics_url = "http://localhost:8088/api/metrics"
|
||||||
|
|
||||||
|
class Collate:
|
||||||
|
|
||||||
|
def look(self, data):
|
||||||
|
return sum(
|
||||||
|
[
|
||||||
|
float(x["value"][1])
|
||||||
|
for x in data["data"]["result"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, data):
|
||||||
|
self.last = self.look(data)
|
||||||
|
self.total = 0
|
||||||
|
self.time = 0
|
||||||
|
|
||||||
|
def record(self, data, time):
|
||||||
|
|
||||||
|
value = self.look(data)
|
||||||
|
delta = value - self.last
|
||||||
|
self.last = value
|
||||||
|
|
||||||
|
self.total += delta
|
||||||
|
self.time += time
|
||||||
|
|
||||||
|
return delta/time, self.total/self.time
|
||||||
|
|
||||||
|
def dump_status(metrics_url, number_samples, period):
|
||||||
|
|
||||||
|
input_url = f"{metrics_url}/query?query=input_tokens_total"
|
||||||
|
output_url = f"{metrics_url}/query?query=output_tokens_total"
|
||||||
|
|
||||||
|
resp = requests.get(input_url)
|
||||||
|
obj = resp.json()
|
||||||
|
input = Collate(obj)
|
||||||
|
|
||||||
|
resp = requests.get(output_url)
|
||||||
|
obj = resp.json()
|
||||||
|
output = Collate(obj)
|
||||||
|
|
||||||
|
print(f"{'Input':>10s} {'Output':>10s} {'Total':>10s}")
|
||||||
|
print(f"{'-----':>10s} {'------':>10s} {'-----':>10s}")
|
||||||
|
|
||||||
|
for i in range(number_samples):
|
||||||
|
|
||||||
|
time.sleep(period)
|
||||||
|
|
||||||
|
resp = requests.get(input_url)
|
||||||
|
obj = resp.json()
|
||||||
|
inr, inl = input.record(obj, period)
|
||||||
|
|
||||||
|
resp = requests.get(output_url)
|
||||||
|
obj = resp.json()
|
||||||
|
outr, outl = output.record(obj, period)
|
||||||
|
|
||||||
|
print(f"{inl:10.1f} {outl:10.1f} {inl+outl:10.1f}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog='tg-show-processor-state',
|
||||||
|
description=__doc__,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-m', '--metrics-url',
|
||||||
|
default=default_metrics_url,
|
||||||
|
help=f'Metrics URL (default: {default_metrics_url})',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-p', '--period',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help=f'Metrics period (default: 1)',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-n', '--number-samples',
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help=f'Metrics period (default: 100)',
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
dump_status(**vars(args))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
print("Exception:", e, flush=True)
|
||||||
|
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
@ -80,6 +80,7 @@ setuptools.setup(
|
||||||
"scripts/tg-show-processor-state",
|
"scripts/tg-show-processor-state",
|
||||||
"scripts/tg-show-prompts",
|
"scripts/tg-show-prompts",
|
||||||
"scripts/tg-show-token-costs",
|
"scripts/tg-show-token-costs",
|
||||||
|
"scripts/tg-show-token-rate",
|
||||||
"scripts/tg-show-tools",
|
"scripts/tg-show-tools",
|
||||||
"scripts/tg-start-flow",
|
"scripts/tg-start-flow",
|
||||||
"scripts/tg-unload-kg-core",
|
"scripts/tg-unload-kg-core",
|
||||||
|
|
|
||||||
6
trustgraph-flow/scripts/rev-gateway
Executable file
6
trustgraph-flow/scripts/rev-gateway
Executable file
|
|
@ -0,0 +1,6 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from trustgraph.rev_gateway import run
|
||||||
|
|
||||||
|
run()
|
||||||
|
|
||||||
6
trustgraph-flow/scripts/text-completion-vllm
Executable file
6
trustgraph-flow/scripts/text-completion-vllm
Executable file
|
|
@ -0,0 +1,6 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from trustgraph.model.text_completion.vllm import run
|
||||||
|
|
||||||
|
run()
|
||||||
|
|
||||||
|
|
@ -71,6 +71,7 @@ setuptools.setup(
|
||||||
scripts=[
|
scripts=[
|
||||||
"scripts/agent-manager-react",
|
"scripts/agent-manager-react",
|
||||||
"scripts/api-gateway",
|
"scripts/api-gateway",
|
||||||
|
"scripts/rev-gateway",
|
||||||
"scripts/chunker-recursive",
|
"scripts/chunker-recursive",
|
||||||
"scripts/chunker-token",
|
"scripts/chunker-token",
|
||||||
"scripts/config-svc",
|
"scripts/config-svc",
|
||||||
|
|
@ -118,6 +119,7 @@ setuptools.setup(
|
||||||
"scripts/text-completion-ollama",
|
"scripts/text-completion-ollama",
|
||||||
"scripts/text-completion-openai",
|
"scripts/text-completion-openai",
|
||||||
"scripts/text-completion-tgi",
|
"scripts/text-completion-tgi",
|
||||||
|
"scripts/text-completion-vllm",
|
||||||
"scripts/triples-query-cassandra",
|
"scripts/triples-query-cassandra",
|
||||||
"scripts/triples-query-falkordb",
|
"scripts/triples-query-falkordb",
|
||||||
"scripts/triples-query-memgraph",
|
"scripts/triples-query-memgraph",
|
||||||
|
|
|
||||||
|
|
@ -21,16 +21,19 @@ RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
|
||||||
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
||||||
|
|
||||||
default_ident = "kg-extract-definitions"
|
default_ident = "kg-extract-definitions"
|
||||||
|
default_concurrency = 1
|
||||||
|
|
||||||
class Processor(FlowProcessor):
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
id = params.get("id")
|
id = params.get("id")
|
||||||
|
concurrency = params.get("concurrency", 1)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"id": id,
|
"id": id,
|
||||||
|
"concurrency": concurrency,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -38,7 +41,8 @@ class Processor(FlowProcessor):
|
||||||
ConsumerSpec(
|
ConsumerSpec(
|
||||||
name = "input",
|
name = "input",
|
||||||
schema = Chunk,
|
schema = Chunk,
|
||||||
handler = self.on_message
|
handler = self.on_message,
|
||||||
|
concurrency = concurrency,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -190,6 +194,13 @@ class Processor(FlowProcessor):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--concurrency',
|
||||||
|
type=int,
|
||||||
|
default=default_concurrency,
|
||||||
|
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||||
|
)
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
|
||||||
|
|
@ -20,16 +20,19 @@ RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
|
||||||
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
||||||
|
|
||||||
default_ident = "kg-extract-relationships"
|
default_ident = "kg-extract-relationships"
|
||||||
|
default_concurrency = 1
|
||||||
|
|
||||||
class Processor(FlowProcessor):
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
id = params.get("id")
|
id = params.get("id")
|
||||||
|
concurrency = params.get("concurrency", 1)
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"id": id,
|
"id": id,
|
||||||
|
"concurrency": concurrency,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -37,7 +40,8 @@ class Processor(FlowProcessor):
|
||||||
ConsumerSpec(
|
ConsumerSpec(
|
||||||
name = "input",
|
name = "input",
|
||||||
schema = Chunk,
|
schema = Chunk,
|
||||||
handler = self.on_message
|
handler = self.on_message,
|
||||||
|
concurrency = concurrency,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -192,6 +196,13 @@ class Processor(FlowProcessor):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--concurrency',
|
||||||
|
type=int,
|
||||||
|
default=default_concurrency,
|
||||||
|
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||||
|
)
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
|
||||||
from ... schema import AgentRequest, AgentResponse
|
from ... schema import AgentRequest, AgentResponse
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
|
|
@ -20,24 +21,12 @@ class AgentRequestor(ServiceRequestor):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("agent")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("agent")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
return AgentRequest(
|
return self.request_translator.to_pulsar(body)
|
||||||
question=body["question"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
resp = {
|
return self.response_translator.from_response_with_completion(message)
|
||||||
}
|
|
||||||
|
|
||||||
if message.answer:
|
|
||||||
resp["answer"] = message.answer
|
|
||||||
|
|
||||||
if message.thought:
|
|
||||||
resp["thought"] = message.thought
|
|
||||||
|
|
||||||
if message.observation:
|
|
||||||
resp["observation"] = message.observation
|
|
||||||
|
|
||||||
# The 2nd boolean expression indicates whether we're done responding
|
|
||||||
return resp, (message.answer is not None)
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
from ... schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
|
from ... schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
|
||||||
from ... schema import config_request_queue
|
from ... schema import config_request_queue
|
||||||
from ... schema import config_response_queue
|
from ... schema import config_response_queue
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
|
|
@ -19,60 +20,12 @@ class ConfigRequestor(ServiceRequestor):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("config")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("config")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
|
return self.request_translator.to_pulsar(body)
|
||||||
if "keys" in body:
|
|
||||||
keys = [
|
|
||||||
ConfigKey(
|
|
||||||
type = k["type"],
|
|
||||||
key = k["key"],
|
|
||||||
)
|
|
||||||
for k in body["keys"]
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
keys = None
|
|
||||||
|
|
||||||
if "values" in body:
|
|
||||||
values = [
|
|
||||||
ConfigValue(
|
|
||||||
type = v["type"],
|
|
||||||
key = v["key"],
|
|
||||||
value = v["value"],
|
|
||||||
)
|
|
||||||
for v in body["values"]
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
values = None
|
|
||||||
|
|
||||||
return ConfigRequest(
|
|
||||||
operation = body.get("operation", None),
|
|
||||||
keys = keys,
|
|
||||||
type = body.get("type", None),
|
|
||||||
values = values
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
|
return self.response_translator.from_response_with_completion(message)
|
||||||
response = { }
|
|
||||||
|
|
||||||
if message.version is not None:
|
|
||||||
response["version"] = message.version
|
|
||||||
|
|
||||||
if message.values is not None:
|
|
||||||
response["values"] = [
|
|
||||||
{
|
|
||||||
"type": v.type,
|
|
||||||
"key": v.key,
|
|
||||||
"value": v.value,
|
|
||||||
}
|
|
||||||
for v in message.values
|
|
||||||
]
|
|
||||||
|
|
||||||
if message.directory is not None:
|
|
||||||
response["directory"] = message.directory
|
|
||||||
|
|
||||||
if message.config is not None:
|
|
||||||
response["config"] = message.config
|
|
||||||
|
|
||||||
return response, True
|
|
||||||
|
|
||||||
|
|
|
||||||
96
trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
Normal file
96
trustgraph-flow/trustgraph/gateway/dispatch/core_export.py
Normal file
|
|
@ -0,0 +1,96 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
import msgpack
|
||||||
|
from . knowledge import KnowledgeRequestor
|
||||||
|
|
||||||
|
class CoreExport:
|
||||||
|
|
||||||
|
def __init__(self, pulsar_client):
|
||||||
|
self.pulsar_client = pulsar_client
|
||||||
|
|
||||||
|
async def process(self, data, error, ok, request):
|
||||||
|
|
||||||
|
id = request.query["id"]
|
||||||
|
user = request.query["user"]
|
||||||
|
|
||||||
|
response = await ok()
|
||||||
|
|
||||||
|
kr = KnowledgeRequestor(
|
||||||
|
pulsar_client = self.pulsar_client,
|
||||||
|
consumer = "api-gateway-core-export-" + str(uuid.uuid4()),
|
||||||
|
subscriber = "api-gateway-core-export-" + str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
await kr.start()
|
||||||
|
|
||||||
|
async def responder(resp, fin):
|
||||||
|
|
||||||
|
if "graph-embeddings" in resp:
|
||||||
|
|
||||||
|
data = resp["graph-embeddings"]
|
||||||
|
|
||||||
|
msg = (
|
||||||
|
"ge",
|
||||||
|
{
|
||||||
|
"m": {
|
||||||
|
"i": data["metadata"]["id"],
|
||||||
|
"m": data["metadata"]["metadata"],
|
||||||
|
"u": data["metadata"]["user"],
|
||||||
|
"c": data["metadata"]["collection"],
|
||||||
|
},
|
||||||
|
"e": [
|
||||||
|
{
|
||||||
|
"e": ent["entity"],
|
||||||
|
"v": ent["vectors"],
|
||||||
|
}
|
||||||
|
for ent in data["entities"]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
enc = msgpack.packb(msg)
|
||||||
|
await response.write(enc)
|
||||||
|
|
||||||
|
if "triples" in resp:
|
||||||
|
|
||||||
|
data = resp["triples"]
|
||||||
|
msg = (
|
||||||
|
"t",
|
||||||
|
{
|
||||||
|
"m": {
|
||||||
|
"i": data["metadata"]["id"],
|
||||||
|
"m": data["metadata"]["metadata"],
|
||||||
|
"u": data["metadata"]["user"],
|
||||||
|
"c": data["metadata"]["collection"],
|
||||||
|
},
|
||||||
|
"t": data["triples"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
enc = msgpack.packb(msg)
|
||||||
|
await response.write(enc)
|
||||||
|
|
||||||
|
await kr.process(
|
||||||
|
{
|
||||||
|
"operation": "get-kg-core",
|
||||||
|
"user": user,
|
||||||
|
"id": id,
|
||||||
|
},
|
||||||
|
responder
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
print("Exception:", e)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
|
||||||
|
await kr.stop()
|
||||||
|
|
||||||
|
await response.write_eof()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
94
trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
Normal file
94
trustgraph-flow/trustgraph/gateway/dispatch/core_import.py
Normal file
|
|
@ -0,0 +1,94 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import msgpack
|
||||||
|
from . knowledge import KnowledgeRequestor
|
||||||
|
|
||||||
|
class CoreImport:
|
||||||
|
|
||||||
|
def __init__(self, pulsar_client):
|
||||||
|
self.pulsar_client = pulsar_client
|
||||||
|
|
||||||
|
async def process(self, data, error, ok, request):
|
||||||
|
|
||||||
|
id = request.query["id"]
|
||||||
|
user = request.query["user"]
|
||||||
|
|
||||||
|
kr = KnowledgeRequestor(
|
||||||
|
pulsar_client = self.pulsar_client,
|
||||||
|
consumer = "api-gateway-core-import-" + str(uuid.uuid4()),
|
||||||
|
subscriber = "api-gateway-core-import-" + str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
|
||||||
|
await kr.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
unpacker = msgpack.Unpacker()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
buf = await data.read(128*1024)
|
||||||
|
if not buf: break
|
||||||
|
|
||||||
|
unpacker.feed(buf)
|
||||||
|
|
||||||
|
for unpacked in unpacker:
|
||||||
|
|
||||||
|
if unpacked[0] == "t":
|
||||||
|
msg = unpacked[1]
|
||||||
|
msg = {
|
||||||
|
"operation": "put-kg-core",
|
||||||
|
"user": user,
|
||||||
|
"id": id,
|
||||||
|
"triples": {
|
||||||
|
"metadata": {
|
||||||
|
"id": id,
|
||||||
|
"metadata": msg["m"]["m"],
|
||||||
|
"user": user,
|
||||||
|
"collection": "default", # Not used?
|
||||||
|
},
|
||||||
|
"triples": msg["t"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await kr.process(msg)
|
||||||
|
|
||||||
|
elif unpacked[0] == "ge":
|
||||||
|
msg = unpacked[1]
|
||||||
|
msg = {
|
||||||
|
"operation": "put-kg-core",
|
||||||
|
"user": user,
|
||||||
|
"id": id,
|
||||||
|
"graph-embeddings": {
|
||||||
|
"metadata": {
|
||||||
|
"id": id,
|
||||||
|
"metadata": msg["m"]["m"],
|
||||||
|
"user": user,
|
||||||
|
"collection": "default", # Not used?
|
||||||
|
},
|
||||||
|
"entities": [
|
||||||
|
{
|
||||||
|
"entity": ent["e"],
|
||||||
|
"vectors": ent["v"],
|
||||||
|
}
|
||||||
|
for ent in msg["e"]
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await kr.process(msg)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("Exception:", e)
|
||||||
|
await error(str(e))
|
||||||
|
|
||||||
|
finally:
|
||||||
|
|
||||||
|
await kr.stop()
|
||||||
|
|
||||||
|
print("All done.")
|
||||||
|
response = await ok()
|
||||||
|
await response.write_eof()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
@ -6,8 +6,7 @@ from aiohttp import WSMsgType
|
||||||
from ... schema import Metadata
|
from ... schema import Metadata
|
||||||
from ... schema import DocumentEmbeddings, ChunkEmbeddings
|
from ... schema import DocumentEmbeddings, ChunkEmbeddings
|
||||||
from ... base import Publisher
|
from ... base import Publisher
|
||||||
|
from ... messaging.translators.document_loading import DocumentEmbeddingsTranslator
|
||||||
from . serialize import to_subgraph
|
|
||||||
|
|
||||||
class DocumentEmbeddingsImport:
|
class DocumentEmbeddingsImport:
|
||||||
|
|
||||||
|
|
@ -17,6 +16,7 @@ class DocumentEmbeddingsImport:
|
||||||
|
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
self.running = running
|
self.running = running
|
||||||
|
self.translator = DocumentEmbeddingsTranslator()
|
||||||
|
|
||||||
self.publisher = Publisher(
|
self.publisher = Publisher(
|
||||||
pulsar_client, topic = queue, schema = DocumentEmbeddings
|
pulsar_client, topic = queue, schema = DocumentEmbeddings
|
||||||
|
|
@ -36,23 +36,7 @@ class DocumentEmbeddingsImport:
|
||||||
async def receive(self, msg):
|
async def receive(self, msg):
|
||||||
|
|
||||||
data = msg.json()
|
data = msg.json()
|
||||||
|
elt = self.translator.to_pulsar(data)
|
||||||
elt = DocumentEmbeddings(
|
|
||||||
metadata=Metadata(
|
|
||||||
id=data["metadata"]["id"],
|
|
||||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
|
||||||
user=data["metadata"]["user"],
|
|
||||||
collection=data["metadata"]["collection"],
|
|
||||||
),
|
|
||||||
chunks=[
|
|
||||||
ChunkEmbeddings(
|
|
||||||
chunk=de["chunk"].encode("utf-8"),
|
|
||||||
vectors=de["vectors"],
|
|
||||||
)
|
|
||||||
for de in data["chunks"]
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.publisher.send(None, elt)
|
await self.publisher.send(None, elt)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
from ... schema import Document, Metadata
|
from ... schema import Document, Metadata
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . sender import ServiceSender
|
from . sender import ServiceSender
|
||||||
from . serialize import to_subgraph
|
|
||||||
|
|
||||||
class DocumentLoad(ServiceSender):
|
class DocumentLoad(ServiceSender):
|
||||||
def __init__(self, pulsar_client, queue):
|
def __init__(self, pulsar_client, queue):
|
||||||
|
|
@ -15,26 +15,9 @@ class DocumentLoad(ServiceSender):
|
||||||
schema = Document,
|
schema = Document,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.translator = TranslatorRegistry.get_request_translator("document")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
|
|
||||||
if "metadata" in body:
|
|
||||||
metadata = to_subgraph(body["metadata"])
|
|
||||||
else:
|
|
||||||
metadata = []
|
|
||||||
|
|
||||||
# Doing a base64 decoe/encode here to make sure the
|
|
||||||
# content is valid base64
|
|
||||||
doc = base64.b64decode(body["data"])
|
|
||||||
|
|
||||||
print("Document received")
|
print("Document received")
|
||||||
|
return self.translator.to_pulsar(body)
|
||||||
return Document(
|
|
||||||
metadata=Metadata(
|
|
||||||
id=body.get("id"),
|
|
||||||
metadata=metadata,
|
|
||||||
user=body.get("user", "trustgraph"),
|
|
||||||
collection=body.get("collection", "default"),
|
|
||||||
),
|
|
||||||
data=base64.b64encode(doc).decode("utf-8")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
|
||||||
from ... schema import DocumentRagQuery, DocumentRagResponse
|
from ... schema import DocumentRagQuery, DocumentRagResponse
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
|
|
@ -20,14 +21,12 @@ class DocumentRagRequestor(ServiceRequestor):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("document-rag")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("document-rag")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
return DocumentRagQuery(
|
return self.request_translator.to_pulsar(body)
|
||||||
query=body["query"],
|
|
||||||
user=body.get("user", "trustgraph"),
|
|
||||||
collection=body.get("collection", "default"),
|
|
||||||
doc_limit=int(body.get("doc-limit", 20)),
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
return { "response": message.response }, True
|
return self.response_translator.from_response_with_completion(message)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
|
||||||
from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
|
|
@ -20,11 +21,12 @@ class EmbeddingsRequestor(ServiceRequestor):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("embeddings")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("embeddings")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
return EmbeddingsRequest(
|
return self.request_translator.to_pulsar(body)
|
||||||
text=body["text"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
return { "vectors": message.vectors }, True
|
return self.response_translator.from_response_with_completion(message)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
from ... schema import FlowRequest, FlowResponse
|
from ... schema import FlowRequest, FlowResponse
|
||||||
from ... schema import flow_request_queue
|
from ... schema import flow_request_queue
|
||||||
from ... schema import flow_response_queue
|
from ... schema import flow_response_queue
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
|
|
@ -19,34 +20,12 @@ class FlowRequestor(ServiceRequestor):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_request(self, body):
|
self.request_translator = TranslatorRegistry.get_request_translator("flow")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("flow")
|
||||||
|
|
||||||
return FlowRequest(
|
def to_request(self, body):
|
||||||
operation = body.get("operation", None),
|
return self.request_translator.to_pulsar(body)
|
||||||
class_name = body.get("class-name", None),
|
|
||||||
class_definition = body.get("class-definition", None),
|
|
||||||
description = body.get("description", None),
|
|
||||||
flow_id = body.get("flow-id", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
|
return self.response_translator.from_response_with_completion(message)
|
||||||
response = { }
|
|
||||||
|
|
||||||
if message.class_names is not None:
|
|
||||||
response["class-names"] = message.class_names
|
|
||||||
|
|
||||||
if message.flow_ids is not None:
|
|
||||||
response["flow-ids"] = message.flow_ids
|
|
||||||
|
|
||||||
if message.class_definition is not None:
|
|
||||||
response["class-definition"] = message.class_definition
|
|
||||||
|
|
||||||
if message.flow is not None:
|
|
||||||
response["flow"] = message.flow
|
|
||||||
|
|
||||||
if message.description is not None:
|
|
||||||
response["description"] = message.description
|
|
||||||
|
|
||||||
return response, True
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
|
|
||||||
from ... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
from ... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
from . serialize import serialize_value
|
|
||||||
|
|
||||||
class GraphEmbeddingsQueryRequestor(ServiceRequestor):
|
class GraphEmbeddingsQueryRequestor(ServiceRequestor):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -21,22 +21,12 @@ class GraphEmbeddingsQueryRequestor(ServiceRequestor):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("graph-embeddings-query")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("graph-embeddings-query")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
|
return self.request_translator.to_pulsar(body)
|
||||||
limit = int(body.get("limit", 20))
|
|
||||||
|
|
||||||
return GraphEmbeddingsRequest(
|
|
||||||
vectors = body["vectors"],
|
|
||||||
limit = limit,
|
|
||||||
user = body.get("user", "trustgraph"),
|
|
||||||
collection = body.get("collection", "default"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
|
return self.response_translator.from_response_with_completion(message)
|
||||||
return {
|
|
||||||
"entities": [
|
|
||||||
serialize_value(ent) for ent in message.entities
|
|
||||||
]
|
|
||||||
}, True
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
|
||||||
from ... schema import GraphRagQuery, GraphRagResponse
|
from ... schema import GraphRagQuery, GraphRagResponse
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
|
|
@ -20,17 +21,12 @@ class GraphRagRequestor(ServiceRequestor):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("graph-rag")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("graph-rag")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
return GraphRagQuery(
|
return self.request_translator.to_pulsar(body)
|
||||||
query=body["query"],
|
|
||||||
user=body.get("user", "trustgraph"),
|
|
||||||
collection=body.get("collection", "default"),
|
|
||||||
entity_limit=int(body.get("entity-limit", 50)),
|
|
||||||
triple_limit=int(body.get("triple-limit", 30)),
|
|
||||||
max_subgraph_size=int(body.get("max-subgraph-size", 1000)),
|
|
||||||
max_path_length=int(body.get("max-path-length", 2)),
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
return { "response": message.response }, True
|
return self.response_translator.from_response_with_completion(message)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,9 @@ from ... schema import KnowledgeRequest, KnowledgeResponse, Triples
|
||||||
from ... schema import GraphEmbeddings, Metadata, EntityEmbeddings
|
from ... schema import GraphEmbeddings, Metadata, EntityEmbeddings
|
||||||
from ... schema import knowledge_request_queue
|
from ... schema import knowledge_request_queue
|
||||||
from ... schema import knowledge_response_queue
|
from ... schema import knowledge_response_queue
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
from . serialize import serialize_graph_embeddings
|
|
||||||
from . serialize import serialize_triples, to_subgraph, to_value
|
|
||||||
from . serialize import to_document_metadata, to_processing_metadata
|
|
||||||
|
|
||||||
class KnowledgeRequestor(ServiceRequestor):
|
class KnowledgeRequestor(ServiceRequestor):
|
||||||
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
|
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
|
||||||
|
|
@ -25,73 +23,12 @@ class KnowledgeRequestor(ServiceRequestor):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("knowledge")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("knowledge")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
|
return self.request_translator.to_pulsar(body)
|
||||||
if "triples" in body:
|
|
||||||
triples = Triples(
|
|
||||||
metadata=Metadata(
|
|
||||||
id = body["triples"]["metadata"]["id"],
|
|
||||||
metadata = to_subgraph(body["triples"]["metadata"]["metadata"]),
|
|
||||||
user = body["triples"]["metadata"]["user"],
|
|
||||||
),
|
|
||||||
triples = to_subgraph(body["triples"]["triples"]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
triples = None
|
|
||||||
|
|
||||||
if "graph-embeddings" in body:
|
|
||||||
ge = GraphEmbeddings(
|
|
||||||
metadata = Metadata(
|
|
||||||
id = body["graph-embeddings"]["metadata"]["id"],
|
|
||||||
metadata = to_subgraph(body["graph-embeddings"]["metadata"]["metadata"]),
|
|
||||||
user = body["graph-embeddings"]["metadata"]["user"],
|
|
||||||
),
|
|
||||||
entities=[
|
|
||||||
EntityEmbeddings(
|
|
||||||
entity = to_value(ent["entity"]),
|
|
||||||
vectors = ent["vectors"],
|
|
||||||
)
|
|
||||||
for ent in body["graph-embeddings"]["entities"]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ge = None
|
|
||||||
|
|
||||||
return KnowledgeRequest(
|
|
||||||
operation = body.get("operation", None),
|
|
||||||
user = body.get("user", None),
|
|
||||||
id = body.get("id", None),
|
|
||||||
flow = body.get("flow", None),
|
|
||||||
collection = body.get("collection", None),
|
|
||||||
triples = triples,
|
|
||||||
graph_embeddings = ge,
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
|
return self.response_translator.from_response_with_completion(message)
|
||||||
# Response to list,
|
|
||||||
if message.ids is not None:
|
|
||||||
return {
|
|
||||||
"ids": message.ids
|
|
||||||
}, True
|
|
||||||
|
|
||||||
if message.triples:
|
|
||||||
return {
|
|
||||||
"triples": serialize_triples(message.triples)
|
|
||||||
}, False
|
|
||||||
|
|
||||||
if message.graph_embeddings:
|
|
||||||
return {
|
|
||||||
"graph-embeddings": serialize_graph_embeddings(
|
|
||||||
message.graph_embeddings
|
|
||||||
)
|
|
||||||
}, False
|
|
||||||
|
|
||||||
if message.eos is True:
|
|
||||||
return {
|
|
||||||
"eos": True
|
|
||||||
}, True
|
|
||||||
|
|
||||||
# Empty case, return from successful delete.
|
|
||||||
return {}, True
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,9 @@ import base64
|
||||||
from ... schema import LibrarianRequest, LibrarianResponse
|
from ... schema import LibrarianRequest, LibrarianResponse
|
||||||
from ... schema import librarian_request_queue
|
from ... schema import librarian_request_queue
|
||||||
from ... schema import librarian_response_queue
|
from ... schema import librarian_response_queue
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
from . serialize import serialize_document_metadata
|
|
||||||
from . serialize import serialize_processing_metadata
|
|
||||||
from . serialize import to_document_metadata, to_processing_metadata
|
|
||||||
from . serialize import to_criteria
|
|
||||||
|
|
||||||
class LibrarianRequestor(ServiceRequestor):
|
class LibrarianRequestor(ServiceRequestor):
|
||||||
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
|
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
|
||||||
|
|
@ -25,67 +22,20 @@ class LibrarianRequestor(ServiceRequestor):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("librarian")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("librarian")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
|
# Handle base64 content processing
|
||||||
# Content gets base64 decoded & encoded again. It at least makes
|
|
||||||
# sure payload is valid base64.
|
|
||||||
|
|
||||||
if "document-metadata" in body:
|
|
||||||
dm = to_document_metadata(body["document-metadata"])
|
|
||||||
else:
|
|
||||||
dm = None
|
|
||||||
|
|
||||||
if "processing-metadata" in body:
|
|
||||||
pm = to_processing_metadata(body["processing-metadata"])
|
|
||||||
else:
|
|
||||||
pm = None
|
|
||||||
|
|
||||||
if "criteria" in body:
|
|
||||||
criteria = to_criteria(body["criteria"])
|
|
||||||
else:
|
|
||||||
criteria = None
|
|
||||||
|
|
||||||
if "content" in body:
|
if "content" in body:
|
||||||
|
# Content gets base64 decoded & encoded again to ensure valid base64
|
||||||
content = base64.b64decode(body["content"].encode("utf-8"))
|
content = base64.b64decode(body["content"].encode("utf-8"))
|
||||||
content = base64.b64encode(content).decode("utf-8")
|
content = base64.b64encode(content).decode("utf-8")
|
||||||
else:
|
body = body.copy()
|
||||||
content = None
|
body["content"] = content
|
||||||
|
|
||||||
return LibrarianRequest(
|
return self.request_translator.to_pulsar(body)
|
||||||
operation = body.get("operation", None),
|
|
||||||
document_id = body.get("document-id", None),
|
|
||||||
processing_id = body.get("processing-id", None),
|
|
||||||
document_metadata = dm,
|
|
||||||
processing_metadata = pm,
|
|
||||||
content = content,
|
|
||||||
user = body.get("user", None),
|
|
||||||
collection = body.get("collection", None),
|
|
||||||
criteria = criteria,
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
|
return self.response_translator.from_response_with_completion(message)
|
||||||
response = {}
|
|
||||||
|
|
||||||
if message.document_metadata:
|
|
||||||
response["document-metadata"] = serialize_document_metadata(
|
|
||||||
message.document_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
if message.content:
|
|
||||||
response["content"] = message.content.decode("utf-8")
|
|
||||||
|
|
||||||
if message.document_metadatas != None:
|
|
||||||
response["document-metadatas"] = [
|
|
||||||
serialize_document_metadata(v)
|
|
||||||
for v in message.document_metadatas
|
|
||||||
]
|
|
||||||
|
|
||||||
if message.processing_metadatas != None:
|
|
||||||
response["processing-metadatas"] = [
|
|
||||||
serialize_processing_metadata(v)
|
|
||||||
for v in message.processing_metadatas
|
|
||||||
]
|
|
||||||
|
|
||||||
return response, True
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from aiohttp import web
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from . config import ConfigRequestor
|
from . config import ConfigRequestor
|
||||||
|
|
@ -30,6 +31,9 @@ from . graph_embeddings_import import GraphEmbeddingsImport
|
||||||
from . document_embeddings_import import DocumentEmbeddingsImport
|
from . document_embeddings_import import DocumentEmbeddingsImport
|
||||||
from . entity_contexts_import import EntityContextsImport
|
from . entity_contexts_import import EntityContextsImport
|
||||||
|
|
||||||
|
from . core_export import CoreExport
|
||||||
|
from . core_import import CoreImport
|
||||||
|
|
||||||
from . mux import Mux
|
from . mux import Mux
|
||||||
|
|
||||||
request_response_dispatchers = {
|
request_response_dispatchers = {
|
||||||
|
|
@ -77,10 +81,11 @@ class DispatcherWrapper:
|
||||||
|
|
||||||
class DispatcherManager:
|
class DispatcherManager:
|
||||||
|
|
||||||
def __init__(self, pulsar_client, config_receiver):
|
def __init__(self, pulsar_client, config_receiver, prefix="api-gateway"):
|
||||||
self.pulsar_client = pulsar_client
|
self.pulsar_client = pulsar_client
|
||||||
self.config_receiver = config_receiver
|
self.config_receiver = config_receiver
|
||||||
self.config_receiver.add_handler(self)
|
self.config_receiver.add_handler(self)
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
self.flows = {}
|
self.flows = {}
|
||||||
self.dispatchers = {}
|
self.dispatchers = {}
|
||||||
|
|
@ -98,6 +103,22 @@ class DispatcherManager:
|
||||||
def dispatch_global_service(self):
|
def dispatch_global_service(self):
|
||||||
return DispatcherWrapper(self.process_global_service)
|
return DispatcherWrapper(self.process_global_service)
|
||||||
|
|
||||||
|
def dispatch_core_export(self):
|
||||||
|
return DispatcherWrapper(self.process_core_export)
|
||||||
|
|
||||||
|
def dispatch_core_import(self):
|
||||||
|
return DispatcherWrapper(self.process_core_import)
|
||||||
|
|
||||||
|
async def process_core_import(self, data, error, ok, request):
|
||||||
|
|
||||||
|
ci = CoreImport(self.pulsar_client)
|
||||||
|
return await ci.process(data, error, ok, request)
|
||||||
|
|
||||||
|
async def process_core_export(self, data, error, ok, request):
|
||||||
|
|
||||||
|
ce = CoreExport(self.pulsar_client)
|
||||||
|
return await ce.process(data, error, ok, request)
|
||||||
|
|
||||||
async def process_global_service(self, data, responder, params):
|
async def process_global_service(self, data, responder, params):
|
||||||
|
|
||||||
kind = params.get("kind")
|
kind = params.get("kind")
|
||||||
|
|
@ -113,8 +134,8 @@ class DispatcherManager:
|
||||||
dispatcher = global_dispatchers[kind](
|
dispatcher = global_dispatchers[kind](
|
||||||
pulsar_client = self.pulsar_client,
|
pulsar_client = self.pulsar_client,
|
||||||
timeout = 120,
|
timeout = 120,
|
||||||
consumer = f"api-gateway-{kind}-request",
|
consumer = f"{self.prefix}-{kind}-request",
|
||||||
subscriber = f"api-gateway-{kind}-request",
|
subscriber = f"{self.prefix}-{kind}-request",
|
||||||
)
|
)
|
||||||
|
|
||||||
await dispatcher.start()
|
await dispatcher.start()
|
||||||
|
|
@ -206,8 +227,8 @@ class DispatcherManager:
|
||||||
ws = ws,
|
ws = ws,
|
||||||
running = running,
|
running = running,
|
||||||
queue = qconfig,
|
queue = qconfig,
|
||||||
consumer = f"api-gateway-{id}",
|
consumer = f"{self.prefix}-{id}",
|
||||||
subscriber = f"api-gateway-{id}",
|
subscriber = f"{self.prefix}-{id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
return dispatcher
|
return dispatcher
|
||||||
|
|
@ -248,8 +269,8 @@ class DispatcherManager:
|
||||||
request_queue = qconfig["request"],
|
request_queue = qconfig["request"],
|
||||||
response_queue = qconfig["response"],
|
response_queue = qconfig["response"],
|
||||||
timeout = 120,
|
timeout = 120,
|
||||||
consumer = f"api-gateway-{flow}-{kind}-request",
|
consumer = f"{self.prefix}-{flow}-{kind}-request",
|
||||||
subscriber = f"api-gateway-{flow}-{kind}-request",
|
subscriber = f"{self.prefix}-{flow}-{kind}-request",
|
||||||
)
|
)
|
||||||
elif kind in sender_dispatchers:
|
elif kind in sender_dispatchers:
|
||||||
dispatcher = sender_dispatchers[kind](
|
dispatcher = sender_dispatchers[kind](
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from ... schema import PromptRequest, PromptResponse
|
from ... schema import PromptRequest, PromptResponse
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
|
|
@ -22,22 +23,12 @@ class PromptRequestor(ServiceRequestor):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("prompt")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("prompt")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
return PromptRequest(
|
return self.request_translator.to_pulsar(body)
|
||||||
id=body["id"],
|
|
||||||
terms={
|
|
||||||
k: json.dumps(v)
|
|
||||||
for k, v in body["variables"].items()
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
if message.object:
|
return self.response_translator.from_response_with_completion(message)
|
||||||
return {
|
|
||||||
"object": message.object
|
|
||||||
}, True
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"text": message.text
|
|
||||||
}, True
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,13 @@ import base64
|
||||||
|
|
||||||
from ... schema import Value, Triple, DocumentMetadata, ProcessingMetadata
|
from ... schema import Value, Triple, DocumentMetadata, ProcessingMetadata
|
||||||
|
|
||||||
|
# DEPRECATED: These functions have been moved to trustgraph.... messaging.translators
|
||||||
|
# Use the new messaging translation system instead for consistency and reusability.
|
||||||
|
# Examples:
|
||||||
|
# from trustgraph.... messaging.translators.primitives import ValueTranslator
|
||||||
|
# value_translator = ValueTranslator()
|
||||||
|
# pulsar_value = value_translator.to_pulsar({"v": "example", "e": True})
|
||||||
|
|
||||||
def to_value(x):
|
def to_value(x):
|
||||||
return Value(value=x["v"], is_uri=x["e"])
|
return Value(value=x["v"], is_uri=x["e"])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
|
||||||
from ... schema import TextCompletionRequest, TextCompletionResponse
|
from ... schema import TextCompletionRequest, TextCompletionResponse
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
|
|
@ -20,12 +21,12 @@ class TextCompletionRequestor(ServiceRequestor):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("text-completion")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("text-completion")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
return TextCompletionRequest(
|
return self.request_translator.to_pulsar(body)
|
||||||
system=body["system"],
|
|
||||||
prompt=body["prompt"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
return { "response": message.response }, True
|
return self.response_translator.from_response_with_completion(message)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
from ... schema import TextDocument, Metadata
|
from ... schema import TextDocument, Metadata
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . sender import ServiceSender
|
from . sender import ServiceSender
|
||||||
from . serialize import to_subgraph
|
|
||||||
|
|
||||||
class TextLoad(ServiceSender):
|
class TextLoad(ServiceSender):
|
||||||
def __init__(self, pulsar_client, queue):
|
def __init__(self, pulsar_client, queue):
|
||||||
|
|
@ -15,30 +15,9 @@ class TextLoad(ServiceSender):
|
||||||
schema = TextDocument,
|
schema = TextDocument,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.translator = TranslatorRegistry.get_request_translator("text-document")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
|
|
||||||
if "metadata" in body:
|
|
||||||
metadata = to_subgraph(body["metadata"])
|
|
||||||
else:
|
|
||||||
metadata = []
|
|
||||||
|
|
||||||
if "charset" in body:
|
|
||||||
charset = body["charset"]
|
|
||||||
else:
|
|
||||||
charset = "utf-8"
|
|
||||||
|
|
||||||
# Text is base64 encoded
|
|
||||||
text = base64.b64decode(body["text"]).decode(charset)
|
|
||||||
|
|
||||||
print("Text document received")
|
print("Text document received")
|
||||||
|
return self.translator.to_pulsar(body)
|
||||||
return TextDocument(
|
|
||||||
metadata=Metadata(
|
|
||||||
id=body.get("id"),
|
|
||||||
metadata=metadata,
|
|
||||||
user=body.get("user", "trustgraph"),
|
|
||||||
collection=body.get("collection", "default"),
|
|
||||||
),
|
|
||||||
text=text,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
|
|
||||||
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples
|
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples
|
||||||
|
from ... messaging import TranslatorRegistry
|
||||||
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
from . serialize import to_value, serialize_subgraph
|
|
||||||
|
|
||||||
class TriplesQueryRequestor(ServiceRequestor):
|
class TriplesQueryRequestor(ServiceRequestor):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -21,34 +21,12 @@ class TriplesQueryRequestor(ServiceRequestor):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.request_translator = TranslatorRegistry.get_request_translator("triples-query")
|
||||||
|
self.response_translator = TranslatorRegistry.get_response_translator("triples-query")
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
|
return self.request_translator.to_pulsar(body)
|
||||||
if "s" in body:
|
|
||||||
s = to_value(body["s"])
|
|
||||||
else:
|
|
||||||
s = None
|
|
||||||
|
|
||||||
if "p" in body:
|
|
||||||
p = to_value(body["p"])
|
|
||||||
else:
|
|
||||||
p = None
|
|
||||||
|
|
||||||
if "o" in body:
|
|
||||||
o = to_value(body["o"])
|
|
||||||
else:
|
|
||||||
o = None
|
|
||||||
|
|
||||||
limit = int(body.get("limit", 10000))
|
|
||||||
|
|
||||||
return TriplesQueryRequest(
|
|
||||||
s = s, p = p, o = o,
|
|
||||||
limit = limit,
|
|
||||||
user = body.get("user", "trustgraph"),
|
|
||||||
collection = body.get("collection", "default"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def from_response(self, message):
|
def from_response(self, message):
|
||||||
return {
|
return self.response_translator.from_response_with_completion(message)
|
||||||
"response": serialize_subgraph(message.triples)
|
|
||||||
}, True
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import asyncio
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from . constant_endpoint import ConstantEndpoint
|
from . stream_endpoint import StreamEndpoint
|
||||||
from . variable_endpoint import VariableEndpoint
|
from . variable_endpoint import VariableEndpoint
|
||||||
from . socket import SocketEndpoint
|
from . socket import SocketEndpoint
|
||||||
from . metrics import MetricsEndpoint
|
from . metrics import MetricsEndpoint
|
||||||
|
|
@ -52,6 +52,18 @@ class EndpointManager:
|
||||||
auth = auth,
|
auth = auth,
|
||||||
dispatcher = dispatcher_manager.dispatch_flow_export()
|
dispatcher = dispatcher_manager.dispatch_flow_export()
|
||||||
),
|
),
|
||||||
|
StreamEndpoint(
|
||||||
|
endpoint_path = "/api/v1/import-core",
|
||||||
|
auth = auth,
|
||||||
|
method = "POST",
|
||||||
|
dispatcher = dispatcher_manager.dispatch_core_import(),
|
||||||
|
),
|
||||||
|
StreamEndpoint(
|
||||||
|
endpoint_path = "/api/v1/export-core",
|
||||||
|
auth = auth,
|
||||||
|
method = "GET",
|
||||||
|
dispatcher = dispatcher_manager.dispatch_core_export(),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def add_routes(self, app):
|
def add_routes(self, app):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from aiohttp import web
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger("endpoint")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
class StreamEndpoint:
|
||||||
|
|
||||||
|
def __init__(self, endpoint_path, auth, dispatcher, method="POST"):
|
||||||
|
|
||||||
|
self.path = endpoint_path
|
||||||
|
|
||||||
|
self.auth = auth
|
||||||
|
self.operation = "service"
|
||||||
|
self.method = method
|
||||||
|
|
||||||
|
self.dispatcher = dispatcher
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def add_routes(self, app):
|
||||||
|
|
||||||
|
if self.method == "POST":
|
||||||
|
app.add_routes([
|
||||||
|
web.post(self.path, self.handle),
|
||||||
|
])
|
||||||
|
elif self.method == "GET":
|
||||||
|
app.add_routes([
|
||||||
|
web.get(self.path, self.handle),
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Bad method" + self.method)
|
||||||
|
|
||||||
|
async def handle(self, request):
|
||||||
|
|
||||||
|
print(request.path, "...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
ht = request.headers["Authorization"]
|
||||||
|
tokens = ht.split(" ", 2)
|
||||||
|
if tokens[0] != "Bearer":
|
||||||
|
return web.HTTPUnauthorized()
|
||||||
|
token = tokens[1]
|
||||||
|
except:
|
||||||
|
token = ""
|
||||||
|
|
||||||
|
if not self.auth.permitted(token, self.operation):
|
||||||
|
return web.HTTPUnauthorized()
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
data = request.content
|
||||||
|
|
||||||
|
async def error(err):
|
||||||
|
return web.HTTPInternalServerError(text = err)
|
||||||
|
|
||||||
|
async def ok(
|
||||||
|
status=200, reason="OK", type="application/octet-stream"
|
||||||
|
):
|
||||||
|
response = web.StreamResponse(
|
||||||
|
status = status, reason = reason,
|
||||||
|
headers = {"Content-Type": type}
|
||||||
|
)
|
||||||
|
await response.prepare(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
resp = await self.dispatcher.process(
|
||||||
|
data, error, ok, request
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Exception: {e}")
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{ "error": str(e) }
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import uuid
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger("endpoint")
|
logger = logging.getLogger("endpoint")
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,7 @@ class Api:
|
||||||
self.dispatcher_manager = DispatcherManager(
|
self.dispatcher_manager = DispatcherManager(
|
||||||
pulsar_client = self.pulsar_client,
|
pulsar_client = self.pulsar_client,
|
||||||
config_receiver = self.config_receiver,
|
config_receiver = self.config_receiver,
|
||||||
|
prefix = "gateway",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.endpoint_manager = EndpointManager(
|
self.endpoint_manager = EndpointManager(
|
||||||
|
|
|
||||||
|
|
@ -18,12 +18,14 @@ from .... base import ProducerSpec, ConsumerSpec, TextCompletionClientSpec
|
||||||
from . prompt_manager import PromptConfiguration, Prompt, PromptManager
|
from . prompt_manager import PromptConfiguration, Prompt, PromptManager
|
||||||
|
|
||||||
default_ident = "prompt"
|
default_ident = "prompt"
|
||||||
|
default_concurrency = 1
|
||||||
|
|
||||||
class Processor(FlowProcessor):
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
id = params.get("id")
|
id = params.get("id")
|
||||||
|
concurrency = params.get("concurrency", 1)
|
||||||
|
|
||||||
# Config key for prompts
|
# Config key for prompts
|
||||||
self.config_key = params.get("config_type", "prompt")
|
self.config_key = params.get("config_type", "prompt")
|
||||||
|
|
@ -31,6 +33,7 @@ class Processor(FlowProcessor):
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"id": id,
|
"id": id,
|
||||||
|
"concurrency": concurrency,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -38,7 +41,8 @@ class Processor(FlowProcessor):
|
||||||
ConsumerSpec(
|
ConsumerSpec(
|
||||||
name = "request",
|
name = "request",
|
||||||
schema = PromptRequest,
|
schema = PromptRequest,
|
||||||
handler = self.on_request
|
handler = self.on_request,
|
||||||
|
concurrency = concurrency,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -219,6 +223,13 @@ class Processor(FlowProcessor):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--concurrency',
|
||||||
|
type=int,
|
||||||
|
default=default_concurrency,
|
||||||
|
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||||
|
)
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
|
||||||
|
from . llm import *
|
||||||
|
|
||||||
7
trustgraph-flow/trustgraph/model/text_completion/vllm/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/vllm/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from . llm import run
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run()
|
||||||
|
|
||||||
138
trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py
Executable file
138
trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py
Executable file
|
|
@ -0,0 +1,138 @@
|
||||||
|
|
||||||
|
"""
|
||||||
|
Simple LLM service, performs text prompt completion using vLLM
|
||||||
|
Input is prompt, output is response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from .... exceptions import TooManyRequests
|
||||||
|
from .... base import LlmService, LlmResult
|
||||||
|
|
||||||
|
default_ident = "text-completion"
|
||||||
|
|
||||||
|
default_temperature = 0.0
|
||||||
|
default_max_output = 2048
|
||||||
|
default_base_url = os.getenv("VLLM_BASE_URL")
|
||||||
|
default_model = "TheBloke/Mistral-7B-v0.1-AWQ"
|
||||||
|
|
||||||
|
if default_base_url == "" or default_base_url is None:
|
||||||
|
default_base_url = "http://vllm-service:8899/v1"
|
||||||
|
|
||||||
|
class Processor(LlmService):
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
|
||||||
|
base_url = params.get("url", default_base_url)
|
||||||
|
temperature = params.get("temperature", default_temperature)
|
||||||
|
max_output = params.get("max_output", default_max_output)
|
||||||
|
model = params.get("model", default_model)
|
||||||
|
|
||||||
|
super(Processor, self).__init__(
|
||||||
|
**params | {
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_output": max_output,
|
||||||
|
"url": base_url,
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.base_url = base_url
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_output = max_output
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
self.session = aiohttp.ClientSession()
|
||||||
|
|
||||||
|
print("Using vLLM service at", base_url)
|
||||||
|
|
||||||
|
print("Initialised", flush=True)
|
||||||
|
|
||||||
|
async def generate_content(self, system, prompt):
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
request = {
|
||||||
|
"model": self.model,
|
||||||
|
"prompt": system + "\n\n" + prompt,
|
||||||
|
"max_tokens": self.max_output,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
url = f"{self.base_url}/completions"
|
||||||
|
|
||||||
|
async with self.session.post(
|
||||||
|
url,
|
||||||
|
headers=headers,
|
||||||
|
json=request,
|
||||||
|
) as response:
|
||||||
|
|
||||||
|
if response.status != 200:
|
||||||
|
raise RuntimeError("Bad status: " + str(response.status))
|
||||||
|
|
||||||
|
resp = await response.json()
|
||||||
|
|
||||||
|
inputtokens = resp["usage"]["prompt_tokens"]
|
||||||
|
outputtokens = resp["usage"]["completion_tokens"]
|
||||||
|
ans = resp["choices"][0]["text"]
|
||||||
|
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||||
|
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||||
|
print(ans, flush=True)
|
||||||
|
|
||||||
|
resp = LlmResult(
|
||||||
|
text = ans,
|
||||||
|
in_token = inputtokens,
|
||||||
|
out_token = outputtokens,
|
||||||
|
model = self.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
# FIXME: Assuming vLLM won't produce rate limits?
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
# Apart from rate limits, treat all exceptions as unrecoverable
|
||||||
|
|
||||||
|
print(f"Exception: {type(e)} {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_args(parser):
|
||||||
|
|
||||||
|
LlmService.add_args(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-u', '--url',
|
||||||
|
default=default_base_url,
|
||||||
|
help=f'vLLM service base URL (default: {default_base_url})'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-m', '--model',
|
||||||
|
default=default_model,
|
||||||
|
help=f'LLM model (default: {default_model})'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-t', '--temperature',
|
||||||
|
type=float,
|
||||||
|
default=default_temperature,
|
||||||
|
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-x', '--max-output',
|
||||||
|
type=int,
|
||||||
|
default=default_max_output,
|
||||||
|
help=f'LLM max output tokens (default: {default_max_output})'
|
||||||
|
)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
|
||||||
|
Processor.launch(default_ident, __doc__)
|
||||||
|
|
@ -11,12 +11,14 @@ from ... base import PromptClientSpec, EmbeddingsClientSpec
|
||||||
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
from ... base import GraphEmbeddingsClientSpec, TriplesClientSpec
|
||||||
|
|
||||||
default_ident = "graph-rag"
|
default_ident = "graph-rag"
|
||||||
|
default_concurrency = 1
|
||||||
|
|
||||||
class Processor(FlowProcessor):
|
class Processor(FlowProcessor):
|
||||||
|
|
||||||
def __init__(self, **params):
|
def __init__(self, **params):
|
||||||
|
|
||||||
id = params.get("id", default_ident)
|
id = params.get("id", default_ident)
|
||||||
|
concurrency = params.get("concurrency", 1)
|
||||||
|
|
||||||
entity_limit = params.get("entity_limit", 50)
|
entity_limit = params.get("entity_limit", 50)
|
||||||
triple_limit = params.get("triple_limit", 30)
|
triple_limit = params.get("triple_limit", 30)
|
||||||
|
|
@ -26,6 +28,7 @@ class Processor(FlowProcessor):
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
"id": id,
|
"id": id,
|
||||||
|
"concurrency": concurrency,
|
||||||
"entity_limit": entity_limit,
|
"entity_limit": entity_limit,
|
||||||
"triple_limit": triple_limit,
|
"triple_limit": triple_limit,
|
||||||
"max_subgraph_size": max_subgraph_size,
|
"max_subgraph_size": max_subgraph_size,
|
||||||
|
|
@ -43,6 +46,7 @@ class Processor(FlowProcessor):
|
||||||
name = "request",
|
name = "request",
|
||||||
schema = GraphRagQuery,
|
schema = GraphRagQuery,
|
||||||
handler = self.on_request,
|
handler = self.on_request,
|
||||||
|
concurrency = concurrency,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -157,6 +161,13 @@ class Processor(FlowProcessor):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-c', '--concurrency',
|
||||||
|
type=int,
|
||||||
|
default=default_concurrency,
|
||||||
|
help=f'Concurrent processing threads (default: {default_concurrency})'
|
||||||
|
)
|
||||||
|
|
||||||
FlowProcessor.add_args(parser)
|
FlowProcessor.add_args(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
|
||||||
1
trustgraph-flow/trustgraph/rev_gateway/__init__.py
Normal file
1
trustgraph-flow/trustgraph/rev_gateway/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from . service import run
|
||||||
11
trustgraph-flow/trustgraph/rev_gateway/__main__.py
Normal file
11
trustgraph-flow/trustgraph/rev_gateway/__main__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
import logging
|
||||||
|
from .service import run
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
|
|
||||||
130
trustgraph-flow/trustgraph/rev_gateway/dispatcher.py
Normal file
130
trustgraph-flow/trustgraph/rev_gateway/dispatcher.py
Normal file
|
|
@ -0,0 +1,130 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from trustgraph.messaging import TranslatorRegistry
|
||||||
|
from ..gateway.dispatch.manager import DispatcherManager
|
||||||
|
|
||||||
|
logger = logging.getLogger("dispatcher")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
class WebSocketResponder:
|
||||||
|
"""Simple responder that captures response for websocket return"""
|
||||||
|
def __init__(self):
|
||||||
|
self.response = None
|
||||||
|
self.completed = False
|
||||||
|
|
||||||
|
async def send(self, data):
|
||||||
|
"""Capture the response data"""
|
||||||
|
self.response = data
|
||||||
|
self.completed = True
|
||||||
|
|
||||||
|
async def __call__(self, data, final=False):
|
||||||
|
"""Make the responder callable for compatibility with requestor"""
|
||||||
|
await self.send(data)
|
||||||
|
if final:
|
||||||
|
self.completed = True
|
||||||
|
|
||||||
|
class MessageDispatcher:
|
||||||
|
|
||||||
|
def __init__(self, max_workers: int = 10, config_receiver=None, pulsar_client=None):
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.semaphore = asyncio.Semaphore(max_workers)
|
||||||
|
self.active_tasks = set()
|
||||||
|
self.pulsar_client = pulsar_client
|
||||||
|
|
||||||
|
# Use DispatcherManager for flow and service management
|
||||||
|
if pulsar_client and config_receiver:
|
||||||
|
self.dispatcher_manager = DispatcherManager(pulsar_client, config_receiver, prefix="rev-gateway")
|
||||||
|
else:
|
||||||
|
self.dispatcher_manager = None
|
||||||
|
logger.warning("No pulsar_client or config_receiver provided - using fallback mode")
|
||||||
|
|
||||||
|
# Service name mapping from websocket protocol to translator registry
|
||||||
|
self.service_mapping = {
|
||||||
|
"text-completion": "text-completion",
|
||||||
|
"graph-rag": "graph-rag",
|
||||||
|
"agent": "agent",
|
||||||
|
"embeddings": "embeddings",
|
||||||
|
"graph-embeddings": "graph-embeddings",
|
||||||
|
"triples": "triples",
|
||||||
|
"document-load": "document",
|
||||||
|
"text-load": "text-document",
|
||||||
|
"flow": "flow",
|
||||||
|
"knowledge": "knowledge",
|
||||||
|
"config": "config",
|
||||||
|
"librarian": "librarian",
|
||||||
|
"document-rag": "document-rag"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def handle_message(self, message: Dict[Any, Any]) -> Optional[Dict[Any, Any]]:
|
||||||
|
async with self.semaphore:
|
||||||
|
task = asyncio.create_task(self._process_message(message))
|
||||||
|
self.active_tasks.add(task)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await task
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
self.active_tasks.discard(task)
|
||||||
|
|
||||||
|
async def _process_message(self, message: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||||
|
request_id = message.get('id', str(uuid.uuid4()))
|
||||||
|
service = message.get('service')
|
||||||
|
request_data = message.get('request', {})
|
||||||
|
flow_id = message.get('flow', 'default') # Default flow
|
||||||
|
|
||||||
|
logger.info(f"Processing message {request_id} for service {service} on flow {flow_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.dispatcher_manager:
|
||||||
|
raise RuntimeError("DispatcherManager not available - pulsar_client and config_receiver required")
|
||||||
|
|
||||||
|
# Use DispatcherManager for flow-based processing
|
||||||
|
responder = WebSocketResponder()
|
||||||
|
|
||||||
|
# Map websocket service name to dispatcher service name
|
||||||
|
dispatcher_service = self.service_mapping.get(service, service)
|
||||||
|
|
||||||
|
# Check if this is a global service or flow service
|
||||||
|
from ..gateway.dispatch.manager import global_dispatchers
|
||||||
|
if dispatcher_service in global_dispatchers:
|
||||||
|
# Use global service dispatcher
|
||||||
|
await self.dispatcher_manager.invoke_global_service(
|
||||||
|
request_data, responder, dispatcher_service
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use DispatcherManager to process the request through Pulsar queues
|
||||||
|
await self.dispatcher_manager.invoke_flow_service(
|
||||||
|
request_data, responder, flow_id, dispatcher_service
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the response from the responder
|
||||||
|
if responder.completed:
|
||||||
|
response_data = responder.response
|
||||||
|
else:
|
||||||
|
response_data = {'error': 'No response received'}
|
||||||
|
|
||||||
|
response = {
|
||||||
|
'id': request_id,
|
||||||
|
'response': response_data
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing message {request_id}: {e}")
|
||||||
|
response = {
|
||||||
|
'id': request_id,
|
||||||
|
'response': {'error': str(e)}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Completed processing message {request_id}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
if self.active_tasks:
|
||||||
|
logger.info(f"Waiting for {len(self.active_tasks)} active tasks to complete")
|
||||||
|
await asyncio.gather(*self.active_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# DispatcherManager handles its own cleanup
|
||||||
|
logger.info("Dispatcher shutdown complete")
|
||||||
242
trustgraph-flow/trustgraph/rev_gateway/service.py
Normal file
242
trustgraph-flow/trustgraph/rev_gateway/service.py
Normal file
|
|
@ -0,0 +1,242 @@
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from aiohttp import ClientSession, WSMsgType, ClientWebSocketResponse
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
import pulsar
|
||||||
|
|
||||||
|
from .dispatcher import MessageDispatcher
|
||||||
|
from ..gateway.config.receiver import ConfigReceiver
|
||||||
|
|
||||||
|
logger = logging.getLogger("rev_gateway")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
default_websocket = "ws://localhost:7650/out"
|
||||||
|
|
||||||
|
class ReverseGateway:
|
||||||
|
|
||||||
|
def __init__(self, websocket_uri: str = None, max_workers: int = 10,
|
||||||
|
pulsar_host: str = None, pulsar_api_key: str = None,
|
||||||
|
pulsar_listener: str = None):
|
||||||
|
# Set default WebSocket URI with environment variable support
|
||||||
|
if websocket_uri is None:
|
||||||
|
websocket_uri = os.getenv("WEBSOCKET_URI", default_websocket)
|
||||||
|
|
||||||
|
# Parse and validate the WebSocket URI
|
||||||
|
parsed_uri = urlparse(websocket_uri)
|
||||||
|
if parsed_uri.scheme not in ('ws', 'wss'):
|
||||||
|
raise ValueError(f"WebSocket URI must use ws:// or wss:// scheme, got: {parsed_uri.scheme}")
|
||||||
|
if not parsed_uri.netloc:
|
||||||
|
raise ValueError(f"WebSocket URI must include hostname, got: {websocket_uri}")
|
||||||
|
|
||||||
|
# Store parsed components for debugging/logging
|
||||||
|
self.websocket_uri = websocket_uri
|
||||||
|
self.host = parsed_uri.hostname
|
||||||
|
self.port = parsed_uri.port
|
||||||
|
self.scheme = parsed_uri.scheme
|
||||||
|
self.path = parsed_uri.path or "/ws"
|
||||||
|
|
||||||
|
# Construct the full URL (in case path was missing)
|
||||||
|
if not parsed_uri.path:
|
||||||
|
self.url = f"{self.scheme}://{parsed_uri.netloc}/ws"
|
||||||
|
else:
|
||||||
|
self.url = websocket_uri
|
||||||
|
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.ws: Optional[ClientWebSocketResponse] = None
|
||||||
|
self.session: Optional[ClientSession] = None
|
||||||
|
self.running = False
|
||||||
|
self.reconnect_delay = 3.0
|
||||||
|
|
||||||
|
# Pulsar configuration
|
||||||
|
self.pulsar_host = pulsar_host or os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
|
||||||
|
self.pulsar_api_key = pulsar_api_key or os.getenv("PULSAR_API_KEY", None)
|
||||||
|
self.pulsar_listener = pulsar_listener
|
||||||
|
|
||||||
|
# Initialize Pulsar client
|
||||||
|
if self.pulsar_api_key:
|
||||||
|
self.pulsar_client = pulsar.Client(
|
||||||
|
self.pulsar_host,
|
||||||
|
listener_name=self.pulsar_listener,
|
||||||
|
authentication=pulsar.AuthenticationToken(self.pulsar_api_key)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.pulsar_client = pulsar.Client(
|
||||||
|
self.pulsar_host,
|
||||||
|
listener_name=self.pulsar_listener
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize config receiver
|
||||||
|
self.config_receiver = ConfigReceiver(self.pulsar_client)
|
||||||
|
|
||||||
|
# Initialize dispatcher with config_receiver and pulsar_client - must be created after config_receiver
|
||||||
|
self.dispatcher = MessageDispatcher(max_workers, self.config_receiver, self.pulsar_client)
|
||||||
|
|
||||||
|
async def connect(self) -> bool:
|
||||||
|
try:
|
||||||
|
if self.session is None:
|
||||||
|
self.session = ClientSession()
|
||||||
|
|
||||||
|
logger.info(f"Connecting to {self.url}")
|
||||||
|
self.ws = await self.session.ws_connect(self.url)
|
||||||
|
logger.info(f"WebSocket connection established to {self.host}:{self.port or 'default'}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to {self.url}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
if self.ws and not self.ws.closed:
|
||||||
|
await self.ws.close()
|
||||||
|
if self.session and not self.session.closed:
|
||||||
|
await self.session.close()
|
||||||
|
self.ws = None
|
||||||
|
self.session = None
|
||||||
|
|
||||||
|
async def send_message(self, message: dict):
|
||||||
|
if self.ws and not self.ws.closed:
|
||||||
|
try:
|
||||||
|
await self.ws.send_str(json.dumps(message))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send message: {e}")
|
||||||
|
|
||||||
|
async def handle_message(self, message: str):
|
||||||
|
try:
|
||||||
|
print(f"Received: {message}", flush=True)
|
||||||
|
|
||||||
|
msg_data = json.loads(message)
|
||||||
|
response = await self.dispatcher.handle_message(msg_data)
|
||||||
|
|
||||||
|
if response:
|
||||||
|
await self.send_message(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling message: {e}")
|
||||||
|
|
||||||
|
async def listen(self):
|
||||||
|
while self.running and self.ws and not self.ws.closed:
|
||||||
|
try:
|
||||||
|
msg = await self.ws.receive()
|
||||||
|
|
||||||
|
if msg.type == WSMsgType.TEXT:
|
||||||
|
await self.handle_message(msg.data)
|
||||||
|
elif msg.type == WSMsgType.BINARY:
|
||||||
|
await self.handle_message(msg.data.decode('utf-8'))
|
||||||
|
elif msg.type in (WSMsgType.CLOSE, WSMsgType.ERROR):
|
||||||
|
logger.warning("WebSocket closed or error occurred")
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in listen loop: {e}")
|
||||||
|
break
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
self.running = True
|
||||||
|
logger.info("Starting reverse gateway")
|
||||||
|
|
||||||
|
# Start config receiver
|
||||||
|
logger.info("Starting config receiver")
|
||||||
|
await self.config_receiver.start()
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
if await self.connect():
|
||||||
|
await self.listen()
|
||||||
|
else:
|
||||||
|
logger.warning(f"Connection failed, retrying in {self.reconnect_delay} seconds")
|
||||||
|
|
||||||
|
await self.disconnect()
|
||||||
|
|
||||||
|
if self.running:
|
||||||
|
await asyncio.sleep(self.reconnect_delay)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Shutdown requested")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error: {e}")
|
||||||
|
if self.running:
|
||||||
|
await asyncio.sleep(self.reconnect_delay)
|
||||||
|
|
||||||
|
await self.shutdown()
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
logger.info("Shutting down reverse gateway")
|
||||||
|
self.running = False
|
||||||
|
await self.dispatcher.shutdown()
|
||||||
|
await self.disconnect()
|
||||||
|
|
||||||
|
# Close Pulsar client
|
||||||
|
if hasattr(self, 'pulsar_client'):
|
||||||
|
self.pulsar_client.close()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="reverse-gateway",
|
||||||
|
description="TrustGraph Reverse Gateway - WebSocket to Pulsar bridge"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--websocket-uri',
|
||||||
|
default=None,
|
||||||
|
help=f'WebSocket URI to connect to (default: {default_websocket} or WEBSOCKET_URI env var)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--max-workers',
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help='Maximum concurrent message handlers (default: 10)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-p', '--pulsar-host',
|
||||||
|
default=None,
|
||||||
|
help='Pulsar host URL (default: pulsar://pulsar:6650 or PULSAR_HOST env var)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--pulsar-api-key',
|
||||||
|
default=None,
|
||||||
|
help='Pulsar API key for authentication (default: PULSAR_API_KEY env var)'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--pulsar-listener',
|
||||||
|
default=None,
|
||||||
|
help='Pulsar listener name'
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def run():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
gateway = ReverseGateway(
|
||||||
|
websocket_uri=args.websocket_uri,
|
||||||
|
max_workers=args.max_workers,
|
||||||
|
pulsar_host=args.pulsar_host,
|
||||||
|
pulsar_api_key=args.pulsar_api_key,
|
||||||
|
pulsar_listener=args.pulsar_listener
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Starting reverse gateway:")
|
||||||
|
print(f" WebSocket URI: {gateway.url}")
|
||||||
|
print(f" Max workers: {args.max_workers}")
|
||||||
|
print(f" Pulsar host: {gateway.pulsar_host}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
asyncio.run(gateway.run())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nShutdown requested by user")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Fatal error: {e}")
|
||||||
|
sys.exit(1)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue