mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 17:06:22 +02:00
Flow API - update gateway (#357)
* Altered API to incorporate Flow IDs, refactored for dynamic start/stop of flows * Gateway: Split endpoint / dispatcher for maintainability
This commit is contained in:
parent
450f664b1b
commit
a70ae9793a
52 changed files with 1206 additions and 907 deletions
33
test-api/test-llm2-api
Executable file
33
test-api/test-llm2-api
Executable file
|
|
@ -0,0 +1,33 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
|
url = "http://localhost:8088/api/v1/"
|
||||||
|
|
||||||
|
############################################################################
|
||||||
|
|
||||||
|
input = {
|
||||||
|
"system": "",
|
||||||
|
"prompt": "Add 2 and 3"
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = requests.post(
|
||||||
|
f"{url}text-completion",
|
||||||
|
json=input,
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status_code != 200:
|
||||||
|
raise RuntimeError(f"Status code: {resp.status_code}")
|
||||||
|
|
||||||
|
resp = resp.json()
|
||||||
|
|
||||||
|
if "error" in resp:
|
||||||
|
print(f"Error: {resp['error']}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(resp["response"])
|
||||||
|
|
||||||
|
############################################################################
|
||||||
|
|
||||||
|
|
@ -5,7 +5,7 @@ import json
|
||||||
import sys
|
import sys
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
url = "http://localhost:8088/api/v1/"
|
url = "http://localhost:8088/api/v1/flow/0000/document-load"
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
|
@ -88,10 +88,7 @@ input = {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = requests.post(
|
resp = requests.post(url, json=input)
|
||||||
f"{url}load/document",
|
|
||||||
json=input,
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = resp.json()
|
resp = resp.json()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import json
|
||||||
import sys
|
import sys
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
url = "http://localhost:8088/api/v1/"
|
url = "http://localhost:8088/api/v1/flow/0000/service/text-load"
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
|
@ -85,10 +85,7 @@ input = {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = requests.post(
|
resp = requests.post(url, json=input)
|
||||||
f"{url}load/text",
|
|
||||||
json=input,
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = resp.json()
|
resp = resp.json()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from trustgraph.schema import Document, Metadata
|
||||||
client = pulsar.Client("pulsar://localhost:6650", listener_name="localhost")
|
client = pulsar.Client("pulsar://localhost:6650", listener_name="localhost")
|
||||||
|
|
||||||
prod = client.create_producer(
|
prod = client.create_producer(
|
||||||
topic="persistent://tg/flow/document-load:0002",
|
topic="persistent://tg/flow/document-load:0000",
|
||||||
schema=JsonSchema(Document),
|
schema=JsonSchema(Document),
|
||||||
chunking_enabled=True,
|
chunking_enabled=True,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ prod = client.create_producer(
|
||||||
chunking_enabled=True,
|
chunking_enabled=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
path = "docs/README.cats"
|
path = "../trustgraph/docs/README.cats"
|
||||||
|
|
||||||
with open(path, "r") as f:
|
with open(path, "r") as f:
|
||||||
# blob = base64.b64encode(f.read()).decode("utf-8")
|
# blob = base64.b64encode(f.read()).decode("utf-8")
|
||||||
|
|
|
||||||
|
|
@ -15,17 +15,22 @@ class Publisher:
|
||||||
self.q = asyncio.Queue(maxsize=max_size)
|
self.q = asyncio.Queue(maxsize=max_size)
|
||||||
self.chunking_enabled = chunking_enabled
|
self.chunking_enabled = chunking_enabled
|
||||||
self.running = True
|
self.running = True
|
||||||
|
self.task = None
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self.task = asyncio.create_task(self.run())
|
self.task = asyncio.create_task(self.run())
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
await self.task
|
|
||||||
|
if self.task:
|
||||||
|
await self.task
|
||||||
|
|
||||||
async def join(self):
|
async def join(self):
|
||||||
await self.stop()
|
await self.stop()
|
||||||
await self.task
|
|
||||||
|
if self.task:
|
||||||
|
await self.task
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ class Subscriber:
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
self.running = True
|
self.running = True
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
|
self.task = None
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
@ -28,11 +29,15 @@ class Subscriber:
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
await self.task
|
|
||||||
|
if self.task:
|
||||||
|
await self.task
|
||||||
|
|
||||||
async def join(self):
|
async def join(self):
|
||||||
await self.stop()
|
await self.stop()
|
||||||
await self.task
|
|
||||||
|
if self.task:
|
||||||
|
await self.task
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
|
|
||||||
|
|
@ -45,6 +50,8 @@ class Subscriber:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
# FIXME: Create consumer in start method so we know
|
||||||
|
# it is definitely running when start completes
|
||||||
consumer = self.client.subscribe(
|
consumer = self.client.subscribe(
|
||||||
topic = self.topic,
|
topic = self.topic,
|
||||||
subscription_name = self.subscription,
|
subscription_name = self.subscription,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
|
|
||||||
from pulsar.schema import JsonSchema
|
from pulsar.schema import JsonSchema
|
||||||
from .. schema import EmbeddingsRequest, EmbeddingsResponse
|
from .. schema import EmbeddingsRequest, EmbeddingsResponse
|
||||||
from .. schema import embeddings_request_queue, embeddings_response_queue
|
|
||||||
from . base import BaseClient
|
from . base import BaseClient
|
||||||
|
|
||||||
import _pulsar
|
import _pulsar
|
||||||
|
|
@ -23,12 +22,6 @@ class EmbeddingsClient(BaseClient):
|
||||||
pulsar_api_key=None,
|
pulsar_api_key=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if input_queue == None:
|
|
||||||
input_queue=embeddings_request_queue
|
|
||||||
|
|
||||||
if output_queue == None:
|
|
||||||
output_queue=embeddings_response_queue
|
|
||||||
|
|
||||||
super(EmbeddingsClient, self).__init__(
|
super(EmbeddingsClient, self).__init__(
|
||||||
log_level=log_level,
|
log_level=log_level,
|
||||||
subscriber=subscriber,
|
subscriber=subscriber,
|
||||||
|
|
@ -43,4 +36,3 @@ class EmbeddingsClient(BaseClient):
|
||||||
def request(self, text, timeout=300):
|
def request(self, text, timeout=300):
|
||||||
return self.call(text=text, timeout=timeout).vectors
|
return self.call(text=text, timeout=timeout).vectors
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ async def load_ge(running, queue, url):
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
|
|
||||||
async with session.ws_connect(f"{url}load/graph-embeddings") as ws:
|
async with session.ws_connect(url) as ws:
|
||||||
|
|
||||||
while running.get():
|
while running.get():
|
||||||
|
|
||||||
|
|
@ -73,7 +73,7 @@ async def load_triples(running, queue, url):
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
|
|
||||||
async with session.ws_connect(f"{url}load/triples") as ws:
|
async with session.ws_connect(url) as ws:
|
||||||
|
|
||||||
while running.get():
|
while running.get():
|
||||||
|
|
||||||
|
|
@ -200,6 +200,9 @@ async def run(running, **args):
|
||||||
ge_q = asyncio.Queue(maxsize=10)
|
ge_q = asyncio.Queue(maxsize=10)
|
||||||
t_q = asyncio.Queue(maxsize=10)
|
t_q = asyncio.Queue(maxsize=10)
|
||||||
|
|
||||||
|
flow_id = args["flow_id"]
|
||||||
|
url = args["url"]
|
||||||
|
|
||||||
load_task = asyncio.create_task(
|
load_task = asyncio.create_task(
|
||||||
loader(
|
loader(
|
||||||
running=running,
|
running=running,
|
||||||
|
|
@ -212,15 +215,17 @@ async def run(running, **args):
|
||||||
|
|
||||||
ge_task = asyncio.create_task(
|
ge_task = asyncio.create_task(
|
||||||
load_ge(
|
load_ge(
|
||||||
running=running,
|
running = running,
|
||||||
queue=ge_q, url=args["url"] + "api/v1/"
|
queue = ge_q,
|
||||||
|
url = f"{url}api/v1/flow/{flow_id}/import/graph-embeddings"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
triples_task = asyncio.create_task(
|
triples_task = asyncio.create_task(
|
||||||
load_triples(
|
load_triples(
|
||||||
running=running,
|
running = running,
|
||||||
queue=t_q, url=args["url"] + "api/v1/"
|
queue = t_q,
|
||||||
|
url = f"{url}api/v1/flow/{flow_id}/import/triples"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -258,6 +263,12 @@ async def main(running):
|
||||||
help=f'Output file'
|
help=f'Output file'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-f', '--flow-id',
|
||||||
|
default="0000",
|
||||||
|
help=f'Flow ID (default: 0000)'
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--format',
|
'--format',
|
||||||
default="msgpack",
|
default="msgpack",
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ async def fetch_ge(running, queue, user, collection, url):
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
|
|
||||||
async with session.ws_connect(f"{url}stream/graph-embeddings") as ws:
|
async with session.ws_connect(url) as ws:
|
||||||
|
|
||||||
while running.get():
|
while running.get():
|
||||||
|
|
||||||
|
|
@ -74,7 +74,7 @@ async def fetch_triples(running, queue, user, collection, url):
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
|
|
||||||
async with session.ws_connect(f"{url}stream/triples") as ws:
|
async with session.ws_connect(url) as ws:
|
||||||
|
|
||||||
while running.get():
|
while running.get():
|
||||||
|
|
||||||
|
|
@ -160,11 +160,14 @@ async def run(running, **args):
|
||||||
|
|
||||||
q = asyncio.Queue()
|
q = asyncio.Queue()
|
||||||
|
|
||||||
|
flow_id = args["flow_id"]
|
||||||
|
url = args["url"]
|
||||||
|
|
||||||
ge_task = asyncio.create_task(
|
ge_task = asyncio.create_task(
|
||||||
fetch_ge(
|
fetch_ge(
|
||||||
running=running,
|
running=running,
|
||||||
queue=q, user=args["user"], collection=args["collection"],
|
queue=q, user=args["user"], collection=args["collection"],
|
||||||
url=args["url"] + "api/v1/"
|
url = f"{url}api/v1/flow/{flow_id}/export/graph-embeddings"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -172,7 +175,7 @@ async def run(running, **args):
|
||||||
fetch_triples(
|
fetch_triples(
|
||||||
running=running, queue=q,
|
running=running, queue=q,
|
||||||
user=args["user"], collection=args["collection"],
|
user=args["user"], collection=args["collection"],
|
||||||
url=args["url"] + "api/v1/"
|
url = f"{url}api/v1/flow/{flow_id}/export/triples"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -224,6 +227,12 @@ async def main(running):
|
||||||
help=f'Output format (default: msgpack)',
|
help=f'Output format (default: msgpack)',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-f', '--flow-id',
|
||||||
|
default="0000",
|
||||||
|
help=f'Flow ID (default: 0000)'
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--user',
|
'--user',
|
||||||
help=f'User ID to filter on (default: no filter)'
|
help=f'User ID to filter on (default: no filter)'
|
||||||
|
|
|
||||||
121
trustgraph-flow/trustgraph/gateway/config/receiver.py
Executable file
121
trustgraph-flow/trustgraph/gateway/config/receiver.py
Executable file
|
|
@ -0,0 +1,121 @@
|
||||||
|
"""
|
||||||
|
API gateway. Offers HTTP services which are translated to interaction on the
|
||||||
|
Pulsar bus.
|
||||||
|
"""
|
||||||
|
|
||||||
|
module = "api-gateway"
|
||||||
|
|
||||||
|
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
|
||||||
|
# are active listeners
|
||||||
|
|
||||||
|
# FIXME: Connection errors in publishers / subscribers cause those threads
|
||||||
|
# to fail and are not failed or retried
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
from aiohttp import web
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import uuid
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pulsar
|
||||||
|
from prometheus_client import start_http_server
|
||||||
|
|
||||||
|
from ... schema import ConfigPush, config_push_queue
|
||||||
|
from ... base import Consumer
|
||||||
|
|
||||||
|
logger = logging.getLogger("config.receiver")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
class ConfigReceiver:
|
||||||
|
|
||||||
|
def __init__(self, pulsar_client):
|
||||||
|
|
||||||
|
self.pulsar_client = pulsar_client
|
||||||
|
|
||||||
|
self.flow_handlers = []
|
||||||
|
|
||||||
|
self.flows = {}
|
||||||
|
|
||||||
|
def add_handler(self, h):
|
||||||
|
self.flow_handlers.append(h)
|
||||||
|
|
||||||
|
async def on_config(self, msg, proc, flow):
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
v = msg.value()
|
||||||
|
|
||||||
|
print(f"Config version", v.version)
|
||||||
|
|
||||||
|
if "flows" in v.config:
|
||||||
|
|
||||||
|
flows = v.config["flows"]
|
||||||
|
|
||||||
|
wanted = list(flows.keys())
|
||||||
|
current = list(self.flows.keys())
|
||||||
|
|
||||||
|
for k in wanted:
|
||||||
|
if k not in current:
|
||||||
|
self.flows[k] = json.loads(flows[k])
|
||||||
|
await self.start_flow(k, self.flows[k])
|
||||||
|
|
||||||
|
for k in current:
|
||||||
|
if k not in wanted:
|
||||||
|
await self.stop_flow(k, self.flows[k])
|
||||||
|
del self.flows[k]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception: {e}", flush=True)
|
||||||
|
|
||||||
|
async def start_flow(self, id, flow):
|
||||||
|
|
||||||
|
print("Start flow", id)
|
||||||
|
|
||||||
|
for handler in self.flow_handlers:
|
||||||
|
|
||||||
|
try:
|
||||||
|
await handler.start_flow(id, flow)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception: {e}", flush=True)
|
||||||
|
|
||||||
|
async def stop_flow(self, id, flow):
|
||||||
|
|
||||||
|
print("Stop flow", id)
|
||||||
|
|
||||||
|
for handler in self.flow_handlers:
|
||||||
|
|
||||||
|
try:
|
||||||
|
await handler.stop_flow(id, flow)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception: {e}", flush=True)
|
||||||
|
|
||||||
|
async def config_loader(self):
|
||||||
|
|
||||||
|
async with asyncio.TaskGroup() as tg:
|
||||||
|
|
||||||
|
id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
self.config_cons = Consumer(
|
||||||
|
taskgroup = tg,
|
||||||
|
flow = None,
|
||||||
|
client = self.pulsar_client,
|
||||||
|
subscriber = f"gateway-{id}",
|
||||||
|
topic = config_push_queue,
|
||||||
|
schema = ConfigPush,
|
||||||
|
handler = self.on_config,
|
||||||
|
start_of_messages = True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.config_cons.start()
|
||||||
|
|
||||||
|
print("Waiting...")
|
||||||
|
|
||||||
|
print("Config consumer done. :/")
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
|
||||||
|
asyncio.create_task(self.config_loader())
|
||||||
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
|
|
||||||
from .. schema import AgentRequest, AgentResponse
|
from ... schema import AgentRequest, AgentResponse
|
||||||
|
|
||||||
from . endpoint import ServiceEndpoint
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
class AgentRequestor(ServiceRequestor):
|
class AgentRequestor(ServiceRequestor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
self, pulsar_client, request_queue, response_queue, timeout,
|
||||||
consumer, subscriber,
|
consumer, subscriber,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
|
|
||||||
from .. schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
|
from ... schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
|
||||||
from .. schema import config_request_queue
|
from ... schema import config_request_queue
|
||||||
from .. schema import config_response_queue
|
from ... schema import config_response_queue
|
||||||
|
|
||||||
from . endpoint import ServiceEndpoint
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
class ConfigRequestor(ServiceRequestor):
|
class ConfigRequestor(ServiceRequestor):
|
||||||
def __init__(self, pulsar_client, timeout, auth):
|
def __init__(self, pulsar_client, timeout=120):
|
||||||
|
|
||||||
super(ConfigRequestor, self).__init__(
|
super(ConfigRequestor, self).__init__(
|
||||||
pulsar_client=pulsar_client,
|
pulsar_client=pulsar_client,
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import queue
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from ... schema import DocumentEmbeddings
|
||||||
|
from ... base import Subscriber
|
||||||
|
|
||||||
|
from . serialize import serialize_document_embeddings
|
||||||
|
|
||||||
|
class DocumentEmbeddingsExport:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, ws, running, pulsar_client, queue, consumer, subscriber
|
||||||
|
):
|
||||||
|
|
||||||
|
self.ws = ws
|
||||||
|
self.running = running
|
||||||
|
self.pulsar_client = pulsar_client
|
||||||
|
self.queue = queue
|
||||||
|
self.consumer = consumer
|
||||||
|
self.subscriber = subscriber
|
||||||
|
|
||||||
|
async def destroy(self):
|
||||||
|
self.running.stop()
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
async def receive(self, msg):
|
||||||
|
# Ignore incoming info from websocket
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
|
||||||
|
subs = Subscriber(
|
||||||
|
client = self.pulsar_client, topic = self.queue,
|
||||||
|
consumer_name = self.consumer, subscription = self.subscriber,
|
||||||
|
schema = DocumentEmbeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
await subs.start()
|
||||||
|
|
||||||
|
id = str(uuid.uuid4())
|
||||||
|
q = await subs.subscribe_all(id)
|
||||||
|
|
||||||
|
while self.running.get():
|
||||||
|
try:
|
||||||
|
|
||||||
|
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||||
|
await self.ws.send_json(serialize_document_embeddings(resp))
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception: {str(e)}", flush=True)
|
||||||
|
break
|
||||||
|
|
||||||
|
await subs.unsubscribe_all(id)
|
||||||
|
|
||||||
|
await subs.stop()
|
||||||
|
|
||||||
|
await self.ws.close()
|
||||||
|
self.running.stop()
|
||||||
|
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from aiohttp import WSMsgType
|
||||||
|
|
||||||
|
from ... schema import Metadata
|
||||||
|
from ... schema import DocumentEmbeddings, ChunkEmbeddings
|
||||||
|
from ... base import Publisher
|
||||||
|
|
||||||
|
from . serialize import to_subgraph
|
||||||
|
|
||||||
|
class DocumentEmbeddingsImport:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, ws, running, pulsar_client, queue
|
||||||
|
):
|
||||||
|
|
||||||
|
self.ws = ws
|
||||||
|
self.running = running
|
||||||
|
|
||||||
|
self.publisher = Publisher(
|
||||||
|
pulsar_client, topic = queue, schema = DocumentEmbeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
async def destroy(self):
|
||||||
|
self.running.stop()
|
||||||
|
|
||||||
|
if self.ws:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
await self.publisher.stop()
|
||||||
|
|
||||||
|
async def receive(self, msg):
|
||||||
|
|
||||||
|
data = msg.json()
|
||||||
|
|
||||||
|
elt = DocumentEmbeddings(
|
||||||
|
metadata=Metadata(
|
||||||
|
id=data["metadata"]["id"],
|
||||||
|
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||||
|
user=data["metadata"]["user"],
|
||||||
|
collection=data["metadata"]["collection"],
|
||||||
|
),
|
||||||
|
chunks=[
|
||||||
|
ChunkEmbeddings(
|
||||||
|
chunk=de["chunk"].encode("utf-8"),
|
||||||
|
vectors=de["vectors"],
|
||||||
|
)
|
||||||
|
for de in data["chunks"]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.publisher.send(None, elt)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
|
||||||
|
while self.running.get():
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
if self.ws:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
self.ws = None
|
||||||
|
|
||||||
|
|
@ -1,19 +1,18 @@
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
from .. schema import Document, Metadata
|
from ... schema import Document, Metadata
|
||||||
from .. schema import document_ingest_queue
|
|
||||||
|
|
||||||
from . sender import ServiceSender
|
from . sender import ServiceSender
|
||||||
from . serialize import to_subgraph
|
from . serialize import to_subgraph
|
||||||
|
|
||||||
class DocumentLoadSender(ServiceSender):
|
class DocumentLoad(ServiceSender):
|
||||||
def __init__(self, pulsar_client):
|
def __init__(self, pulsar_client, queue):
|
||||||
|
|
||||||
super(DocumentLoadSender, self).__init__(
|
super(DocumentLoad, self).__init__(
|
||||||
pulsar_client=pulsar_client,
|
pulsar_client = pulsar_client,
|
||||||
request_queue=document_ingest_queue,
|
queue = queue,
|
||||||
request_schema=Document,
|
schema = Document,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
|
|
@ -1,20 +1,22 @@
|
||||||
|
|
||||||
from .. schema import DocumentRagQuery, DocumentRagResponse
|
from ... schema import DocumentRagQuery, DocumentRagResponse
|
||||||
from .. schema import document_rag_request_queue
|
|
||||||
from .. schema import document_rag_response_queue
|
|
||||||
|
|
||||||
from . endpoint import ServiceEndpoint
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
class DocumentRagRequestor(ServiceRequestor):
|
class DocumentRagRequestor(ServiceRequestor):
|
||||||
def __init__(self, pulsar_client, timeout, auth):
|
def __init__(
|
||||||
|
self, pulsar_client, request_queue, response_queue, timeout,
|
||||||
|
consumer, subscriber,
|
||||||
|
):
|
||||||
|
|
||||||
super(DocumentRagRequestor, self).__init__(
|
super(DocumentRagRequestor, self).__init__(
|
||||||
pulsar_client=pulsar_client,
|
pulsar_client=pulsar_client,
|
||||||
request_queue=document_rag_request_queue,
|
request_queue=request_queue,
|
||||||
response_queue=document_rag_response_queue,
|
response_queue=response_queue,
|
||||||
request_schema=DocumentRagQuery,
|
request_schema=DocumentRagQuery,
|
||||||
response_schema=DocumentRagResponse,
|
response_schema=DocumentRagResponse,
|
||||||
|
subscription = subscriber,
|
||||||
|
consumer_name = consumer,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
|
|
||||||
from .. schema import EmbeddingsRequest, EmbeddingsResponse
|
from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
||||||
|
|
||||||
from . endpoint import ServiceEndpoint
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
class EmbeddingsRequestor(ServiceRequestor):
|
class EmbeddingsRequestor(ServiceRequestor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
self, pulsar_client, request_queue, response_queue, timeout,
|
||||||
consumer, subscriber,
|
consumer, subscriber,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
|
|
||||||
from .. schema import FlowRequest, FlowResponse, ConfigKey, ConfigValue
|
from ... schema import FlowRequest, FlowResponse
|
||||||
from .. schema import flow_request_queue
|
from ... schema import flow_request_queue
|
||||||
from .. schema import flow_response_queue
|
from ... schema import flow_response_queue
|
||||||
|
|
||||||
from . endpoint import ServiceEndpoint
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
class FlowRequestor(ServiceRequestor):
|
class FlowRequestor(ServiceRequestor):
|
||||||
def __init__(self, pulsar_client, timeout, auth):
|
def __init__(self, pulsar_client, timeout=120):
|
||||||
|
|
||||||
super(FlowRequestor, self).__init__(
|
super(FlowRequestor, self).__init__(
|
||||||
pulsar_client=pulsar_client,
|
pulsar_client=pulsar_client,
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import queue
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from ... schema import GraphEmbeddings
|
||||||
|
from ... base import Subscriber
|
||||||
|
|
||||||
|
from . serialize import serialize_graph_embeddings
|
||||||
|
|
||||||
|
class GraphEmbeddingsExport:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, ws, running, pulsar_client, queue, consumer, subscriber
|
||||||
|
):
|
||||||
|
|
||||||
|
self.ws = ws
|
||||||
|
self.running = running
|
||||||
|
self.pulsar_client = pulsar_client
|
||||||
|
self.queue = queue
|
||||||
|
self.consumer = consumer
|
||||||
|
self.subscriber = subscriber
|
||||||
|
|
||||||
|
async def destroy(self):
|
||||||
|
self.running.stop()
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
async def receive(self, msg):
|
||||||
|
# Ignore incoming info from websocket
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
|
||||||
|
subs = Subscriber(
|
||||||
|
client = self.pulsar_client, topic = self.queue,
|
||||||
|
consumer_name = self.consumer, subscription = self.subscriber,
|
||||||
|
schema = GraphEmbeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
await subs.start()
|
||||||
|
|
||||||
|
id = str(uuid.uuid4())
|
||||||
|
q = await subs.subscribe_all(id)
|
||||||
|
|
||||||
|
while self.running.get():
|
||||||
|
try:
|
||||||
|
|
||||||
|
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||||
|
await self.ws.send_json(serialize_graph_embeddings(resp))
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception: {str(e)}", flush=True)
|
||||||
|
break
|
||||||
|
|
||||||
|
await subs.unsubscribe_all(id)
|
||||||
|
|
||||||
|
await subs.stop()
|
||||||
|
|
||||||
|
await self.ws.close()
|
||||||
|
self.running.stop()
|
||||||
|
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from aiohttp import WSMsgType
|
||||||
|
|
||||||
|
from ... schema import Metadata
|
||||||
|
from ... schema import GraphEmbeddings, EntityEmbeddings
|
||||||
|
from ... base import Publisher
|
||||||
|
|
||||||
|
from . serialize import to_subgraph, to_value
|
||||||
|
|
||||||
|
class GraphEmbeddingsImport:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, ws, running, pulsar_client, queue
|
||||||
|
):
|
||||||
|
|
||||||
|
self.ws = ws
|
||||||
|
self.running = running
|
||||||
|
|
||||||
|
self.publisher = Publisher(
|
||||||
|
pulsar_client, topic = queue, schema = GraphEmbeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
async def destroy(self):
|
||||||
|
self.running.stop()
|
||||||
|
|
||||||
|
if self.ws:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
await self.publisher.stop()
|
||||||
|
|
||||||
|
async def receive(self, msg):
|
||||||
|
|
||||||
|
data = msg.json()
|
||||||
|
|
||||||
|
elt = GraphEmbeddings(
|
||||||
|
metadata=Metadata(
|
||||||
|
id=data["metadata"]["id"],
|
||||||
|
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||||
|
user=data["metadata"]["user"],
|
||||||
|
collection=data["metadata"]["collection"],
|
||||||
|
),
|
||||||
|
entities=[
|
||||||
|
EntityEmbeddings(
|
||||||
|
entity=to_value(ent["entity"]),
|
||||||
|
vectors=ent["vectors"],
|
||||||
|
)
|
||||||
|
for ent in data["entities"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.publisher.send(None, elt)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
|
||||||
|
while self.running.get():
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
if self.ws:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
self.ws = None
|
||||||
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
|
|
||||||
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
from ... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||||
|
|
||||||
from . endpoint import ServiceEndpoint
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
from . serialize import serialize_value
|
from . serialize import serialize_value
|
||||||
|
|
||||||
class GraphEmbeddingsQueryRequestor(ServiceRequestor):
|
class GraphEmbeddingsQueryRequestor(ServiceRequestor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
self, pulsar_client, request_queue, response_queue, timeout,
|
||||||
consumer, subscriber,
|
consumer, subscriber,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
|
|
||||||
from .. schema import GraphRagQuery, GraphRagResponse
|
from ... schema import GraphRagQuery, GraphRagResponse
|
||||||
|
|
||||||
from . endpoint import ServiceEndpoint
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
class GraphRagRequestor(ServiceRequestor):
|
class GraphRagRequestor(ServiceRequestor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
self, pulsar_client, request_queue, response_queue, timeout,
|
||||||
consumer, subscriber,
|
consumer, subscriber,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
@ -1,15 +1,14 @@
|
||||||
|
|
||||||
from .. schema import LibrarianRequest, LibrarianResponse, Triples
|
from ... schema import LibrarianRequest, LibrarianResponse
|
||||||
from .. schema import librarian_request_queue
|
from ... schema import librarian_request_queue
|
||||||
from .. schema import librarian_response_queue
|
from ... schema import librarian_response_queue
|
||||||
|
|
||||||
from . endpoint import ServiceEndpoint
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
from . serialize import serialize_document_package, serialize_document_info
|
from . serialize import serialize_document_package, serialize_document_info
|
||||||
from . serialize import to_document_package, to_document_info, to_criteria
|
from . serialize import to_document_package, to_document_info, to_criteria
|
||||||
|
|
||||||
class LibrarianRequestor(ServiceRequestor):
|
class LibrarianRequestor(ServiceRequestor):
|
||||||
def __init__(self, pulsar_client, timeout, auth):
|
def __init__(self, pulsar_client, timeout=120):
|
||||||
|
|
||||||
super(LibrarianRequestor, self).__init__(
|
super(LibrarianRequestor, self).__init__(
|
||||||
pulsar_client=pulsar_client,
|
pulsar_client=pulsar_client,
|
||||||
|
|
@ -22,20 +21,16 @@ class LibrarianRequestor(ServiceRequestor):
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
|
|
||||||
print("TRR")
|
|
||||||
if "document" in body:
|
if "document" in body:
|
||||||
dp = to_document_package(body["document"])
|
dp = to_document_package(body["document"])
|
||||||
else:
|
else:
|
||||||
dp = None
|
dp = None
|
||||||
|
|
||||||
print("GOT")
|
|
||||||
if "criteria" in body:
|
if "criteria" in body:
|
||||||
criteria = to_criteria(body["criteria"])
|
criteria = to_criteria(body["criteria"])
|
||||||
else:
|
else:
|
||||||
criteria = None
|
criteria = None
|
||||||
|
|
||||||
print("ASLDKJ")
|
|
||||||
|
|
||||||
return LibrarianRequest(
|
return LibrarianRequest(
|
||||||
operation = body.get("operation", None),
|
operation = body.get("operation", None),
|
||||||
id = body.get("id", None),
|
id = body.get("id", None),
|
||||||
229
trustgraph-flow/trustgraph/gateway/dispatch/manager.py
Normal file
229
trustgraph-flow/trustgraph/gateway/dispatch/manager.py
Normal file
|
|
@ -0,0 +1,229 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from . config import ConfigRequestor
|
||||||
|
from . flow import FlowRequestor
|
||||||
|
from . librarian import LibrarianRequestor
|
||||||
|
|
||||||
|
from . embeddings import EmbeddingsRequestor
|
||||||
|
from . agent import AgentRequestor
|
||||||
|
from . text_completion import TextCompletionRequestor
|
||||||
|
from . prompt import PromptRequestor
|
||||||
|
from . graph_rag import GraphRagRequestor
|
||||||
|
from . document_rag import DocumentRagRequestor
|
||||||
|
from . triples_query import TriplesQueryRequestor
|
||||||
|
from . embeddings import EmbeddingsRequestor
|
||||||
|
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
||||||
|
from . prompt import PromptRequestor
|
||||||
|
from . text_load import TextLoad
|
||||||
|
from . document_load import DocumentLoad
|
||||||
|
|
||||||
|
from . triples_export import TriplesExport
|
||||||
|
from . graph_embeddings_export import GraphEmbeddingsExport
|
||||||
|
from . document_embeddings_export import DocumentEmbeddingsExport
|
||||||
|
|
||||||
|
from . triples_import import TriplesImport
|
||||||
|
from . graph_embeddings_import import GraphEmbeddingsImport
|
||||||
|
from . document_embeddings_import import DocumentEmbeddingsImport
|
||||||
|
|
||||||
|
request_response_dispatchers = {
|
||||||
|
"agent": AgentRequestor,
|
||||||
|
"text-completion": TextCompletionRequestor,
|
||||||
|
"prompt": PromptRequestor,
|
||||||
|
"graph-rag": GraphRagRequestor,
|
||||||
|
"document-rag": DocumentRagRequestor,
|
||||||
|
"embeddings": EmbeddingsRequestor,
|
||||||
|
"graph-embeddings": GraphEmbeddingsQueryRequestor,
|
||||||
|
"triples-query": TriplesQueryRequestor,
|
||||||
|
}
|
||||||
|
|
||||||
|
sender_dispatchers = {
|
||||||
|
"text-load": TextLoad,
|
||||||
|
"document-load": DocumentLoad,
|
||||||
|
}
|
||||||
|
|
||||||
|
export_dispatchers = {
|
||||||
|
"triples": TriplesExport,
|
||||||
|
"graph-embeddings": GraphEmbeddingsExport,
|
||||||
|
"document-embeddings": DocumentEmbeddingsExport,
|
||||||
|
}
|
||||||
|
|
||||||
|
import_dispatchers = {
|
||||||
|
"triples": TriplesImport,
|
||||||
|
"graph-embeddings": GraphEmbeddingsImport,
|
||||||
|
"document-embeddings": DocumentEmbeddingsImport,
|
||||||
|
}
|
||||||
|
|
||||||
|
class DispatcherWrapper:
|
||||||
|
def __init__(self, mgr, name, impl):
|
||||||
|
self.mgr = mgr
|
||||||
|
self.name = name
|
||||||
|
self.impl = impl
|
||||||
|
async def process(self, data, responder):
|
||||||
|
return await self.mgr.process_impl(
|
||||||
|
data, responder, self.name, self.impl
|
||||||
|
)
|
||||||
|
|
||||||
|
class DispatcherManager:
|
||||||
|
|
||||||
|
def __init__(self, pulsar_client, config_receiver):
|
||||||
|
self.pulsar_client = pulsar_client
|
||||||
|
self.config_receiver = config_receiver
|
||||||
|
self.config_receiver.add_handler(self)
|
||||||
|
|
||||||
|
self.flows = {}
|
||||||
|
self.dispatchers = {}
|
||||||
|
|
||||||
|
async def start_flow(self, id, flow):
|
||||||
|
print("Start flow", id)
|
||||||
|
self.flows[id] = flow
|
||||||
|
return
|
||||||
|
|
||||||
|
async def stop_flow(self, id, flow):
|
||||||
|
print("Stop flow", id)
|
||||||
|
del self.flows[id]
|
||||||
|
return
|
||||||
|
|
||||||
|
def dispatch_config(self):
|
||||||
|
return DispatcherWrapper(self, "config", ConfigRequestor)
|
||||||
|
|
||||||
|
def dispatch_flow(self):
|
||||||
|
return DispatcherWrapper(self, "flow", FlowRequestor)
|
||||||
|
|
||||||
|
def dispatch_librarian(self):
|
||||||
|
return DispatcherWrapper(self, "librarian", LibrarianRequestor)
|
||||||
|
|
||||||
|
async def process_impl(self, data, responder, name, impl):
|
||||||
|
|
||||||
|
key = (None, name)
|
||||||
|
|
||||||
|
if key in self.dispatchers:
|
||||||
|
return await self.dispatchers[key].process(data, responder)
|
||||||
|
|
||||||
|
dispatcher = impl(
|
||||||
|
pulsar_client = self.pulsar_client
|
||||||
|
)
|
||||||
|
|
||||||
|
await dispatcher.start()
|
||||||
|
|
||||||
|
self.dispatchers[key] = dispatcher
|
||||||
|
|
||||||
|
return await dispatcher.process(data, responder)
|
||||||
|
|
||||||
|
def dispatch_service(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def dispatch_import(self):
|
||||||
|
return self.invoke_import
|
||||||
|
|
||||||
|
def dispatch_export(self):
|
||||||
|
return self.invoke_export
|
||||||
|
|
||||||
|
async def invoke_import(self, ws, running, params):
|
||||||
|
|
||||||
|
flow = params.get("flow")
|
||||||
|
kind = params.get("kind")
|
||||||
|
|
||||||
|
if flow not in self.flows:
|
||||||
|
raise RuntimeError("Invalid flow")
|
||||||
|
|
||||||
|
if kind not in import_dispatchers:
|
||||||
|
raise RuntimeError("Invalid kind")
|
||||||
|
|
||||||
|
key = (flow, kind)
|
||||||
|
|
||||||
|
intf_defs = self.flows[flow]["interfaces"]
|
||||||
|
|
||||||
|
if kind not in intf_defs:
|
||||||
|
raise RuntimeError("This kind not supported by flow")
|
||||||
|
|
||||||
|
# FIXME: The -store bit, does it make sense?
|
||||||
|
qconfig = intf_defs[kind + "-store"]
|
||||||
|
|
||||||
|
id = str(uuid.uuid4())
|
||||||
|
dispatcher = import_dispatchers[kind](
|
||||||
|
pulsar_client = self.pulsar_client,
|
||||||
|
ws = ws,
|
||||||
|
running = running,
|
||||||
|
queue = qconfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dispatcher
|
||||||
|
|
||||||
|
async def invoke_export(self, ws, running, params):
|
||||||
|
|
||||||
|
flow = params.get("flow")
|
||||||
|
kind = params.get("kind")
|
||||||
|
|
||||||
|
if flow not in self.flows:
|
||||||
|
raise RuntimeError("Invalid flow")
|
||||||
|
|
||||||
|
if kind not in export_dispatchers:
|
||||||
|
raise RuntimeError("Invalid kind")
|
||||||
|
|
||||||
|
key = (flow, kind)
|
||||||
|
|
||||||
|
intf_defs = self.flows[flow]["interfaces"]
|
||||||
|
|
||||||
|
if kind not in intf_defs:
|
||||||
|
raise RuntimeError("This kind not supported by flow")
|
||||||
|
|
||||||
|
# FIXME: The -store bit, does it make sense?
|
||||||
|
qconfig = intf_defs[kind + "-store"]
|
||||||
|
|
||||||
|
id = str(uuid.uuid4())
|
||||||
|
dispatcher = export_dispatchers[kind](
|
||||||
|
pulsar_client = self.pulsar_client,
|
||||||
|
ws = ws,
|
||||||
|
running = running,
|
||||||
|
queue = qconfig,
|
||||||
|
consumer = f"api-gateway-{id}",
|
||||||
|
subscriber = f"api-gateway-{id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return dispatcher
|
||||||
|
|
||||||
|
async def process(self, data, responder, params):
|
||||||
|
|
||||||
|
flow = params.get("flow")
|
||||||
|
kind = params.get("kind")
|
||||||
|
|
||||||
|
if flow not in self.flows:
|
||||||
|
raise RuntimeError("Invalid flow")
|
||||||
|
|
||||||
|
key = (flow, kind)
|
||||||
|
|
||||||
|
if key in self.dispatchers:
|
||||||
|
return await self.dispatchers[key].process(data, responder)
|
||||||
|
|
||||||
|
intf_defs = self.flows[flow]["interfaces"]
|
||||||
|
|
||||||
|
if kind not in intf_defs:
|
||||||
|
raise RuntimeError("This kind not supported by flow")
|
||||||
|
|
||||||
|
qconfig = intf_defs[kind]
|
||||||
|
|
||||||
|
if kind in request_response_dispatchers:
|
||||||
|
dispatcher = request_response_dispatchers[kind](
|
||||||
|
pulsar_client = self.pulsar_client,
|
||||||
|
request_queue = qconfig["request"],
|
||||||
|
response_queue = qconfig["response"],
|
||||||
|
timeout = 120,
|
||||||
|
consumer = f"api-gateway-{flow}-{kind}-request",
|
||||||
|
subscriber = f"api-gateway-{flow}-{kind}-request",
|
||||||
|
)
|
||||||
|
elif kind in sender_dispatchers:
|
||||||
|
dispatcher = sender_dispatchers[kind](
|
||||||
|
pulsar_client = self.pulsar_client,
|
||||||
|
queue = qconfig,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Invalid kind")
|
||||||
|
|
||||||
|
await dispatcher.start()
|
||||||
|
|
||||||
|
self.dispatchers[key] = dispatcher
|
||||||
|
|
||||||
|
return await dispatcher.process(data, responder)
|
||||||
|
|
||||||
|
|
@ -1,14 +1,13 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from .. schema import PromptRequest, PromptResponse
|
from ... schema import PromptRequest, PromptResponse
|
||||||
|
|
||||||
from . endpoint import ServiceEndpoint
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
class PromptRequestor(ServiceRequestor):
|
class PromptRequestor(ServiceRequestor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
self, pulsar_client, request_queue, response_queue, timeout,
|
||||||
consumer, subscriber,
|
consumer, subscriber,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
@ -3,8 +3,8 @@ import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .. base import Publisher
|
from ... base import Publisher
|
||||||
from .. base import Subscriber
|
from ... base import Subscriber
|
||||||
|
|
||||||
logger = logging.getLogger("requestor")
|
logger = logging.getLogger("requestor")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
@ -33,13 +33,17 @@ class ServiceRequestor:
|
||||||
|
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
await self.pub.start()
|
self.running = True
|
||||||
await self.sub.start()
|
await self.sub.start()
|
||||||
|
await self.pub.start()
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
await self.pub.stop()
|
await self.pub.stop()
|
||||||
await self.sub.stop()
|
await self.sub.stop()
|
||||||
|
self.running = False
|
||||||
|
|
||||||
def to_request(self, request):
|
def to_request(self, request):
|
||||||
raise RuntimeError("Not defined")
|
raise RuntimeError("Not defined")
|
||||||
|
|
@ -57,13 +61,14 @@ class ServiceRequestor:
|
||||||
|
|
||||||
await self.pub.send(id, self.to_request(request))
|
await self.pub.send(id, self.to_request(request))
|
||||||
|
|
||||||
while True:
|
while self.running:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await asyncio.wait_for(
|
resp = await asyncio.wait_for(
|
||||||
q.get(), timeout=self.timeout
|
q.get(), timeout=self.timeout
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print("Exception", e)
|
||||||
raise RuntimeError("Timeout")
|
raise RuntimeError("Timeout")
|
||||||
|
|
||||||
if resp.error:
|
if resp.error:
|
||||||
|
|
@ -5,7 +5,7 @@ import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .. base import Publisher
|
from ... base import Publisher
|
||||||
|
|
||||||
logger = logging.getLogger("sender")
|
logger = logging.getLogger("sender")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
@ -15,18 +15,20 @@ class ServiceSender:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pulsar_client,
|
pulsar_client,
|
||||||
request_queue, request_schema,
|
queue, schema,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.pub = Publisher(
|
self.pub = Publisher(
|
||||||
pulsar_client, request_queue,
|
pulsar_client, queue,
|
||||||
schema=request_schema,
|
schema=schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
|
||||||
await self.pub.start()
|
await self.pub.start()
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
await self.pub.stop()
|
||||||
|
|
||||||
def to_request(self, request):
|
def to_request(self, request):
|
||||||
raise RuntimeError("Not defined")
|
raise RuntimeError("Not defined")
|
||||||
|
|
||||||
|
|
@ -39,6 +41,8 @@ class ServiceSender:
|
||||||
if responder:
|
if responder:
|
||||||
await responder({}, True)
|
await responder({}, True)
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
logging.error(f"Exception: {e}")
|
logging.error(f"Exception: {e}")
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
from .. schema import Value, Triple, DocumentPackage, DocumentInfo
|
from ... schema import Value, Triple, DocumentPackage, DocumentInfo
|
||||||
|
|
||||||
def to_value(x):
|
def to_value(x):
|
||||||
return Value(value=x["v"], is_uri=x["e"])
|
return Value(value=x["v"], is_uri=x["e"])
|
||||||
99
trustgraph-flow/trustgraph/gateway/dispatch/streamer.py
Normal file
99
trustgraph-flow/trustgraph/gateway/dispatch/streamer.py
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from ... base import Publisher
|
||||||
|
from ... base import Subscriber
|
||||||
|
|
||||||
|
logger = logging.getLogger("requestor")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
class ServiceRequestor:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pulsar_client,
|
||||||
|
queue, schema,
|
||||||
|
handler,
|
||||||
|
subscription="api-gateway", consumer_name="api-gateway",
|
||||||
|
timeout=600,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.sub = Subscriber(
|
||||||
|
pulsar_client, queue,
|
||||||
|
subscription, consumer_name,
|
||||||
|
schema
|
||||||
|
)
|
||||||
|
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
self.receiver = handler
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
await self.sub.start()
|
||||||
|
self.streamer = asyncio.create_task(self.stream())
|
||||||
|
sub.start()
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
await self.sub.stop()
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
def from_inbound(self, response):
|
||||||
|
raise RuntimeError("Not defined")
|
||||||
|
|
||||||
|
async def stream(self):
|
||||||
|
|
||||||
|
id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
q = await self.sub.subscribe(id)
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = await asyncio.wait_for(
|
||||||
|
q.get(), timeout=self.timeout
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("Timeout")
|
||||||
|
|
||||||
|
if resp.error:
|
||||||
|
err = { "error": {
|
||||||
|
"type": resp.error.type,
|
||||||
|
"message": resp.error.message,
|
||||||
|
} }
|
||||||
|
|
||||||
|
fin = False
|
||||||
|
|
||||||
|
await self.receiver(err, fin)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
resp, fin = self.from_inbound(resp)
|
||||||
|
|
||||||
|
print(resp, fin)
|
||||||
|
|
||||||
|
await self.receiver(resp, fin)
|
||||||
|
|
||||||
|
if fin: break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
logging.error(f"Exception: {e}")
|
||||||
|
|
||||||
|
err = { "error": {
|
||||||
|
"type": "gateway-error",
|
||||||
|
"message": str(e),
|
||||||
|
} }
|
||||||
|
if responder:
|
||||||
|
await responder(err, True)
|
||||||
|
return err
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self.sub.unsubscribe(id)
|
||||||
|
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
|
|
||||||
from .. schema import TextCompletionRequest, TextCompletionResponse
|
from ... schema import TextCompletionRequest, TextCompletionResponse
|
||||||
|
|
||||||
from . endpoint import ServiceEndpoint
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
|
|
||||||
class TextCompletionRequestor(ServiceRequestor):
|
class TextCompletionRequestor(ServiceRequestor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
self, pulsar_client, request_queue, response_queue, timeout,
|
||||||
consumer, subscriber,
|
consumer, subscriber,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
@ -1,19 +1,18 @@
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
from .. schema import TextDocument, Metadata
|
from ... schema import TextDocument, Metadata
|
||||||
from .. schema import text_ingest_queue
|
|
||||||
|
|
||||||
from . sender import ServiceSender
|
from . sender import ServiceSender
|
||||||
from . serialize import to_subgraph
|
from . serialize import to_subgraph
|
||||||
|
|
||||||
class TextLoadSender(ServiceSender):
|
class TextLoad(ServiceSender):
|
||||||
def __init__(self, pulsar_client):
|
def __init__(self, pulsar_client, queue):
|
||||||
|
|
||||||
super(TextLoadSender, self).__init__(
|
super(TextLoad, self).__init__(
|
||||||
pulsar_client=pulsar_client,
|
pulsar_client = pulsar_client,
|
||||||
request_queue=text_ingest_queue,
|
queue = queue,
|
||||||
request_schema=TextDocument,
|
schema = TextDocument,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_request(self, body):
|
def to_request(self, body):
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import queue
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from ... schema import Triples
|
||||||
|
from ... base import Subscriber
|
||||||
|
|
||||||
|
from . serialize import serialize_triples
|
||||||
|
|
||||||
|
class TriplesExport:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, ws, running, pulsar_client, queue, consumer, subscriber
|
||||||
|
):
|
||||||
|
|
||||||
|
self.ws = ws
|
||||||
|
self.running = running
|
||||||
|
self.pulsar_client = pulsar_client
|
||||||
|
self.queue = queue
|
||||||
|
self.consumer = consumer
|
||||||
|
self.subscriber = subscriber
|
||||||
|
|
||||||
|
async def destroy(self):
|
||||||
|
self.running.stop()
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
async def receive(self, msg):
|
||||||
|
# Ignore incoming info from websocket
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
|
||||||
|
subs = Subscriber(
|
||||||
|
client = self.pulsar_client, topic = self.queue,
|
||||||
|
consumer_name = self.consumer, subscription = self.subscriber,
|
||||||
|
schema = Triples
|
||||||
|
)
|
||||||
|
|
||||||
|
await subs.start()
|
||||||
|
|
||||||
|
id = str(uuid.uuid4())
|
||||||
|
q = await subs.subscribe_all(id)
|
||||||
|
|
||||||
|
while self.running.get():
|
||||||
|
try:
|
||||||
|
|
||||||
|
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||||
|
await self.ws.send_json(serialize_triples(resp))
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Exception: {str(e)}", flush=True)
|
||||||
|
break
|
||||||
|
|
||||||
|
await subs.unsubscribe_all(id)
|
||||||
|
|
||||||
|
await subs.stop()
|
||||||
|
|
||||||
|
await self.ws.close()
|
||||||
|
self.running.stop()
|
||||||
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from aiohttp import WSMsgType
|
||||||
|
|
||||||
|
from ... schema import Metadata
|
||||||
|
from ... schema import Triples
|
||||||
|
from ... base import Publisher
|
||||||
|
|
||||||
|
from . serialize import to_subgraph
|
||||||
|
|
||||||
|
class TriplesImport:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, ws, running, pulsar_client, queue
|
||||||
|
):
|
||||||
|
|
||||||
|
self.ws = ws
|
||||||
|
self.running = running
|
||||||
|
|
||||||
|
self.publisher = Publisher(
|
||||||
|
pulsar_client, topic = queue, schema = Triples
|
||||||
|
)
|
||||||
|
|
||||||
|
async def destroy(self):
|
||||||
|
self.running.stop()
|
||||||
|
|
||||||
|
if self.ws:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
await self.publisher.stop()
|
||||||
|
|
||||||
|
async def receive(self, msg):
|
||||||
|
|
||||||
|
data = msg.json()
|
||||||
|
|
||||||
|
elt = Triples(
|
||||||
|
metadata=Metadata(
|
||||||
|
id=data["metadata"]["id"],
|
||||||
|
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||||
|
user=data["metadata"]["user"],
|
||||||
|
collection=data["metadata"]["collection"],
|
||||||
|
),
|
||||||
|
triples=to_subgraph(data["triples"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.publisher.send(None, elt)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
|
||||||
|
while self.running.get():
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
if self.ws:
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
self.ws = None
|
||||||
|
|
||||||
|
|
@ -1,13 +1,12 @@
|
||||||
|
|
||||||
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Triples
|
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples
|
||||||
|
|
||||||
from . endpoint import ServiceEndpoint
|
|
||||||
from . requestor import ServiceRequestor
|
from . requestor import ServiceRequestor
|
||||||
from . serialize import to_value, serialize_subgraph
|
from . serialize import to_value, serialize_subgraph
|
||||||
|
|
||||||
class TriplesQueryRequestor(ServiceRequestor):
|
class TriplesQueryRequestor(ServiceRequestor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
self, pulsar_client, request_queue, response_queue, timeout,
|
||||||
consumer, subscriber,
|
consumer, subscriber,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|
@ -1,63 +0,0 @@
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import uuid
|
|
||||||
from aiohttp import WSMsgType
|
|
||||||
|
|
||||||
from .. schema import Metadata
|
|
||||||
from .. schema import DocumentEmbeddings, ChunkEmbeddings
|
|
||||||
from .. schema import document_embeddings_store_queue
|
|
||||||
from .. base import Publisher
|
|
||||||
|
|
||||||
from . socket import SocketEndpoint
|
|
||||||
from . serialize import to_subgraph
|
|
||||||
|
|
||||||
class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, pulsar_client, auth, path="/api/v1/load/document-embeddings",
|
|
||||||
):
|
|
||||||
|
|
||||||
super(DocumentEmbeddingsLoadEndpoint, self).__init__(
|
|
||||||
endpoint_path=path, auth=auth,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pulsar_client=pulsar_client
|
|
||||||
|
|
||||||
self.publisher = Publisher(
|
|
||||||
self.pulsar_client, document_embeddings_store_queue,
|
|
||||||
schema=DocumentEmbeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
|
|
||||||
await self.publisher.start()
|
|
||||||
|
|
||||||
async def listener(self, ws, running):
|
|
||||||
|
|
||||||
async for msg in ws:
|
|
||||||
# On error, finish
|
|
||||||
if msg.type == WSMsgType.ERROR:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
|
|
||||||
data = msg.json()
|
|
||||||
|
|
||||||
elt = DocumentEmbeddings(
|
|
||||||
metadata=Metadata(
|
|
||||||
id=data["metadata"]["id"],
|
|
||||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
|
||||||
user=data["metadata"]["user"],
|
|
||||||
collection=data["metadata"]["collection"],
|
|
||||||
),
|
|
||||||
chunks=[
|
|
||||||
ChunkEmbeddings(
|
|
||||||
chunk=de["chunk"].encode("utf-8"),
|
|
||||||
vectors=de["vectors"],
|
|
||||||
)
|
|
||||||
for de in data["chunks"]
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.publisher.send(None, elt)
|
|
||||||
|
|
||||||
running.stop()
|
|
||||||
|
|
@ -1,72 +0,0 @@
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import queue
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from .. schema import DocumentEmbeddings
|
|
||||||
from .. schema import document_embeddings_store_queue
|
|
||||||
from .. base import Subscriber
|
|
||||||
|
|
||||||
from . socket import SocketEndpoint
|
|
||||||
from . serialize import serialize_document_embeddings
|
|
||||||
|
|
||||||
class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, pulsar_client, auth,
|
|
||||||
path="/api/v1/stream/document-embeddings"
|
|
||||||
):
|
|
||||||
|
|
||||||
super(DocumentEmbeddingsStreamEndpoint, self).__init__(
|
|
||||||
endpoint_path=path, auth=auth,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pulsar_client=pulsar_client
|
|
||||||
|
|
||||||
self.subscriber = Subscriber(
|
|
||||||
self.pulsar_client, document_embeddings_store_queue,
|
|
||||||
"api-gateway", "api-gateway",
|
|
||||||
schema=DocumentEmbeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def listener(self, ws, running):
|
|
||||||
|
|
||||||
worker = asyncio.create_task(
|
|
||||||
self.async_thread(ws, running)
|
|
||||||
)
|
|
||||||
|
|
||||||
await super(DocumentEmbeddingsStreamEndpoint, self).listener(
|
|
||||||
ws, running
|
|
||||||
)
|
|
||||||
|
|
||||||
await worker
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
|
|
||||||
await self.subscriber.start()
|
|
||||||
|
|
||||||
async def async_thread(self, ws, running):
|
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
q = await self.subscriber.subscribe_all(id)
|
|
||||||
|
|
||||||
while running.get():
|
|
||||||
try:
|
|
||||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
|
||||||
await ws.send_json(serialize_document_embeddings(resp))
|
|
||||||
|
|
||||||
except TimeoutError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception: {str(e)}", flush=True)
|
|
||||||
break
|
|
||||||
|
|
||||||
await self.subscriber.unsubscribe_all(id)
|
|
||||||
|
|
||||||
running.stop()
|
|
||||||
|
|
||||||
|
|
@ -7,19 +7,19 @@ import logging
|
||||||
logger = logging.getLogger("endpoint")
|
logger = logging.getLogger("endpoint")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
class ServiceEndpoint:
|
class ConstantEndpoint:
|
||||||
|
|
||||||
def __init__(self, endpoint_path, auth, requestor):
|
def __init__(self, endpoint_path, auth, dispatcher):
|
||||||
|
|
||||||
self.path = endpoint_path
|
self.path = endpoint_path
|
||||||
|
|
||||||
self.auth = auth
|
self.auth = auth
|
||||||
self.operation = "service"
|
self.operation = "service"
|
||||||
|
|
||||||
self.requestor = requestor
|
self.dispatcher = dispatcher
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
await self.requestor.start()
|
pass
|
||||||
|
|
||||||
def add_routes(self, app):
|
def add_routes(self, app):
|
||||||
|
|
||||||
|
|
@ -52,7 +52,7 @@ class ServiceEndpoint:
|
||||||
async def responder(x, fin):
|
async def responder(x, fin):
|
||||||
print(x)
|
print(x)
|
||||||
|
|
||||||
resp = await self.requestor.process(data, responder)
|
resp = await self.dispatcher.process(data, responder)
|
||||||
|
|
||||||
return web.json_response(resp)
|
return web.json_response(resp)
|
||||||
|
|
||||||
67
trustgraph-flow/trustgraph/gateway/endpoint/manager.py
Normal file
67
trustgraph-flow/trustgraph/gateway/endpoint/manager.py
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from . constant_endpoint import ConstantEndpoint
|
||||||
|
from . variable_endpoint import VariableEndpoint
|
||||||
|
from . socket import SocketEndpoint
|
||||||
|
from . metrics import MetricsEndpoint
|
||||||
|
|
||||||
|
from .. dispatch.manager import DispatcherManager
|
||||||
|
|
||||||
|
class EndpointManager:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, dispatcher_manager, auth, prometheus_url, timeout=600
|
||||||
|
):
|
||||||
|
|
||||||
|
self.dispatcher_manager = dispatcher_manager
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
self.services = {
|
||||||
|
}
|
||||||
|
|
||||||
|
self.endpoints = [
|
||||||
|
ConstantEndpoint(
|
||||||
|
endpoint_path = "/api/v1/librarian", auth = auth,
|
||||||
|
dispatcher = dispatcher_manager.dispatch_librarian(),
|
||||||
|
),
|
||||||
|
ConstantEndpoint(
|
||||||
|
endpoint_path = "/api/v1/config", auth = auth,
|
||||||
|
dispatcher = dispatcher_manager.dispatch_config(),
|
||||||
|
),
|
||||||
|
ConstantEndpoint(
|
||||||
|
endpoint_path = "/api/v1/flow", auth = auth,
|
||||||
|
dispatcher = dispatcher_manager.dispatch_flow(),
|
||||||
|
),
|
||||||
|
MetricsEndpoint(
|
||||||
|
endpoint_path = "/api/v1/metrics",
|
||||||
|
prometheus_url = prometheus_url,
|
||||||
|
auth = auth,
|
||||||
|
),
|
||||||
|
VariableEndpoint(
|
||||||
|
endpoint_path = "/api/v1/flow/{flow}/service/{kind}",
|
||||||
|
auth = auth,
|
||||||
|
dispatcher = dispatcher_manager.dispatch_service(),
|
||||||
|
),
|
||||||
|
SocketEndpoint(
|
||||||
|
endpoint_path = "/api/v1/flow/{flow}/import/{kind}",
|
||||||
|
auth = auth,
|
||||||
|
dispatcher = dispatcher_manager.dispatch_import()
|
||||||
|
),
|
||||||
|
SocketEndpoint(
|
||||||
|
endpoint_path = "/api/v1/flow/{flow}/export/{kind}",
|
||||||
|
auth = auth,
|
||||||
|
dispatcher = dispatcher_manager.dispatch_export()
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def add_routes(self, app):
|
||||||
|
for ep in self.endpoints:
|
||||||
|
ep.add_routes(app)
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
for ep in self.endpoints:
|
||||||
|
await ep.start()
|
||||||
|
|
||||||
111
trustgraph-flow/trustgraph/gateway/endpoint/socket.py
Normal file
111
trustgraph-flow/trustgraph/gateway/endpoint/socket.py
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from aiohttp import web, WSMsgType
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from .. running import Running
|
||||||
|
|
||||||
|
logger = logging.getLogger("socket")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
class SocketEndpoint:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, endpoint_path, auth, dispatcher,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.path = endpoint_path
|
||||||
|
self.auth = auth
|
||||||
|
self.operation = "socket"
|
||||||
|
|
||||||
|
self.dispatcher = dispatcher
|
||||||
|
|
||||||
|
async def worker(self, ws, dispatcher, running):
|
||||||
|
|
||||||
|
await dispatcher.run()
|
||||||
|
|
||||||
|
async def listener(self, ws, dispatcher, running):
|
||||||
|
|
||||||
|
async for msg in ws:
|
||||||
|
|
||||||
|
# On error, finish
|
||||||
|
if msg.type == WSMsgType.TEXT:
|
||||||
|
await dispatcher.receive(msg)
|
||||||
|
continue
|
||||||
|
elif msg.type == WSMsgType.BINARY:
|
||||||
|
await dispatcher.receive(msg)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
running.stop()
|
||||||
|
await ws.close()
|
||||||
|
|
||||||
|
async def handle(self, request):
|
||||||
|
|
||||||
|
try:
|
||||||
|
token = request.query['token']
|
||||||
|
except:
|
||||||
|
token = ""
|
||||||
|
|
||||||
|
if not self.auth.permitted(token, self.operation):
|
||||||
|
return web.HTTPUnauthorized()
|
||||||
|
|
||||||
|
# 50MB max message size
|
||||||
|
ws = web.WebSocketResponse(max_msg_size=52428800)
|
||||||
|
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
async with asyncio.TaskGroup() as tg:
|
||||||
|
|
||||||
|
running = Running()
|
||||||
|
|
||||||
|
dispatcher = await self.dispatcher(
|
||||||
|
ws, running, request.match_info
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_task = tg.create_task(
|
||||||
|
self.worker(ws, dispatcher, running)
|
||||||
|
)
|
||||||
|
|
||||||
|
lsnr_task = tg.create_task(
|
||||||
|
self.listener(ws, dispatcher, running)
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Created taskgroup, waiting...")
|
||||||
|
|
||||||
|
# Wait for threads to complete
|
||||||
|
|
||||||
|
print("Task group closed")
|
||||||
|
|
||||||
|
# Finally?
|
||||||
|
await dispatcher.destroy()
|
||||||
|
|
||||||
|
except ExceptionGroup as e:
|
||||||
|
|
||||||
|
print("Exception group:", flush=True)
|
||||||
|
|
||||||
|
for se in e.exceptions:
|
||||||
|
print(" Type:", type(se), flush=True)
|
||||||
|
print(f" Exception: {se}", flush=True)
|
||||||
|
except Exception as e:
|
||||||
|
print("Socket exception:", e, flush=True)
|
||||||
|
|
||||||
|
await ws.close()
|
||||||
|
|
||||||
|
return ws
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
self.running.stop()
|
||||||
|
|
||||||
|
def add_routes(self, app):
|
||||||
|
|
||||||
|
app.add_routes([
|
||||||
|
web.get(self.path, self.handle),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
@ -4,26 +4,25 @@ from aiohttp import web
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger("flow-endpoint")
|
logger = logging.getLogger("endpoint")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
class FlowEndpoint:
|
class VariableEndpoint:
|
||||||
|
|
||||||
def __init__(self, endpoint_path, auth, requestors):
|
def __init__(self, endpoint_path, auth, dispatcher):
|
||||||
|
|
||||||
self.path = endpoint_path
|
self.path = endpoint_path
|
||||||
|
|
||||||
self.auth = auth
|
self.auth = auth
|
||||||
self.operation = "service"
|
self.operation = "service"
|
||||||
|
|
||||||
self.requestors = requestors
|
self.dispatcher = dispatcher
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def add_routes(self, app):
|
def add_routes(self, app):
|
||||||
|
|
||||||
pass
|
|
||||||
app.add_routes([
|
app.add_routes([
|
||||||
web.post(self.path, self.handle),
|
web.post(self.path, self.handle),
|
||||||
])
|
])
|
||||||
|
|
@ -32,15 +31,6 @@ class FlowEndpoint:
|
||||||
|
|
||||||
print(request.path, "...")
|
print(request.path, "...")
|
||||||
|
|
||||||
flow_id = request.match_info['flow']
|
|
||||||
kind = request.match_info['kind']
|
|
||||||
k = (flow_id, kind)
|
|
||||||
|
|
||||||
if k not in self.requestors:
|
|
||||||
raise web.HTTPBadRequest()
|
|
||||||
|
|
||||||
requestor = self.requestors[k]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ht = request.headers["Authorization"]
|
ht = request.headers["Authorization"]
|
||||||
tokens = ht.split(" ", 2)
|
tokens = ht.split(" ", 2)
|
||||||
|
|
@ -62,7 +52,9 @@ class FlowEndpoint:
|
||||||
async def responder(x, fin):
|
async def responder(x, fin):
|
||||||
print(x)
|
print(x)
|
||||||
|
|
||||||
resp = await requestor.process(data, responder)
|
resp = await self.dispatcher.process(
|
||||||
|
data, responder, request.match_info
|
||||||
|
)
|
||||||
|
|
||||||
return web.json_response(resp)
|
return web.json_response(resp)
|
||||||
|
|
||||||
|
|
@ -1,64 +0,0 @@
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import uuid
|
|
||||||
from aiohttp import WSMsgType
|
|
||||||
|
|
||||||
from .. schema import Metadata
|
|
||||||
from .. schema import GraphEmbeddings, EntityEmbeddings
|
|
||||||
from .. schema import graph_embeddings_store_queue
|
|
||||||
from .. base import Publisher
|
|
||||||
|
|
||||||
from . socket import SocketEndpoint
|
|
||||||
from . serialize import to_subgraph, to_value
|
|
||||||
|
|
||||||
class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, pulsar_client, auth, path="/api/v1/load/graph-embeddings",
|
|
||||||
):
|
|
||||||
|
|
||||||
super(GraphEmbeddingsLoadEndpoint, self).__init__(
|
|
||||||
endpoint_path=path, auth=auth,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pulsar_client=pulsar_client
|
|
||||||
|
|
||||||
self.publisher = Publisher(
|
|
||||||
self.pulsar_client, graph_embeddings_store_queue,
|
|
||||||
schema=GraphEmbeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
|
|
||||||
await self.publisher.start()
|
|
||||||
|
|
||||||
async def listener(self, ws, running):
|
|
||||||
|
|
||||||
async for msg in ws:
|
|
||||||
|
|
||||||
# On error, finish
|
|
||||||
if msg.type == WSMsgType.ERROR:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
|
|
||||||
data = msg.json()
|
|
||||||
|
|
||||||
elt = GraphEmbeddings(
|
|
||||||
metadata=Metadata(
|
|
||||||
id=data["metadata"]["id"],
|
|
||||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
|
||||||
user=data["metadata"]["user"],
|
|
||||||
collection=data["metadata"]["collection"],
|
|
||||||
),
|
|
||||||
entities=[
|
|
||||||
EntityEmbeddings(
|
|
||||||
entity=to_value(ent["entity"]),
|
|
||||||
vectors=ent["vectors"],
|
|
||||||
)
|
|
||||||
for ent in data["entities"]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.publisher.send(None, elt)
|
|
||||||
|
|
||||||
running.stop()
|
|
||||||
|
|
@ -1,69 +0,0 @@
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import queue
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from .. schema import GraphEmbeddings
|
|
||||||
from .. schema import graph_embeddings_store_queue
|
|
||||||
from .. base import Subscriber
|
|
||||||
|
|
||||||
from . socket import SocketEndpoint
|
|
||||||
from . serialize import serialize_graph_embeddings
|
|
||||||
|
|
||||||
class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, pulsar_client, auth, path="/api/v1/stream/graph-embeddings"
|
|
||||||
):
|
|
||||||
|
|
||||||
super(GraphEmbeddingsStreamEndpoint, self).__init__(
|
|
||||||
endpoint_path=path, auth=auth,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pulsar_client=pulsar_client
|
|
||||||
|
|
||||||
self.subscriber = Subscriber(
|
|
||||||
self.pulsar_client, graph_embeddings_store_queue,
|
|
||||||
"api-gateway", "api-gateway",
|
|
||||||
schema=GraphEmbeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
async def listener(self, ws, running):
|
|
||||||
|
|
||||||
worker = asyncio.create_task(
|
|
||||||
self.async_thread(ws, running)
|
|
||||||
)
|
|
||||||
|
|
||||||
await super(GraphEmbeddingsStreamEndpoint, self).listener(ws, running)
|
|
||||||
|
|
||||||
await worker
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
|
|
||||||
await self.subscriber.start()
|
|
||||||
|
|
||||||
async def async_thread(self, ws, running):
|
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
q = await self.subscriber.subscribe_all(id)
|
|
||||||
|
|
||||||
while running.get():
|
|
||||||
try:
|
|
||||||
resp = await asyncio.wait_for(q.get, timeout=0.5)
|
|
||||||
await ws.send_json(serialize_graph_embeddings(resp))
|
|
||||||
|
|
||||||
except TimeoutError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception: {str(e)}", flush=True)
|
|
||||||
break
|
|
||||||
|
|
||||||
await self.subscriber.unsubscribe_all(id)
|
|
||||||
|
|
||||||
running.stop()
|
|
||||||
|
|
||||||
|
|
@ -3,63 +3,22 @@ API gateway. Offers HTTP services which are translated to interaction on the
|
||||||
Pulsar bus.
|
Pulsar bus.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
module = "api-gateway"
|
|
||||||
|
|
||||||
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
|
|
||||||
# are active listeners
|
|
||||||
|
|
||||||
# FIXME: Connection errors in publishers / subscribers cause those threads
|
|
||||||
# to fail and are not failed or retried
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import argparse
|
import argparse
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import base64
|
|
||||||
import uuid
|
|
||||||
import json
|
|
||||||
|
|
||||||
import pulsar
|
|
||||||
from prometheus_client import start_http_server
|
|
||||||
|
|
||||||
from .. log_level import LogLevel
|
from .. log_level import LogLevel
|
||||||
|
|
||||||
from . serialize import to_subgraph
|
|
||||||
from . running import Running
|
|
||||||
|
|
||||||
from .. schema import ConfigPush, config_push_queue
|
|
||||||
|
|
||||||
from . text_completion import TextCompletionRequestor
|
|
||||||
from . prompt import PromptRequestor
|
|
||||||
from . graph_rag import GraphRagRequestor
|
|
||||||
#from . document_rag import DocumentRagRequestor
|
|
||||||
from . triples_query import TriplesQueryRequestor
|
|
||||||
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
|
||||||
from . embeddings import EmbeddingsRequestor
|
|
||||||
#from . encyclopedia import EncyclopediaRequestor
|
|
||||||
from . agent import AgentRequestor
|
|
||||||
#from . dbpedia import DbpediaRequestor
|
|
||||||
#from . internet_search import InternetSearchRequestor
|
|
||||||
#from . librarian import LibrarianRequestor
|
|
||||||
from . config import ConfigRequestor
|
|
||||||
from . flow import FlowRequestor
|
|
||||||
#from . triples_stream import TriplesStreamEndpoint
|
|
||||||
#from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
|
|
||||||
#from . document_embeddings_stream import DocumentEmbeddingsStreamEndpoint
|
|
||||||
#from . triples_load import TriplesLoadEndpoint
|
|
||||||
#from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
|
|
||||||
#from . document_embeddings_load import DocumentEmbeddingsLoadEndpoint
|
|
||||||
from . mux import MuxEndpoint
|
|
||||||
#from . document_load import DocumentLoadSender
|
|
||||||
#from . text_load import TextLoadSender
|
|
||||||
from . metrics import MetricsEndpoint
|
|
||||||
|
|
||||||
from . endpoint import ServiceEndpoint
|
|
||||||
from . flow_endpoint import FlowEndpoint
|
|
||||||
from . auth import Authenticator
|
from . auth import Authenticator
|
||||||
from .. base import Subscriber
|
from . config.receiver import ConfigReceiver
|
||||||
from .. base import Consumer
|
from . dispatch.manager import DispatcherManager
|
||||||
|
|
||||||
|
from . endpoint.manager import EndpointManager
|
||||||
|
|
||||||
|
import pulsar
|
||||||
|
from prometheus_client import start_http_server
|
||||||
|
|
||||||
logger = logging.getLogger("api")
|
logger = logging.getLogger("api")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
@ -81,6 +40,7 @@ class Api:
|
||||||
self.pulsar_api_key = config.get(
|
self.pulsar_api_key = config.get(
|
||||||
"pulsar_api_key", default_pulsar_api_key
|
"pulsar_api_key", default_pulsar_api_key
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pulsar_listener = config.get("pulsar_listener", None)
|
self.pulsar_listener = config.get("pulsar_listener", None)
|
||||||
|
|
||||||
if self.pulsar_api_key:
|
if self.pulsar_api_key:
|
||||||
|
|
@ -108,278 +68,24 @@ class Api:
|
||||||
else:
|
else:
|
||||||
self.auth = Authenticator(allow_all=True)
|
self.auth = Authenticator(allow_all=True)
|
||||||
|
|
||||||
self.services = {
|
self.config_receiver = ConfigReceiver(self.pulsar_client)
|
||||||
# "text-completion": TextCompletionRequestor(
|
|
||||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
self.dispatcher_manager = DispatcherManager(
|
||||||
# auth = self.auth,
|
pulsar_client = self.pulsar_client,
|
||||||
# ),
|
config_receiver = self.config_receiver,
|
||||||
# "prompt": PromptRequestor(
|
)
|
||||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
# auth = self.auth,
|
self.endpoint_manager = EndpointManager(
|
||||||
# ),
|
dispatcher_manager = self.dispatcher_manager,
|
||||||
# "graph-rag": GraphRagRequestor(
|
auth = self.auth,
|
||||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
prometheus_url = self.prometheus_url,
|
||||||
# auth = self.auth,
|
timeout = self.timeout,
|
||||||
# ),
|
|
||||||
# "document-rag": DocumentRagRequestor(
|
)
|
||||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# "triples-query": TriplesQueryRequestor(
|
|
||||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# "graph-embeddings-query": GraphEmbeddingsQueryRequestor(
|
|
||||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# "embeddings": EmbeddingsRequestor(
|
|
||||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# "agent": AgentRequestor(
|
|
||||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# "librarian": LibrarianRequestor(
|
|
||||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
(None, "config"): ConfigRequestor(
|
|
||||||
pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
auth = self.auth,
|
|
||||||
),
|
|
||||||
(None, "flow"): FlowRequestor(
|
|
||||||
pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
auth = self.auth,
|
|
||||||
),
|
|
||||||
# "encyclopedia": EncyclopediaRequestor(
|
|
||||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# "dbpedia": DbpediaRequestor(
|
|
||||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# "internet-search": InternetSearchRequestor(
|
|
||||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# "document-load": DocumentLoadSender(
|
|
||||||
# pulsar_client=self.pulsar_client,
|
|
||||||
# ),
|
|
||||||
# "text-load": TextLoadSender(
|
|
||||||
# pulsar_client=self.pulsar_client,
|
|
||||||
# ),
|
|
||||||
}
|
|
||||||
|
|
||||||
self.endpoints = [
|
self.endpoints = [
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/text-completion", auth=self.auth,
|
|
||||||
# requestor = self.services["text-completion"],
|
|
||||||
# ),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/prompt", auth=self.auth,
|
|
||||||
# requestor = self.services["prompt"],
|
|
||||||
# ),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/graph-rag", auth=self.auth,
|
|
||||||
# requestor = self.services["graph-rag"],
|
|
||||||
# ),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/document-rag", auth=self.auth,
|
|
||||||
# requestor = self.services["document-rag"],
|
|
||||||
# ),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/triples-query", auth=self.auth,
|
|
||||||
# requestor = self.services["triples-query"],
|
|
||||||
# ),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/graph-embeddings-query",
|
|
||||||
# auth=self.auth,
|
|
||||||
# requestor = self.services["graph-embeddings-query"],
|
|
||||||
# ),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/embeddings", auth=self.auth,
|
|
||||||
# requestor = self.services["embeddings"],
|
|
||||||
# ),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/agent", auth=self.auth,
|
|
||||||
# requestor = self.services["agent"],
|
|
||||||
# ),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/librarian", auth=self.auth,
|
|
||||||
# requestor = self.services["librarian"],
|
|
||||||
# ),
|
|
||||||
ServiceEndpoint(
|
|
||||||
endpoint_path = "/api/v1/config", auth=self.auth,
|
|
||||||
requestor = self.services[(None, "config")],
|
|
||||||
),
|
|
||||||
ServiceEndpoint(
|
|
||||||
endpoint_path = "/api/v1/flow", auth=self.auth,
|
|
||||||
requestor = self.services[(None, "flow")],
|
|
||||||
),
|
|
||||||
FlowEndpoint(
|
|
||||||
endpoint_path = "/api/v1/flow/{flow}/{kind}",
|
|
||||||
auth=self.auth,
|
|
||||||
requestors = self.services,
|
|
||||||
),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/encyclopedia", auth=self.auth,
|
|
||||||
# requestor = self.services["encyclopedia"],
|
|
||||||
# ),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/dbpedia", auth=self.auth,
|
|
||||||
# requestor = self.services["dbpedia"],
|
|
||||||
# ),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/internet-search", auth=self.auth,
|
|
||||||
# requestor = self.services["internet-search"],
|
|
||||||
# ),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/load/document", auth=self.auth,
|
|
||||||
# requestor = self.services["document-load"],
|
|
||||||
# ),
|
|
||||||
# ServiceEndpoint(
|
|
||||||
# endpoint_path = "/api/v1/load/text", auth=self.auth,
|
|
||||||
# requestor = self.services["text-load"],
|
|
||||||
# ),
|
|
||||||
# TriplesStreamEndpoint(
|
|
||||||
# pulsar_client=self.pulsar_client,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# GraphEmbeddingsStreamEndpoint(
|
|
||||||
# pulsar_client=self.pulsar_client,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# DocumentEmbeddingsStreamEndpoint(
|
|
||||||
# pulsar_client=self.pulsar_client,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# TriplesLoadEndpoint(
|
|
||||||
# pulsar_client=self.pulsar_client,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# GraphEmbeddingsLoadEndpoint(
|
|
||||||
# pulsar_client=self.pulsar_client,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# DocumentEmbeddingsLoadEndpoint(
|
|
||||||
# pulsar_client=self.pulsar_client,
|
|
||||||
# auth = self.auth,
|
|
||||||
# ),
|
|
||||||
# MuxEndpoint(
|
|
||||||
# pulsar_client=self.pulsar_client,
|
|
||||||
# auth = self.auth,
|
|
||||||
# services = self.services,
|
|
||||||
# ),
|
|
||||||
MetricsEndpoint(
|
|
||||||
endpoint_path = "/api/v1/metrics",
|
|
||||||
prometheus_url = self.prometheus_url,
|
|
||||||
auth = self.auth,
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
self.flows = {}
|
|
||||||
|
|
||||||
async def on_config(self, msg, proc, flow):
|
|
||||||
|
|
||||||
try:
|
|
||||||
|
|
||||||
v = msg.value()
|
|
||||||
|
|
||||||
print(f"Config version", v.version)
|
|
||||||
|
|
||||||
if "flows" in v.config:
|
|
||||||
|
|
||||||
flows = v.config["flows"]
|
|
||||||
|
|
||||||
wanted = list(flows.keys())
|
|
||||||
current = list(self.flows.keys())
|
|
||||||
|
|
||||||
for k in wanted:
|
|
||||||
if k not in current:
|
|
||||||
self.flows[k] = json.loads(flows[k])
|
|
||||||
await self.start_flow(k, self.flows[k])
|
|
||||||
|
|
||||||
for k in current:
|
|
||||||
if k not in wanted:
|
|
||||||
await self.stop_flow(k, self.flows[k])
|
|
||||||
del self.flows[k]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception: {e}", flush=True)
|
|
||||||
|
|
||||||
async def start_flow(self, id, flow):
|
|
||||||
|
|
||||||
print("Start flow", id)
|
|
||||||
intf = flow["interfaces"]
|
|
||||||
|
|
||||||
kinds = {
|
|
||||||
"agent": AgentRequestor,
|
|
||||||
"text-completion": TextCompletionRequestor,
|
|
||||||
"prompt": PromptRequestor,
|
|
||||||
"graph-rag": GraphRagRequestor,
|
|
||||||
"embeddings": EmbeddingsRequestor,
|
|
||||||
"graph-embeddings": GraphEmbeddingsQueryRequestor,
|
|
||||||
"triples-query": TriplesQueryRequestor,
|
|
||||||
}
|
|
||||||
|
|
||||||
for api_kind, requestor in kinds.items():
|
|
||||||
|
|
||||||
if api_kind in intf:
|
|
||||||
k = (id, api_kind)
|
|
||||||
if k in self.services:
|
|
||||||
await self.services[k].stop()
|
|
||||||
del self.services[k]
|
|
||||||
|
|
||||||
self.services[k] = requestor(
|
|
||||||
pulsar_client=self.pulsar_client, timeout=self.timeout,
|
|
||||||
request_queue = intf[api_kind]["request"],
|
|
||||||
response_queue = intf[api_kind]["response"],
|
|
||||||
consumer = f"api-gateway-{id}-{api_kind}-request",
|
|
||||||
subscriber = f"api-gateway-{id}-{api_kind}-request",
|
|
||||||
auth = self.auth,
|
|
||||||
)
|
|
||||||
await self.services[k].start()
|
|
||||||
|
|
||||||
async def stop_flow(self, id, flow):
|
|
||||||
print("Stop flow", id)
|
|
||||||
intf = flow["interfaces"]
|
|
||||||
|
|
||||||
svc_list = list(self.services.keys())
|
|
||||||
|
|
||||||
for k in svc_list:
|
|
||||||
|
|
||||||
kid, kkind = k
|
|
||||||
|
|
||||||
if id == kid:
|
|
||||||
await self.services[k].stop()
|
|
||||||
del self.services[k]
|
|
||||||
|
|
||||||
async def config_loader(self):
|
|
||||||
|
|
||||||
async with asyncio.TaskGroup() as tg:
|
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
self.config_cons = Consumer(
|
|
||||||
taskgroup = tg,
|
|
||||||
flow = None,
|
|
||||||
client = self.pulsar_client,
|
|
||||||
subscriber = f"gateway-{id}",
|
|
||||||
topic = config_push_queue,
|
|
||||||
schema = ConfigPush,
|
|
||||||
handler = self.on_config,
|
|
||||||
start_of_messages = True,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.config_cons.start()
|
|
||||||
|
|
||||||
print("Waiting...")
|
|
||||||
|
|
||||||
print("Config consumer done. :/")
|
|
||||||
|
|
||||||
async def app_factory(self):
|
async def app_factory(self):
|
||||||
|
|
||||||
self.app = web.Application(
|
self.app = web.Application(
|
||||||
|
|
@ -387,7 +93,8 @@ class Api:
|
||||||
client_max_size=256 * 1024 * 1024
|
client_max_size=256 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
asyncio.create_task(self.config_loader())
|
await self.config_receiver.start()
|
||||||
|
|
||||||
|
|
||||||
for ep in self.endpoints:
|
for ep in self.endpoints:
|
||||||
ep.add_routes(self.app)
|
ep.add_routes(self.app)
|
||||||
|
|
@ -395,6 +102,9 @@ class Api:
|
||||||
for ep in self.endpoints:
|
for ep in self.endpoints:
|
||||||
await ep.start()
|
await ep.start()
|
||||||
|
|
||||||
|
self.endpoint_manager.add_routes(self.app)
|
||||||
|
await self.endpoint_manager.start()
|
||||||
|
|
||||||
return self.app
|
return self.app
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
|
||||||
|
|
@ -1,72 +0,0 @@
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from aiohttp import web, WSMsgType
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from . running import Running
|
|
||||||
|
|
||||||
logger = logging.getLogger("socket")
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
class SocketEndpoint:
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, endpoint_path, auth,
|
|
||||||
):
|
|
||||||
|
|
||||||
self.path = endpoint_path
|
|
||||||
self.auth = auth
|
|
||||||
self.operation = "socket"
|
|
||||||
|
|
||||||
async def listener(self, ws, running):
|
|
||||||
|
|
||||||
async for msg in ws:
|
|
||||||
# On error, finish
|
|
||||||
if msg.type == WSMsgType.TEXT:
|
|
||||||
# Ignore incoming message
|
|
||||||
continue
|
|
||||||
elif msg.type == WSMsgType.BINARY:
|
|
||||||
# Ignore incoming message
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
running.stop()
|
|
||||||
|
|
||||||
async def handle(self, request):
|
|
||||||
|
|
||||||
try:
|
|
||||||
token = request.query['token']
|
|
||||||
except:
|
|
||||||
token = ""
|
|
||||||
|
|
||||||
if not self.auth.permitted(token, self.operation):
|
|
||||||
return web.HTTPUnauthorized()
|
|
||||||
|
|
||||||
running = Running()
|
|
||||||
|
|
||||||
# 50MB max message size
|
|
||||||
ws = web.WebSocketResponse(max_msg_size=52428800)
|
|
||||||
|
|
||||||
await ws.prepare(request)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.listener(ws, running)
|
|
||||||
except Exception as e:
|
|
||||||
print("Socket exception:", e, flush=True)
|
|
||||||
|
|
||||||
running.stop()
|
|
||||||
|
|
||||||
await ws.close()
|
|
||||||
|
|
||||||
return ws
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def add_routes(self, app):
|
|
||||||
|
|
||||||
app.add_routes([
|
|
||||||
web.get(self.path, self.handle),
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
@ -1,56 +0,0 @@
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import uuid
|
|
||||||
from aiohttp import WSMsgType
|
|
||||||
|
|
||||||
from .. schema import Metadata
|
|
||||||
from .. schema import Triples
|
|
||||||
from .. schema import triples_store_queue
|
|
||||||
from .. base import Publisher
|
|
||||||
|
|
||||||
from . socket import SocketEndpoint
|
|
||||||
from . serialize import to_subgraph
|
|
||||||
|
|
||||||
class TriplesLoadEndpoint(SocketEndpoint):
|
|
||||||
|
|
||||||
def __init__(self, pulsar_client, auth, path="/api/v1/load/triples"):
|
|
||||||
|
|
||||||
super(TriplesLoadEndpoint, self).__init__(
|
|
||||||
endpoint_path=path, auth=auth,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pulsar_client=pulsar_client
|
|
||||||
|
|
||||||
self.publisher = Publisher(
|
|
||||||
self.pulsar_client, triples_store_queue,
|
|
||||||
schema=Triples
|
|
||||||
)
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
|
|
||||||
await self.publisher.start()
|
|
||||||
|
|
||||||
async def listener(self, ws, running):
|
|
||||||
|
|
||||||
async for msg in ws:
|
|
||||||
# On error, finish
|
|
||||||
if msg.type == WSMsgType.ERROR:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
|
|
||||||
data = msg.json()
|
|
||||||
|
|
||||||
elt = Triples(
|
|
||||||
metadata=Metadata(
|
|
||||||
id=data["metadata"]["id"],
|
|
||||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
|
||||||
user=data["metadata"]["user"],
|
|
||||||
collection=data["metadata"]["collection"],
|
|
||||||
),
|
|
||||||
triples=to_subgraph(data["triples"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.publisher.send(None, elt)
|
|
||||||
|
|
||||||
|
|
||||||
running.stop()
|
|
||||||
|
|
@ -1,67 +0,0 @@
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import queue
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from .. schema import Triples
|
|
||||||
from .. schema import triples_store_queue
|
|
||||||
from .. base import Subscriber
|
|
||||||
|
|
||||||
from . socket import SocketEndpoint
|
|
||||||
from . serialize import serialize_triples
|
|
||||||
|
|
||||||
class TriplesStreamEndpoint(SocketEndpoint):
|
|
||||||
|
|
||||||
def __init__(self, pulsar_client, auth, path="/api/v1/stream/triples"):
|
|
||||||
|
|
||||||
super(TriplesStreamEndpoint, self).__init__(
|
|
||||||
endpoint_path=path, auth=auth,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.pulsar_client=pulsar_client
|
|
||||||
|
|
||||||
self.subscriber = Subscriber(
|
|
||||||
self.pulsar_client, triples_store_queue,
|
|
||||||
"api-gateway", "api-gateway",
|
|
||||||
schema=Triples
|
|
||||||
)
|
|
||||||
|
|
||||||
async def listener(self, ws, running):
|
|
||||||
|
|
||||||
worker = asyncio.create_task(
|
|
||||||
self.async_thread(ws, running)
|
|
||||||
)
|
|
||||||
|
|
||||||
await super(TriplesStreamEndpoint, self).listener(ws, running)
|
|
||||||
|
|
||||||
await worker
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
|
|
||||||
await self.subscriber.start()
|
|
||||||
|
|
||||||
async def async_thread(self, ws, running):
|
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
q = self.subscriber.subscribe_all(id)
|
|
||||||
|
|
||||||
while running.get():
|
|
||||||
try:
|
|
||||||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
|
||||||
await ws.send_json(serialize_triples(resp))
|
|
||||||
|
|
||||||
except TimeoutError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
except queue.Empty:
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Exception: {str(e)}", flush=True)
|
|
||||||
break
|
|
||||||
|
|
||||||
self.subscriber.unsubscribe_all(id)
|
|
||||||
|
|
||||||
running.stop()
|
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue