mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Flow API - update gateway (#357)
* Altered API to incorporate Flow IDs, refactored for dynamic start/stop of flows * Gateway: Split endpoint / dispatcher for maintainability
This commit is contained in:
parent
450f664b1b
commit
a70ae9793a
52 changed files with 1206 additions and 907 deletions
33
test-api/test-llm2-api
Executable file
33
test-api/test-llm2-api
Executable file
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
|
||||
url = "http://localhost:8088/api/v1/"
|
||||
|
||||
############################################################################
|
||||
|
||||
input = {
|
||||
"system": "",
|
||||
"prompt": "Add 2 and 3"
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}text-completion",
|
||||
json=input,
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError(f"Status code: {resp.status_code}")
|
||||
|
||||
resp = resp.json()
|
||||
|
||||
if "error" in resp:
|
||||
print(f"Error: {resp['error']}")
|
||||
sys.exit(1)
|
||||
|
||||
print(resp["response"])
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
@ -5,7 +5,7 @@ import json
|
|||
import sys
|
||||
import base64
|
||||
|
||||
url = "http://localhost:8088/api/v1/"
|
||||
url = "http://localhost:8088/api/v1/flow/0000/document-load"
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
@ -88,10 +88,7 @@ input = {
|
|||
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}load/document",
|
||||
json=input,
|
||||
)
|
||||
resp = requests.post(url, json=input)
|
||||
|
||||
resp = resp.json()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import json
|
|||
import sys
|
||||
import base64
|
||||
|
||||
url = "http://localhost:8088/api/v1/"
|
||||
url = "http://localhost:8088/api/v1/flow/0000/service/text-load"
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
@ -85,10 +85,7 @@ input = {
|
|||
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
f"{url}load/text",
|
||||
json=input,
|
||||
)
|
||||
resp = requests.post(url, json=input)
|
||||
|
||||
resp = resp.json()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from trustgraph.schema import Document, Metadata
|
|||
client = pulsar.Client("pulsar://localhost:6650", listener_name="localhost")
|
||||
|
||||
prod = client.create_producer(
|
||||
topic="persistent://tg/flow/document-load:0002",
|
||||
topic="persistent://tg/flow/document-load:0000",
|
||||
schema=JsonSchema(Document),
|
||||
chunking_enabled=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ prod = client.create_producer(
|
|||
chunking_enabled=True,
|
||||
)
|
||||
|
||||
path = "docs/README.cats"
|
||||
path = "../trustgraph/docs/README.cats"
|
||||
|
||||
with open(path, "r") as f:
|
||||
# blob = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
|
|
|||
|
|
@ -15,16 +15,21 @@ class Publisher:
|
|||
self.q = asyncio.Queue(maxsize=max_size)
|
||||
self.chunking_enabled = chunking_enabled
|
||||
self.running = True
|
||||
self.task = None
|
||||
|
||||
async def start(self):
|
||||
self.task = asyncio.create_task(self.run())
|
||||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
|
||||
if self.task:
|
||||
await self.task
|
||||
|
||||
async def join(self):
|
||||
await self.stop()
|
||||
|
||||
if self.task:
|
||||
await self.task
|
||||
|
||||
async def run(self):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class Subscriber:
|
|||
self.lock = asyncio.Lock()
|
||||
self.running = True
|
||||
self.metrics = metrics
|
||||
self.task = None
|
||||
|
||||
def __del__(self):
|
||||
self.running = False
|
||||
|
|
@ -28,10 +29,14 @@ class Subscriber:
|
|||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
|
||||
if self.task:
|
||||
await self.task
|
||||
|
||||
async def join(self):
|
||||
await self.stop()
|
||||
|
||||
if self.task:
|
||||
await self.task
|
||||
|
||||
async def run(self):
|
||||
|
|
@ -45,6 +50,8 @@ class Subscriber:
|
|||
|
||||
try:
|
||||
|
||||
# FIXME: Create consumer in start method so we know
|
||||
# it is definitely running when start completes
|
||||
consumer = self.client.subscribe(
|
||||
topic = self.topic,
|
||||
subscription_name = self.subscription,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
|
||||
from pulsar.schema import JsonSchema
|
||||
from .. schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from .. schema import embeddings_request_queue, embeddings_response_queue
|
||||
from . base import BaseClient
|
||||
|
||||
import _pulsar
|
||||
|
|
@ -23,12 +22,6 @@ class EmbeddingsClient(BaseClient):
|
|||
pulsar_api_key=None,
|
||||
):
|
||||
|
||||
if input_queue == None:
|
||||
input_queue=embeddings_request_queue
|
||||
|
||||
if output_queue == None:
|
||||
output_queue=embeddings_response_queue
|
||||
|
||||
super(EmbeddingsClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
|
|
@ -43,4 +36,3 @@ class EmbeddingsClient(BaseClient):
|
|||
def request(self, text, timeout=300):
|
||||
return self.call(text=text, timeout=timeout).vectors
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ async def load_ge(running, queue, url):
|
|||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
||||
async with session.ws_connect(f"{url}load/graph-embeddings") as ws:
|
||||
async with session.ws_connect(url) as ws:
|
||||
|
||||
while running.get():
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ async def load_triples(running, queue, url):
|
|||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
||||
async with session.ws_connect(f"{url}load/triples") as ws:
|
||||
async with session.ws_connect(url) as ws:
|
||||
|
||||
while running.get():
|
||||
|
||||
|
|
@ -200,6 +200,9 @@ async def run(running, **args):
|
|||
ge_q = asyncio.Queue(maxsize=10)
|
||||
t_q = asyncio.Queue(maxsize=10)
|
||||
|
||||
flow_id = args["flow_id"]
|
||||
url = args["url"]
|
||||
|
||||
load_task = asyncio.create_task(
|
||||
loader(
|
||||
running=running,
|
||||
|
|
@ -213,14 +216,16 @@ async def run(running, **args):
|
|||
ge_task = asyncio.create_task(
|
||||
load_ge(
|
||||
running = running,
|
||||
queue=ge_q, url=args["url"] + "api/v1/"
|
||||
queue = ge_q,
|
||||
url = f"{url}api/v1/flow/{flow_id}/import/graph-embeddings"
|
||||
)
|
||||
)
|
||||
|
||||
triples_task = asyncio.create_task(
|
||||
load_triples(
|
||||
running = running,
|
||||
queue=t_q, url=args["url"] + "api/v1/"
|
||||
queue = t_q,
|
||||
url = f"{url}api/v1/flow/{flow_id}/import/triples"
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -258,6 +263,12 @@ async def main(running):
|
|||
help=f'Output file'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--flow-id',
|
||||
default="0000",
|
||||
help=f'Flow ID (default: 0000)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--format',
|
||||
default="msgpack",
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ async def fetch_ge(running, queue, user, collection, url):
|
|||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
||||
async with session.ws_connect(f"{url}stream/graph-embeddings") as ws:
|
||||
async with session.ws_connect(url) as ws:
|
||||
|
||||
while running.get():
|
||||
|
||||
|
|
@ -74,7 +74,7 @@ async def fetch_triples(running, queue, user, collection, url):
|
|||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
||||
async with session.ws_connect(f"{url}stream/triples") as ws:
|
||||
async with session.ws_connect(url) as ws:
|
||||
|
||||
while running.get():
|
||||
|
||||
|
|
@ -160,11 +160,14 @@ async def run(running, **args):
|
|||
|
||||
q = asyncio.Queue()
|
||||
|
||||
flow_id = args["flow_id"]
|
||||
url = args["url"]
|
||||
|
||||
ge_task = asyncio.create_task(
|
||||
fetch_ge(
|
||||
running=running,
|
||||
queue=q, user=args["user"], collection=args["collection"],
|
||||
url=args["url"] + "api/v1/"
|
||||
url = f"{url}api/v1/flow/{flow_id}/export/graph-embeddings"
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -172,7 +175,7 @@ async def run(running, **args):
|
|||
fetch_triples(
|
||||
running=running, queue=q,
|
||||
user=args["user"], collection=args["collection"],
|
||||
url=args["url"] + "api/v1/"
|
||||
url = f"{url}api/v1/flow/{flow_id}/export/triples"
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -224,6 +227,12 @@ async def main(running):
|
|||
help=f'Output format (default: msgpack)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--flow-id',
|
||||
default="0000",
|
||||
help=f'Flow ID (default: 0000)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--user',
|
||||
help=f'User ID to filter on (default: no filter)'
|
||||
|
|
|
|||
121
trustgraph-flow/trustgraph/gateway/config/receiver.py
Executable file
121
trustgraph-flow/trustgraph/gateway/config/receiver.py
Executable file
|
|
@ -0,0 +1,121 @@
|
|||
"""
|
||||
API gateway. Offers HTTP services which are translated to interaction on the
|
||||
Pulsar bus.
|
||||
"""
|
||||
|
||||
module = "api-gateway"
|
||||
|
||||
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
|
||||
# are active listeners
|
||||
|
||||
# FIXME: Connection errors in publishers / subscribers cause those threads
|
||||
# to fail and are not failed or retried
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
from aiohttp import web
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
import uuid
|
||||
import json
|
||||
|
||||
import pulsar
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from ... schema import ConfigPush, config_push_queue
|
||||
from ... base import Consumer
|
||||
|
||||
logger = logging.getLogger("config.receiver")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class ConfigReceiver:
|
||||
|
||||
def __init__(self, pulsar_client):
|
||||
|
||||
self.pulsar_client = pulsar_client
|
||||
|
||||
self.flow_handlers = []
|
||||
|
||||
self.flows = {}
|
||||
|
||||
def add_handler(self, h):
|
||||
self.flow_handlers.append(h)
|
||||
|
||||
async def on_config(self, msg, proc, flow):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
print(f"Config version", v.version)
|
||||
|
||||
if "flows" in v.config:
|
||||
|
||||
flows = v.config["flows"]
|
||||
|
||||
wanted = list(flows.keys())
|
||||
current = list(self.flows.keys())
|
||||
|
||||
for k in wanted:
|
||||
if k not in current:
|
||||
self.flows[k] = json.loads(flows[k])
|
||||
await self.start_flow(k, self.flows[k])
|
||||
|
||||
for k in current:
|
||||
if k not in wanted:
|
||||
await self.stop_flow(k, self.flows[k])
|
||||
del self.flows[k]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}", flush=True)
|
||||
|
||||
async def start_flow(self, id, flow):
|
||||
|
||||
print("Start flow", id)
|
||||
|
||||
for handler in self.flow_handlers:
|
||||
|
||||
try:
|
||||
await handler.start_flow(id, flow)
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}", flush=True)
|
||||
|
||||
async def stop_flow(self, id, flow):
|
||||
|
||||
print("Stop flow", id)
|
||||
|
||||
for handler in self.flow_handlers:
|
||||
|
||||
try:
|
||||
await handler.stop_flow(id, flow)
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}", flush=True)
|
||||
|
||||
async def config_loader(self):
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
self.config_cons = Consumer(
|
||||
taskgroup = tg,
|
||||
flow = None,
|
||||
client = self.pulsar_client,
|
||||
subscriber = f"gateway-{id}",
|
||||
topic = config_push_queue,
|
||||
schema = ConfigPush,
|
||||
handler = self.on_config,
|
||||
start_of_messages = True,
|
||||
)
|
||||
|
||||
await self.config_cons.start()
|
||||
|
||||
print("Waiting...")
|
||||
|
||||
print("Config consumer done. :/")
|
||||
|
||||
async def start(self):
|
||||
|
||||
asyncio.create_task(self.config_loader())
|
||||
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
|
||||
from .. schema import AgentRequest, AgentResponse
|
||||
from ... schema import AgentRequest, AgentResponse
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class AgentRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
self, pulsar_client, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
|
|
@ -1,13 +1,12 @@
|
|||
|
||||
from .. schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
|
||||
from .. schema import config_request_queue
|
||||
from .. schema import config_response_queue
|
||||
from ... schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
|
||||
from ... schema import config_request_queue
|
||||
from ... schema import config_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class ConfigRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout, auth):
|
||||
def __init__(self, pulsar_client, timeout=120):
|
||||
|
||||
super(ConfigRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
import uuid
|
||||
|
||||
from ... schema import DocumentEmbeddings
|
||||
from ... base import Subscriber
|
||||
|
||||
from . serialize import serialize_document_embeddings
|
||||
|
||||
class DocumentEmbeddingsExport:
|
||||
|
||||
def __init__(
|
||||
self, ws, running, pulsar_client, queue, consumer, subscriber
|
||||
):
|
||||
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
self.pulsar_client = pulsar_client
|
||||
self.queue = queue
|
||||
self.consumer = consumer
|
||||
self.subscriber = subscriber
|
||||
|
||||
async def destroy(self):
|
||||
self.running.stop()
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
# Ignore incoming info from websocket
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
|
||||
subs = Subscriber(
|
||||
client = self.pulsar_client, topic = self.queue,
|
||||
consumer_name = self.consumer, subscription = self.subscriber,
|
||||
schema = DocumentEmbeddings
|
||||
)
|
||||
|
||||
await subs.start()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
q = await subs.subscribe_all(id)
|
||||
|
||||
while self.running.get():
|
||||
try:
|
||||
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await self.ws.send_json(serialize_document_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
await subs.unsubscribe_all(id)
|
||||
|
||||
await subs.stop()
|
||||
|
||||
await self.ws.close()
|
||||
self.running.stop()
|
||||
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
from ... schema import DocumentEmbeddings, ChunkEmbeddings
|
||||
from ... base import Publisher
|
||||
|
||||
from . serialize import to_subgraph
|
||||
|
||||
class DocumentEmbeddingsImport:
|
||||
|
||||
def __init__(
|
||||
self, ws, running, pulsar_client, queue
|
||||
):
|
||||
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
|
||||
self.publisher = Publisher(
|
||||
pulsar_client, topic = queue, schema = DocumentEmbeddings
|
||||
)
|
||||
|
||||
async def destroy(self):
|
||||
self.running.stop()
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.publisher.stop()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
||||
elt = DocumentEmbeddings(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
chunks=[
|
||||
ChunkEmbeddings(
|
||||
chunk=de["chunk"].encode("utf-8"),
|
||||
vectors=de["vectors"],
|
||||
)
|
||||
for de in data["chunks"]
|
||||
],
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
|
||||
async def run(self):
|
||||
|
||||
while self.running.get():
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
self.ws = None
|
||||
|
||||
|
|
@ -1,19 +1,18 @@
|
|||
|
||||
import base64
|
||||
|
||||
from .. schema import Document, Metadata
|
||||
from .. schema import document_ingest_queue
|
||||
from ... schema import Document, Metadata
|
||||
|
||||
from . sender import ServiceSender
|
||||
from . serialize import to_subgraph
|
||||
|
||||
class DocumentLoadSender(ServiceSender):
|
||||
def __init__(self, pulsar_client):
|
||||
class DocumentLoad(ServiceSender):
|
||||
def __init__(self, pulsar_client, queue):
|
||||
|
||||
super(DocumentLoadSender, self).__init__(
|
||||
super(DocumentLoad, self).__init__(
|
||||
pulsar_client = pulsar_client,
|
||||
request_queue=document_ingest_queue,
|
||||
request_schema=Document,
|
||||
queue = queue,
|
||||
schema = Document,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
|
@ -1,20 +1,22 @@
|
|||
|
||||
from .. schema import DocumentRagQuery, DocumentRagResponse
|
||||
from .. schema import document_rag_request_queue
|
||||
from .. schema import document_rag_response_queue
|
||||
from ... schema import DocumentRagQuery, DocumentRagResponse
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class DocumentRagRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout, auth):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(DocumentRagRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
request_queue=document_rag_request_queue,
|
||||
response_queue=document_rag_response_queue,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=DocumentRagQuery,
|
||||
response_schema=DocumentRagResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
|
||||
from .. schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class EmbeddingsRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
self, pulsar_client, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
|
|
@ -1,13 +1,12 @@
|
|||
|
||||
from .. schema import FlowRequest, FlowResponse, ConfigKey, ConfigValue
|
||||
from .. schema import flow_request_queue
|
||||
from .. schema import flow_response_queue
|
||||
from ... schema import FlowRequest, FlowResponse
|
||||
from ... schema import flow_request_queue
|
||||
from ... schema import flow_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class FlowRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout, auth):
|
||||
def __init__(self, pulsar_client, timeout=120):
|
||||
|
||||
super(FlowRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
import uuid
|
||||
|
||||
from ... schema import GraphEmbeddings
|
||||
from ... base import Subscriber
|
||||
|
||||
from . serialize import serialize_graph_embeddings
|
||||
|
||||
class GraphEmbeddingsExport:
|
||||
|
||||
def __init__(
|
||||
self, ws, running, pulsar_client, queue, consumer, subscriber
|
||||
):
|
||||
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
self.pulsar_client = pulsar_client
|
||||
self.queue = queue
|
||||
self.consumer = consumer
|
||||
self.subscriber = subscriber
|
||||
|
||||
async def destroy(self):
|
||||
self.running.stop()
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
# Ignore incoming info from websocket
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
|
||||
subs = Subscriber(
|
||||
client = self.pulsar_client, topic = self.queue,
|
||||
consumer_name = self.consumer, subscription = self.subscriber,
|
||||
schema = GraphEmbeddings
|
||||
)
|
||||
|
||||
await subs.start()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
q = await subs.subscribe_all(id)
|
||||
|
||||
while self.running.get():
|
||||
try:
|
||||
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await self.ws.send_json(serialize_graph_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
await subs.unsubscribe_all(id)
|
||||
|
||||
await subs.stop()
|
||||
|
||||
await self.ws.close()
|
||||
self.running.stop()
|
||||
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
from ... schema import GraphEmbeddings, EntityEmbeddings
|
||||
from ... base import Publisher
|
||||
|
||||
from . serialize import to_subgraph, to_value
|
||||
|
||||
class GraphEmbeddingsImport:
|
||||
|
||||
def __init__(
|
||||
self, ws, running, pulsar_client, queue
|
||||
):
|
||||
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
|
||||
self.publisher = Publisher(
|
||||
pulsar_client, topic = queue, schema = GraphEmbeddings
|
||||
)
|
||||
|
||||
async def destroy(self):
|
||||
self.running.stop()
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.publisher.stop()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
||||
elt = GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=to_value(ent["entity"]),
|
||||
vectors=ent["vectors"],
|
||||
)
|
||||
for ent in data["entities"]
|
||||
]
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
|
||||
async def run(self):
|
||||
|
||||
while self.running.get():
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
self.ws = None
|
||||
|
||||
|
|
@ -1,13 +1,12 @@
|
|||
|
||||
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from ... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
from . serialize import serialize_value
|
||||
|
||||
class GraphEmbeddingsQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
self, pulsar_client, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
|
||||
from .. schema import GraphRagQuery, GraphRagResponse
|
||||
from ... schema import GraphRagQuery, GraphRagResponse
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class GraphRagRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
self, pulsar_client, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
|
|
@ -1,15 +1,14 @@
|
|||
|
||||
from .. schema import LibrarianRequest, LibrarianResponse, Triples
|
||||
from .. schema import librarian_request_queue
|
||||
from .. schema import librarian_response_queue
|
||||
from ... schema import LibrarianRequest, LibrarianResponse
|
||||
from ... schema import librarian_request_queue
|
||||
from ... schema import librarian_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
from . serialize import serialize_document_package, serialize_document_info
|
||||
from . serialize import to_document_package, to_document_info, to_criteria
|
||||
|
||||
class LibrarianRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout, auth):
|
||||
def __init__(self, pulsar_client, timeout=120):
|
||||
|
||||
super(LibrarianRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
|
|
@ -22,20 +21,16 @@ class LibrarianRequestor(ServiceRequestor):
|
|||
|
||||
def to_request(self, body):
|
||||
|
||||
print("TRR")
|
||||
if "document" in body:
|
||||
dp = to_document_package(body["document"])
|
||||
else:
|
||||
dp = None
|
||||
|
||||
print("GOT")
|
||||
if "criteria" in body:
|
||||
criteria = to_criteria(body["criteria"])
|
||||
else:
|
||||
criteria = None
|
||||
|
||||
print("ASLDKJ")
|
||||
|
||||
return LibrarianRequest(
|
||||
operation = body.get("operation", None),
|
||||
id = body.get("id", None),
|
||||
229
trustgraph-flow/trustgraph/gateway/dispatch/manager.py
Normal file
229
trustgraph-flow/trustgraph/gateway/dispatch/manager.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
from . config import ConfigRequestor
|
||||
from . flow import FlowRequestor
|
||||
from . librarian import LibrarianRequestor
|
||||
|
||||
from . embeddings import EmbeddingsRequestor
|
||||
from . agent import AgentRequestor
|
||||
from . text_completion import TextCompletionRequestor
|
||||
from . prompt import PromptRequestor
|
||||
from . graph_rag import GraphRagRequestor
|
||||
from . document_rag import DocumentRagRequestor
|
||||
from . triples_query import TriplesQueryRequestor
|
||||
from . embeddings import EmbeddingsRequestor
|
||||
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
||||
from . prompt import PromptRequestor
|
||||
from . text_load import TextLoad
|
||||
from . document_load import DocumentLoad
|
||||
|
||||
from . triples_export import TriplesExport
|
||||
from . graph_embeddings_export import GraphEmbeddingsExport
|
||||
from . document_embeddings_export import DocumentEmbeddingsExport
|
||||
|
||||
from . triples_import import TriplesImport
|
||||
from . graph_embeddings_import import GraphEmbeddingsImport
|
||||
from . document_embeddings_import import DocumentEmbeddingsImport
|
||||
|
||||
request_response_dispatchers = {
|
||||
"agent": AgentRequestor,
|
||||
"text-completion": TextCompletionRequestor,
|
||||
"prompt": PromptRequestor,
|
||||
"graph-rag": GraphRagRequestor,
|
||||
"document-rag": DocumentRagRequestor,
|
||||
"embeddings": EmbeddingsRequestor,
|
||||
"graph-embeddings": GraphEmbeddingsQueryRequestor,
|
||||
"triples-query": TriplesQueryRequestor,
|
||||
}
|
||||
|
||||
sender_dispatchers = {
|
||||
"text-load": TextLoad,
|
||||
"document-load": DocumentLoad,
|
||||
}
|
||||
|
||||
export_dispatchers = {
|
||||
"triples": TriplesExport,
|
||||
"graph-embeddings": GraphEmbeddingsExport,
|
||||
"document-embeddings": DocumentEmbeddingsExport,
|
||||
}
|
||||
|
||||
import_dispatchers = {
|
||||
"triples": TriplesImport,
|
||||
"graph-embeddings": GraphEmbeddingsImport,
|
||||
"document-embeddings": DocumentEmbeddingsImport,
|
||||
}
|
||||
|
||||
class DispatcherWrapper:
|
||||
def __init__(self, mgr, name, impl):
|
||||
self.mgr = mgr
|
||||
self.name = name
|
||||
self.impl = impl
|
||||
async def process(self, data, responder):
|
||||
return await self.mgr.process_impl(
|
||||
data, responder, self.name, self.impl
|
||||
)
|
||||
|
||||
class DispatcherManager:
|
||||
|
||||
def __init__(self, pulsar_client, config_receiver):
|
||||
self.pulsar_client = pulsar_client
|
||||
self.config_receiver = config_receiver
|
||||
self.config_receiver.add_handler(self)
|
||||
|
||||
self.flows = {}
|
||||
self.dispatchers = {}
|
||||
|
||||
async def start_flow(self, id, flow):
|
||||
print("Start flow", id)
|
||||
self.flows[id] = flow
|
||||
return
|
||||
|
||||
async def stop_flow(self, id, flow):
|
||||
print("Stop flow", id)
|
||||
del self.flows[id]
|
||||
return
|
||||
|
||||
def dispatch_config(self):
|
||||
return DispatcherWrapper(self, "config", ConfigRequestor)
|
||||
|
||||
def dispatch_flow(self):
|
||||
return DispatcherWrapper(self, "flow", FlowRequestor)
|
||||
|
||||
def dispatch_librarian(self):
|
||||
return DispatcherWrapper(self, "librarian", LibrarianRequestor)
|
||||
|
||||
async def process_impl(self, data, responder, name, impl):
|
||||
|
||||
key = (None, name)
|
||||
|
||||
if key in self.dispatchers:
|
||||
return await self.dispatchers[key].process(data, responder)
|
||||
|
||||
dispatcher = impl(
|
||||
pulsar_client = self.pulsar_client
|
||||
)
|
||||
|
||||
await dispatcher.start()
|
||||
|
||||
self.dispatchers[key] = dispatcher
|
||||
|
||||
return await dispatcher.process(data, responder)
|
||||
|
||||
def dispatch_service(self):
|
||||
return self
|
||||
|
||||
def dispatch_import(self):
|
||||
return self.invoke_import
|
||||
|
||||
def dispatch_export(self):
|
||||
return self.invoke_export
|
||||
|
||||
async def invoke_import(self, ws, running, params):
|
||||
|
||||
flow = params.get("flow")
|
||||
kind = params.get("kind")
|
||||
|
||||
if flow not in self.flows:
|
||||
raise RuntimeError("Invalid flow")
|
||||
|
||||
if kind not in import_dispatchers:
|
||||
raise RuntimeError("Invalid kind")
|
||||
|
||||
key = (flow, kind)
|
||||
|
||||
intf_defs = self.flows[flow]["interfaces"]
|
||||
|
||||
if kind not in intf_defs:
|
||||
raise RuntimeError("This kind not supported by flow")
|
||||
|
||||
# FIXME: The -store bit, does it make sense?
|
||||
qconfig = intf_defs[kind + "-store"]
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
dispatcher = import_dispatchers[kind](
|
||||
pulsar_client = self.pulsar_client,
|
||||
ws = ws,
|
||||
running = running,
|
||||
queue = qconfig,
|
||||
)
|
||||
|
||||
return dispatcher
|
||||
|
||||
async def invoke_export(self, ws, running, params):
|
||||
|
||||
flow = params.get("flow")
|
||||
kind = params.get("kind")
|
||||
|
||||
if flow not in self.flows:
|
||||
raise RuntimeError("Invalid flow")
|
||||
|
||||
if kind not in export_dispatchers:
|
||||
raise RuntimeError("Invalid kind")
|
||||
|
||||
key = (flow, kind)
|
||||
|
||||
intf_defs = self.flows[flow]["interfaces"]
|
||||
|
||||
if kind not in intf_defs:
|
||||
raise RuntimeError("This kind not supported by flow")
|
||||
|
||||
# FIXME: The -store bit, does it make sense?
|
||||
qconfig = intf_defs[kind + "-store"]
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
dispatcher = export_dispatchers[kind](
|
||||
pulsar_client = self.pulsar_client,
|
||||
ws = ws,
|
||||
running = running,
|
||||
queue = qconfig,
|
||||
consumer = f"api-gateway-{id}",
|
||||
subscriber = f"api-gateway-{id}",
|
||||
)
|
||||
|
||||
return dispatcher
|
||||
|
||||
async def process(self, data, responder, params):
|
||||
|
||||
flow = params.get("flow")
|
||||
kind = params.get("kind")
|
||||
|
||||
if flow not in self.flows:
|
||||
raise RuntimeError("Invalid flow")
|
||||
|
||||
key = (flow, kind)
|
||||
|
||||
if key in self.dispatchers:
|
||||
return await self.dispatchers[key].process(data, responder)
|
||||
|
||||
intf_defs = self.flows[flow]["interfaces"]
|
||||
|
||||
if kind not in intf_defs:
|
||||
raise RuntimeError("This kind not supported by flow")
|
||||
|
||||
qconfig = intf_defs[kind]
|
||||
|
||||
if kind in request_response_dispatchers:
|
||||
dispatcher = request_response_dispatchers[kind](
|
||||
pulsar_client = self.pulsar_client,
|
||||
request_queue = qconfig["request"],
|
||||
response_queue = qconfig["response"],
|
||||
timeout = 120,
|
||||
consumer = f"api-gateway-{flow}-{kind}-request",
|
||||
subscriber = f"api-gateway-{flow}-{kind}-request",
|
||||
)
|
||||
elif kind in sender_dispatchers:
|
||||
dispatcher = sender_dispatchers[kind](
|
||||
pulsar_client = self.pulsar_client,
|
||||
queue = qconfig,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Invalid kind")
|
||||
|
||||
await dispatcher.start()
|
||||
|
||||
self.dispatchers[key] = dispatcher
|
||||
|
||||
return await dispatcher.process(data, responder)
|
||||
|
||||
|
|
@ -1,14 +1,13 @@
|
|||
|
||||
import json
|
||||
|
||||
from .. schema import PromptRequest, PromptResponse
|
||||
from ... schema import PromptRequest, PromptResponse
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class PromptRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
self, pulsar_client, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
|
|
@ -3,8 +3,8 @@ import asyncio
|
|||
import uuid
|
||||
import logging
|
||||
|
||||
from .. base import Publisher
|
||||
from .. base import Subscriber
|
||||
from ... base import Publisher
|
||||
from ... base import Subscriber
|
||||
|
||||
logger = logging.getLogger("requestor")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -33,13 +33,17 @@ class ServiceRequestor:
|
|||
|
||||
self.timeout = timeout
|
||||
|
||||
self.running = True
|
||||
|
||||
async def start(self):
|
||||
await self.pub.start()
|
||||
self.running = True
|
||||
await self.sub.start()
|
||||
await self.pub.start()
|
||||
|
||||
async def stop(self):
|
||||
await self.pub.stop()
|
||||
await self.sub.stop()
|
||||
self.running = False
|
||||
|
||||
def to_request(self, request):
|
||||
raise RuntimeError("Not defined")
|
||||
|
|
@ -57,13 +61,14 @@ class ServiceRequestor:
|
|||
|
||||
await self.pub.send(id, self.to_request(request))
|
||||
|
||||
while True:
|
||||
while self.running:
|
||||
|
||||
try:
|
||||
resp = await asyncio.wait_for(
|
||||
q.get(), timeout=self.timeout
|
||||
)
|
||||
except Exception as e:
|
||||
print("Exception", e)
|
||||
raise RuntimeError("Timeout")
|
||||
|
||||
if resp.error:
|
||||
|
|
@ -5,7 +5,7 @@ import asyncio
|
|||
import uuid
|
||||
import logging
|
||||
|
||||
from .. base import Publisher
|
||||
from ... base import Publisher
|
||||
|
||||
logger = logging.getLogger("sender")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -15,18 +15,20 @@ class ServiceSender:
|
|||
def __init__(
|
||||
self,
|
||||
pulsar_client,
|
||||
request_queue, request_schema,
|
||||
queue, schema,
|
||||
):
|
||||
|
||||
self.pub = Publisher(
|
||||
pulsar_client, request_queue,
|
||||
schema=request_schema,
|
||||
pulsar_client, queue,
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.pub.start()
|
||||
|
||||
async def stop(self):
|
||||
await self.pub.stop()
|
||||
|
||||
def to_request(self, request):
|
||||
raise RuntimeError("Not defined")
|
||||
|
||||
|
|
@ -39,6 +41,8 @@ class ServiceSender:
|
|||
if responder:
|
||||
await responder({}, True)
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logging.error(f"Exception: {e}")
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
import base64
|
||||
|
||||
from .. schema import Value, Triple, DocumentPackage, DocumentInfo
|
||||
from ... schema import Value, Triple, DocumentPackage, DocumentInfo
|
||||
|
||||
def to_value(x):
|
||||
return Value(value=x["v"], is_uri=x["e"])
|
||||
99
trustgraph-flow/trustgraph/gateway/dispatch/streamer.py
Normal file
99
trustgraph-flow/trustgraph/gateway/dispatch/streamer.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from ... base import Publisher
|
||||
from ... base import Subscriber
|
||||
|
||||
logger = logging.getLogger("requestor")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class ServiceRequestor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_client,
|
||||
queue, schema,
|
||||
handler,
|
||||
subscription="api-gateway", consumer_name="api-gateway",
|
||||
timeout=600,
|
||||
):
|
||||
|
||||
self.sub = Subscriber(
|
||||
pulsar_client, queue,
|
||||
subscription, consumer_name,
|
||||
schema
|
||||
)
|
||||
|
||||
self.timeout = timeout
|
||||
|
||||
self.running = True
|
||||
|
||||
self.receiver = handler
|
||||
|
||||
async def start(self):
|
||||
await self.sub.start()
|
||||
self.streamer = asyncio.create_task(self.stream())
|
||||
sub.start()
|
||||
self.running = True
|
||||
|
||||
async def stop(self):
|
||||
await self.sub.stop()
|
||||
self.running = False
|
||||
|
||||
def from_inbound(self, response):
|
||||
raise RuntimeError("Not defined")
|
||||
|
||||
async def stream(self):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
|
||||
q = await self.sub.subscribe(id)
|
||||
|
||||
while self.running:
|
||||
|
||||
try:
|
||||
resp = await asyncio.wait_for(
|
||||
q.get(), timeout=self.timeout
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Timeout")
|
||||
|
||||
if resp.error:
|
||||
err = { "error": {
|
||||
"type": resp.error.type,
|
||||
"message": resp.error.message,
|
||||
} }
|
||||
|
||||
fin = False
|
||||
|
||||
await self.receiver(err, fin)
|
||||
|
||||
else:
|
||||
|
||||
resp, fin = self.from_inbound(resp)
|
||||
|
||||
print(resp, fin)
|
||||
|
||||
await self.receiver(resp, fin)
|
||||
|
||||
if fin: break
|
||||
|
||||
except Exception as e:
|
||||
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
err = { "error": {
|
||||
"type": "gateway-error",
|
||||
"message": str(e),
|
||||
} }
|
||||
if responder:
|
||||
await responder(err, True)
|
||||
return err
|
||||
|
||||
finally:
|
||||
await self.sub.unsubscribe(id)
|
||||
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
|
||||
from .. schema import TextCompletionRequest, TextCompletionResponse
|
||||
from ... schema import TextCompletionRequest, TextCompletionResponse
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class TextCompletionRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
self, pulsar_client, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
|
|
@ -1,19 +1,18 @@
|
|||
|
||||
import base64
|
||||
|
||||
from .. schema import TextDocument, Metadata
|
||||
from .. schema import text_ingest_queue
|
||||
from ... schema import TextDocument, Metadata
|
||||
|
||||
from . sender import ServiceSender
|
||||
from . serialize import to_subgraph
|
||||
|
||||
class TextLoadSender(ServiceSender):
|
||||
def __init__(self, pulsar_client):
|
||||
class TextLoad(ServiceSender):
|
||||
def __init__(self, pulsar_client, queue):
|
||||
|
||||
super(TextLoadSender, self).__init__(
|
||||
super(TextLoad, self).__init__(
|
||||
pulsar_client = pulsar_client,
|
||||
request_queue=text_ingest_queue,
|
||||
request_schema=TextDocument,
|
||||
queue = queue,
|
||||
schema = TextDocument,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
import uuid
|
||||
|
||||
from ... schema import Triples
|
||||
from ... base import Subscriber
|
||||
|
||||
from . serialize import serialize_triples
|
||||
|
||||
class TriplesExport:
|
||||
|
||||
def __init__(
|
||||
self, ws, running, pulsar_client, queue, consumer, subscriber
|
||||
):
|
||||
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
self.pulsar_client = pulsar_client
|
||||
self.queue = queue
|
||||
self.consumer = consumer
|
||||
self.subscriber = subscriber
|
||||
|
||||
async def destroy(self):
|
||||
self.running.stop()
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
# Ignore incoming info from websocket
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
|
||||
subs = Subscriber(
|
||||
client = self.pulsar_client, topic = self.queue,
|
||||
consumer_name = self.consumer, subscription = self.subscriber,
|
||||
schema = Triples
|
||||
)
|
||||
|
||||
await subs.start()
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
q = await subs.subscribe_all(id)
|
||||
|
||||
while self.running.get():
|
||||
try:
|
||||
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await self.ws.send_json(serialize_triples(resp))
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
await subs.unsubscribe_all(id)
|
||||
|
||||
await subs.stop()
|
||||
|
||||
await self.ws.close()
|
||||
self.running.stop()
|
||||
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
from ... schema import Triples
|
||||
from ... base import Publisher
|
||||
|
||||
from . serialize import to_subgraph
|
||||
|
||||
class TriplesImport:
|
||||
|
||||
def __init__(
|
||||
self, ws, running, pulsar_client, queue
|
||||
):
|
||||
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
|
||||
self.publisher = Publisher(
|
||||
pulsar_client, topic = queue, schema = Triples
|
||||
)
|
||||
|
||||
async def destroy(self):
|
||||
self.running.stop()
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.publisher.stop()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
data = msg.json()
|
||||
|
||||
elt = Triples(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
triples=to_subgraph(data["triples"]),
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
|
||||
async def run(self):
|
||||
|
||||
while self.running.get():
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
self.ws = None
|
||||
|
||||
|
|
@ -1,13 +1,12 @@
|
|||
|
||||
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Triples
|
||||
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
from . serialize import to_value, serialize_subgraph
|
||||
|
||||
class TriplesQueryRequestor(ServiceRequestor):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
self, pulsar_client, request_queue, response_queue, timeout,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from .. schema import Metadata
|
||||
from .. schema import DocumentEmbeddings, ChunkEmbeddings
|
||||
from .. schema import document_embeddings_store_queue
|
||||
from .. base import Publisher
|
||||
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import to_subgraph
|
||||
|
||||
class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(
|
||||
self, pulsar_client, auth, path="/api/v1/load/document-embeddings",
|
||||
):
|
||||
|
||||
super(DocumentEmbeddingsLoadEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_client=pulsar_client
|
||||
|
||||
self.publisher = Publisher(
|
||||
self.pulsar_client, document_embeddings_store_queue,
|
||||
schema=DocumentEmbeddings
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.publisher.start()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
async for msg in ws:
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
break
|
||||
else:
|
||||
|
||||
data = msg.json()
|
||||
|
||||
elt = DocumentEmbeddings(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
chunks=[
|
||||
ChunkEmbeddings(
|
||||
chunk=de["chunk"].encode("utf-8"),
|
||||
vectors=de["vectors"],
|
||||
)
|
||||
for de in data["chunks"]
|
||||
],
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
|
||||
running.stop()
|
||||
|
|
@ -1,72 +0,0 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
import uuid
|
||||
|
||||
from .. schema import DocumentEmbeddings
|
||||
from .. schema import document_embeddings_store_queue
|
||||
from .. base import Subscriber
|
||||
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import serialize_document_embeddings
|
||||
|
||||
class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(
|
||||
self, pulsar_client, auth,
|
||||
path="/api/v1/stream/document-embeddings"
|
||||
):
|
||||
|
||||
super(DocumentEmbeddingsStreamEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_client=pulsar_client
|
||||
|
||||
self.subscriber = Subscriber(
|
||||
self.pulsar_client, document_embeddings_store_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
schema=DocumentEmbeddings,
|
||||
)
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
worker = asyncio.create_task(
|
||||
self.async_thread(ws, running)
|
||||
)
|
||||
|
||||
await super(DocumentEmbeddingsStreamEndpoint, self).listener(
|
||||
ws, running
|
||||
)
|
||||
|
||||
await worker
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.subscriber.start()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
q = await self.subscriber.subscribe_all(id)
|
||||
|
||||
while running.get():
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
await ws.send_json(serialize_document_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
await self.subscriber.unsubscribe_all(id)
|
||||
|
||||
running.stop()
|
||||
|
||||
|
|
@ -7,19 +7,19 @@ import logging
|
|||
logger = logging.getLogger("endpoint")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class ServiceEndpoint:
|
||||
class ConstantEndpoint:
|
||||
|
||||
def __init__(self, endpoint_path, auth, requestor):
|
||||
def __init__(self, endpoint_path, auth, dispatcher):
|
||||
|
||||
self.path = endpoint_path
|
||||
|
||||
self.auth = auth
|
||||
self.operation = "service"
|
||||
|
||||
self.requestor = requestor
|
||||
self.dispatcher = dispatcher
|
||||
|
||||
async def start(self):
|
||||
await self.requestor.start()
|
||||
pass
|
||||
|
||||
def add_routes(self, app):
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ class ServiceEndpoint:
|
|||
async def responder(x, fin):
|
||||
print(x)
|
||||
|
||||
resp = await self.requestor.process(data, responder)
|
||||
resp = await self.dispatcher.process(data, responder)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
67
trustgraph-flow/trustgraph/gateway/endpoint/manager.py
Normal file
67
trustgraph-flow/trustgraph/gateway/endpoint/manager.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from . constant_endpoint import ConstantEndpoint
|
||||
from . variable_endpoint import VariableEndpoint
|
||||
from . socket import SocketEndpoint
|
||||
from . metrics import MetricsEndpoint
|
||||
|
||||
from .. dispatch.manager import DispatcherManager
|
||||
|
||||
class EndpointManager:
|
||||
|
||||
def __init__(
|
||||
self, dispatcher_manager, auth, prometheus_url, timeout=600
|
||||
):
|
||||
|
||||
self.dispatcher_manager = dispatcher_manager
|
||||
self.timeout = timeout
|
||||
|
||||
self.services = {
|
||||
}
|
||||
|
||||
self.endpoints = [
|
||||
ConstantEndpoint(
|
||||
endpoint_path = "/api/v1/librarian", auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_librarian(),
|
||||
),
|
||||
ConstantEndpoint(
|
||||
endpoint_path = "/api/v1/config", auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_config(),
|
||||
),
|
||||
ConstantEndpoint(
|
||||
endpoint_path = "/api/v1/flow", auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_flow(),
|
||||
),
|
||||
MetricsEndpoint(
|
||||
endpoint_path = "/api/v1/metrics",
|
||||
prometheus_url = prometheus_url,
|
||||
auth = auth,
|
||||
),
|
||||
VariableEndpoint(
|
||||
endpoint_path = "/api/v1/flow/{flow}/service/{kind}",
|
||||
auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_service(),
|
||||
),
|
||||
SocketEndpoint(
|
||||
endpoint_path = "/api/v1/flow/{flow}/import/{kind}",
|
||||
auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_import()
|
||||
),
|
||||
SocketEndpoint(
|
||||
endpoint_path = "/api/v1/flow/{flow}/export/{kind}",
|
||||
auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_export()
|
||||
),
|
||||
]
|
||||
|
||||
def add_routes(self, app):
|
||||
for ep in self.endpoints:
|
||||
ep.add_routes(app)
|
||||
|
||||
async def start(self):
|
||||
for ep in self.endpoints:
|
||||
await ep.start()
|
||||
|
||||
111
trustgraph-flow/trustgraph/gateway/endpoint/socket.py
Normal file
111
trustgraph-flow/trustgraph/gateway/endpoint/socket.py
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
|
||||
import asyncio
|
||||
from aiohttp import web, WSMsgType
|
||||
import logging
|
||||
|
||||
from .. running import Running
|
||||
|
||||
logger = logging.getLogger("socket")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class SocketEndpoint:
|
||||
|
||||
def __init__(
|
||||
self, endpoint_path, auth, dispatcher,
|
||||
):
|
||||
|
||||
self.path = endpoint_path
|
||||
self.auth = auth
|
||||
self.operation = "socket"
|
||||
|
||||
self.dispatcher = dispatcher
|
||||
|
||||
async def worker(self, ws, dispatcher, running):
|
||||
|
||||
await dispatcher.run()
|
||||
|
||||
async def listener(self, ws, dispatcher, running):
|
||||
|
||||
async for msg in ws:
|
||||
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await dispatcher.receive(msg)
|
||||
continue
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
await dispatcher.receive(msg)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
running.stop()
|
||||
await ws.close()
|
||||
|
||||
async def handle(self, request):
|
||||
|
||||
try:
|
||||
token = request.query['token']
|
||||
except:
|
||||
token = ""
|
||||
|
||||
if not self.auth.permitted(token, self.operation):
|
||||
return web.HTTPUnauthorized()
|
||||
|
||||
# 50MB max message size
|
||||
ws = web.WebSocketResponse(max_msg_size=52428800)
|
||||
|
||||
await ws.prepare(request)
|
||||
|
||||
try:
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
|
||||
running = Running()
|
||||
|
||||
dispatcher = await self.dispatcher(
|
||||
ws, running, request.match_info
|
||||
)
|
||||
|
||||
worker_task = tg.create_task(
|
||||
self.worker(ws, dispatcher, running)
|
||||
)
|
||||
|
||||
lsnr_task = tg.create_task(
|
||||
self.listener(ws, dispatcher, running)
|
||||
)
|
||||
|
||||
print("Created taskgroup, waiting...")
|
||||
|
||||
# Wait for threads to complete
|
||||
|
||||
print("Task group closed")
|
||||
|
||||
# Finally?
|
||||
await dispatcher.destroy()
|
||||
|
||||
except ExceptionGroup as e:
|
||||
|
||||
print("Exception group:", flush=True)
|
||||
|
||||
for se in e.exceptions:
|
||||
print(" Type:", type(se), flush=True)
|
||||
print(f" Exception: {se}", flush=True)
|
||||
except Exception as e:
|
||||
print("Socket exception:", e, flush=True)
|
||||
|
||||
await ws.close()
|
||||
|
||||
return ws
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
async def stop(self):
|
||||
self.running.stop()
|
||||
|
||||
def add_routes(self, app):
|
||||
|
||||
app.add_routes([
|
||||
web.get(self.path, self.handle),
|
||||
])
|
||||
|
||||
|
|
@ -4,26 +4,25 @@ from aiohttp import web
|
|||
import uuid
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("flow-endpoint")
|
||||
logger = logging.getLogger("endpoint")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class FlowEndpoint:
|
||||
class VariableEndpoint:
|
||||
|
||||
def __init__(self, endpoint_path, auth, requestors):
|
||||
def __init__(self, endpoint_path, auth, dispatcher):
|
||||
|
||||
self.path = endpoint_path
|
||||
|
||||
self.auth = auth
|
||||
self.operation = "service"
|
||||
|
||||
self.requestors = requestors
|
||||
self.dispatcher = dispatcher
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
def add_routes(self, app):
|
||||
|
||||
pass
|
||||
app.add_routes([
|
||||
web.post(self.path, self.handle),
|
||||
])
|
||||
|
|
@ -32,15 +31,6 @@ class FlowEndpoint:
|
|||
|
||||
print(request.path, "...")
|
||||
|
||||
flow_id = request.match_info['flow']
|
||||
kind = request.match_info['kind']
|
||||
k = (flow_id, kind)
|
||||
|
||||
if k not in self.requestors:
|
||||
raise web.HTTPBadRequest()
|
||||
|
||||
requestor = self.requestors[k]
|
||||
|
||||
try:
|
||||
ht = request.headers["Authorization"]
|
||||
tokens = ht.split(" ", 2)
|
||||
|
|
@ -62,7 +52,9 @@ class FlowEndpoint:
|
|||
async def responder(x, fin):
|
||||
print(x)
|
||||
|
||||
resp = await requestor.process(data, responder)
|
||||
resp = await self.dispatcher.process(
|
||||
data, responder, request.match_info
|
||||
)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
|
|
@ -1,64 +0,0 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from .. schema import Metadata
|
||||
from .. schema import GraphEmbeddings, EntityEmbeddings
|
||||
from .. schema import graph_embeddings_store_queue
|
||||
from .. base import Publisher
|
||||
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import to_subgraph, to_value
|
||||
|
||||
class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(
|
||||
self, pulsar_client, auth, path="/api/v1/load/graph-embeddings",
|
||||
):
|
||||
|
||||
super(GraphEmbeddingsLoadEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_client=pulsar_client
|
||||
|
||||
self.publisher = Publisher(
|
||||
self.pulsar_client, graph_embeddings_store_queue,
|
||||
schema=GraphEmbeddings
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.publisher.start()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
async for msg in ws:
|
||||
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
break
|
||||
else:
|
||||
|
||||
data = msg.json()
|
||||
|
||||
elt = GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
entities=[
|
||||
EntityEmbeddings(
|
||||
entity=to_value(ent["entity"]),
|
||||
vectors=ent["vectors"],
|
||||
)
|
||||
for ent in data["entities"]
|
||||
]
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
|
||||
running.stop()
|
||||
|
|
@ -1,69 +0,0 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
import uuid
|
||||
|
||||
from .. schema import GraphEmbeddings
|
||||
from .. schema import graph_embeddings_store_queue
|
||||
from .. base import Subscriber
|
||||
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import serialize_graph_embeddings
|
||||
|
||||
class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(
|
||||
self, pulsar_client, auth, path="/api/v1/stream/graph-embeddings"
|
||||
):
|
||||
|
||||
super(GraphEmbeddingsStreamEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_client=pulsar_client
|
||||
|
||||
self.subscriber = Subscriber(
|
||||
self.pulsar_client, graph_embeddings_store_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
schema=GraphEmbeddings
|
||||
)
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
worker = asyncio.create_task(
|
||||
self.async_thread(ws, running)
|
||||
)
|
||||
|
||||
await super(GraphEmbeddingsStreamEndpoint, self).listener(ws, running)
|
||||
|
||||
await worker
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.subscriber.start()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
q = await self.subscriber.subscribe_all(id)
|
||||
|
||||
while running.get():
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_graph_embeddings(resp))
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
await self.subscriber.unsubscribe_all(id)
|
||||
|
||||
running.stop()
|
||||
|
||||
|
|
@ -3,63 +3,22 @@ API gateway. Offers HTTP services which are translated to interaction on the
|
|||
Pulsar bus.
|
||||
"""
|
||||
|
||||
module = "api-gateway"
|
||||
|
||||
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
|
||||
# are active listeners
|
||||
|
||||
# FIXME: Connection errors in publishers / subscribers cause those threads
|
||||
# to fail and are not failed or retried
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
from aiohttp import web
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
import uuid
|
||||
import json
|
||||
|
||||
import pulsar
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from .. log_level import LogLevel
|
||||
|
||||
from . serialize import to_subgraph
|
||||
from . running import Running
|
||||
|
||||
from .. schema import ConfigPush, config_push_queue
|
||||
|
||||
from . text_completion import TextCompletionRequestor
|
||||
from . prompt import PromptRequestor
|
||||
from . graph_rag import GraphRagRequestor
|
||||
#from . document_rag import DocumentRagRequestor
|
||||
from . triples_query import TriplesQueryRequestor
|
||||
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
||||
from . embeddings import EmbeddingsRequestor
|
||||
#from . encyclopedia import EncyclopediaRequestor
|
||||
from . agent import AgentRequestor
|
||||
#from . dbpedia import DbpediaRequestor
|
||||
#from . internet_search import InternetSearchRequestor
|
||||
#from . librarian import LibrarianRequestor
|
||||
from . config import ConfigRequestor
|
||||
from . flow import FlowRequestor
|
||||
#from . triples_stream import TriplesStreamEndpoint
|
||||
#from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
|
||||
#from . document_embeddings_stream import DocumentEmbeddingsStreamEndpoint
|
||||
#from . triples_load import TriplesLoadEndpoint
|
||||
#from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
|
||||
#from . document_embeddings_load import DocumentEmbeddingsLoadEndpoint
|
||||
from . mux import MuxEndpoint
|
||||
#from . document_load import DocumentLoadSender
|
||||
#from . text_load import TextLoadSender
|
||||
from . metrics import MetricsEndpoint
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . flow_endpoint import FlowEndpoint
|
||||
from . auth import Authenticator
|
||||
from .. base import Subscriber
|
||||
from .. base import Consumer
|
||||
from . config.receiver import ConfigReceiver
|
||||
from . dispatch.manager import DispatcherManager
|
||||
|
||||
from . endpoint.manager import EndpointManager
|
||||
|
||||
import pulsar
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -81,6 +40,7 @@ class Api:
|
|||
self.pulsar_api_key = config.get(
|
||||
"pulsar_api_key", default_pulsar_api_key
|
||||
)
|
||||
|
||||
self.pulsar_listener = config.get("pulsar_listener", None)
|
||||
|
||||
if self.pulsar_api_key:
|
||||
|
|
@ -108,278 +68,24 @@ class Api:
|
|||
else:
|
||||
self.auth = Authenticator(allow_all=True)
|
||||
|
||||
self.services = {
|
||||
# "text-completion": TextCompletionRequestor(
|
||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# "prompt": PromptRequestor(
|
||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# "graph-rag": GraphRagRequestor(
|
||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# "document-rag": DocumentRagRequestor(
|
||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# "triples-query": TriplesQueryRequestor(
|
||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# "graph-embeddings-query": GraphEmbeddingsQueryRequestor(
|
||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# "embeddings": EmbeddingsRequestor(
|
||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# "agent": AgentRequestor(
|
||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# "librarian": LibrarianRequestor(
|
||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
(None, "config"): ConfigRequestor(
|
||||
pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
self.config_receiver = ConfigReceiver(self.pulsar_client)
|
||||
|
||||
self.dispatcher_manager = DispatcherManager(
|
||||
pulsar_client = self.pulsar_client,
|
||||
config_receiver = self.config_receiver,
|
||||
)
|
||||
|
||||
self.endpoint_manager = EndpointManager(
|
||||
dispatcher_manager = self.dispatcher_manager,
|
||||
auth = self.auth,
|
||||
),
|
||||
(None, "flow"): FlowRequestor(
|
||||
pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
# "encyclopedia": EncyclopediaRequestor(
|
||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# "dbpedia": DbpediaRequestor(
|
||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# "internet-search": InternetSearchRequestor(
|
||||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# "document-load": DocumentLoadSender(
|
||||
# pulsar_client=self.pulsar_client,
|
||||
# ),
|
||||
# "text-load": TextLoadSender(
|
||||
# pulsar_client=self.pulsar_client,
|
||||
# ),
|
||||
}
|
||||
prometheus_url = self.prometheus_url,
|
||||
timeout = self.timeout,
|
||||
|
||||
)
|
||||
|
||||
self.endpoints = [
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/text-completion", auth=self.auth,
|
||||
# requestor = self.services["text-completion"],
|
||||
# ),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/prompt", auth=self.auth,
|
||||
# requestor = self.services["prompt"],
|
||||
# ),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/graph-rag", auth=self.auth,
|
||||
# requestor = self.services["graph-rag"],
|
||||
# ),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/document-rag", auth=self.auth,
|
||||
# requestor = self.services["document-rag"],
|
||||
# ),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/triples-query", auth=self.auth,
|
||||
# requestor = self.services["triples-query"],
|
||||
# ),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/graph-embeddings-query",
|
||||
# auth=self.auth,
|
||||
# requestor = self.services["graph-embeddings-query"],
|
||||
# ),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/embeddings", auth=self.auth,
|
||||
# requestor = self.services["embeddings"],
|
||||
# ),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/agent", auth=self.auth,
|
||||
# requestor = self.services["agent"],
|
||||
# ),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/librarian", auth=self.auth,
|
||||
# requestor = self.services["librarian"],
|
||||
# ),
|
||||
ServiceEndpoint(
|
||||
endpoint_path = "/api/v1/config", auth=self.auth,
|
||||
requestor = self.services[(None, "config")],
|
||||
),
|
||||
ServiceEndpoint(
|
||||
endpoint_path = "/api/v1/flow", auth=self.auth,
|
||||
requestor = self.services[(None, "flow")],
|
||||
),
|
||||
FlowEndpoint(
|
||||
endpoint_path = "/api/v1/flow/{flow}/{kind}",
|
||||
auth=self.auth,
|
||||
requestors = self.services,
|
||||
),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/encyclopedia", auth=self.auth,
|
||||
# requestor = self.services["encyclopedia"],
|
||||
# ),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/dbpedia", auth=self.auth,
|
||||
# requestor = self.services["dbpedia"],
|
||||
# ),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/internet-search", auth=self.auth,
|
||||
# requestor = self.services["internet-search"],
|
||||
# ),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/load/document", auth=self.auth,
|
||||
# requestor = self.services["document-load"],
|
||||
# ),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/load/text", auth=self.auth,
|
||||
# requestor = self.services["text-load"],
|
||||
# ),
|
||||
# TriplesStreamEndpoint(
|
||||
# pulsar_client=self.pulsar_client,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# GraphEmbeddingsStreamEndpoint(
|
||||
# pulsar_client=self.pulsar_client,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# DocumentEmbeddingsStreamEndpoint(
|
||||
# pulsar_client=self.pulsar_client,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# TriplesLoadEndpoint(
|
||||
# pulsar_client=self.pulsar_client,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# GraphEmbeddingsLoadEndpoint(
|
||||
# pulsar_client=self.pulsar_client,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# DocumentEmbeddingsLoadEndpoint(
|
||||
# pulsar_client=self.pulsar_client,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
# MuxEndpoint(
|
||||
# pulsar_client=self.pulsar_client,
|
||||
# auth = self.auth,
|
||||
# services = self.services,
|
||||
# ),
|
||||
MetricsEndpoint(
|
||||
endpoint_path = "/api/v1/metrics",
|
||||
prometheus_url = self.prometheus_url,
|
||||
auth = self.auth,
|
||||
),
|
||||
]
|
||||
|
||||
self.flows = {}
|
||||
|
||||
async def on_config(self, msg, proc, flow):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
print(f"Config version", v.version)
|
||||
|
||||
if "flows" in v.config:
|
||||
|
||||
flows = v.config["flows"]
|
||||
|
||||
wanted = list(flows.keys())
|
||||
current = list(self.flows.keys())
|
||||
|
||||
for k in wanted:
|
||||
if k not in current:
|
||||
self.flows[k] = json.loads(flows[k])
|
||||
await self.start_flow(k, self.flows[k])
|
||||
|
||||
for k in current:
|
||||
if k not in wanted:
|
||||
await self.stop_flow(k, self.flows[k])
|
||||
del self.flows[k]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}", flush=True)
|
||||
|
||||
async def start_flow(self, id, flow):
|
||||
|
||||
print("Start flow", id)
|
||||
intf = flow["interfaces"]
|
||||
|
||||
kinds = {
|
||||
"agent": AgentRequestor,
|
||||
"text-completion": TextCompletionRequestor,
|
||||
"prompt": PromptRequestor,
|
||||
"graph-rag": GraphRagRequestor,
|
||||
"embeddings": EmbeddingsRequestor,
|
||||
"graph-embeddings": GraphEmbeddingsQueryRequestor,
|
||||
"triples-query": TriplesQueryRequestor,
|
||||
}
|
||||
|
||||
for api_kind, requestor in kinds.items():
|
||||
|
||||
if api_kind in intf:
|
||||
k = (id, api_kind)
|
||||
if k in self.services:
|
||||
await self.services[k].stop()
|
||||
del self.services[k]
|
||||
|
||||
self.services[k] = requestor(
|
||||
pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
request_queue = intf[api_kind]["request"],
|
||||
response_queue = intf[api_kind]["response"],
|
||||
consumer = f"api-gateway-{id}-{api_kind}-request",
|
||||
subscriber = f"api-gateway-{id}-{api_kind}-request",
|
||||
auth = self.auth,
|
||||
)
|
||||
await self.services[k].start()
|
||||
|
||||
async def stop_flow(self, id, flow):
|
||||
print("Stop flow", id)
|
||||
intf = flow["interfaces"]
|
||||
|
||||
svc_list = list(self.services.keys())
|
||||
|
||||
for k in svc_list:
|
||||
|
||||
kid, kkind = k
|
||||
|
||||
if id == kid:
|
||||
await self.services[k].stop()
|
||||
del self.services[k]
|
||||
|
||||
async def config_loader(self):
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
self.config_cons = Consumer(
|
||||
taskgroup = tg,
|
||||
flow = None,
|
||||
client = self.pulsar_client,
|
||||
subscriber = f"gateway-{id}",
|
||||
topic = config_push_queue,
|
||||
schema = ConfigPush,
|
||||
handler = self.on_config,
|
||||
start_of_messages = True,
|
||||
)
|
||||
|
||||
await self.config_cons.start()
|
||||
|
||||
print("Waiting...")
|
||||
|
||||
print("Config consumer done. :/")
|
||||
|
||||
async def app_factory(self):
|
||||
|
||||
self.app = web.Application(
|
||||
|
|
@ -387,7 +93,8 @@ class Api:
|
|||
client_max_size=256 * 1024 * 1024
|
||||
)
|
||||
|
||||
asyncio.create_task(self.config_loader())
|
||||
await self.config_receiver.start()
|
||||
|
||||
|
||||
for ep in self.endpoints:
|
||||
ep.add_routes(self.app)
|
||||
|
|
@ -395,6 +102,9 @@ class Api:
|
|||
for ep in self.endpoints:
|
||||
await ep.start()
|
||||
|
||||
self.endpoint_manager.add_routes(self.app)
|
||||
await self.endpoint_manager.start()
|
||||
|
||||
return self.app
|
||||
|
||||
def run(self):
|
||||
|
|
|
|||
|
|
@ -1,72 +0,0 @@
|
|||
|
||||
import asyncio
|
||||
from aiohttp import web, WSMsgType
|
||||
import logging
|
||||
|
||||
from . running import Running
|
||||
|
||||
logger = logging.getLogger("socket")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class SocketEndpoint:
|
||||
|
||||
def __init__(
|
||||
self, endpoint_path, auth,
|
||||
):
|
||||
|
||||
self.path = endpoint_path
|
||||
self.auth = auth
|
||||
self.operation = "socket"
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
async for msg in ws:
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
# Ignore incoming message
|
||||
continue
|
||||
elif msg.type == WSMsgType.BINARY:
|
||||
# Ignore incoming message
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
running.stop()
|
||||
|
||||
async def handle(self, request):
|
||||
|
||||
try:
|
||||
token = request.query['token']
|
||||
except:
|
||||
token = ""
|
||||
|
||||
if not self.auth.permitted(token, self.operation):
|
||||
return web.HTTPUnauthorized()
|
||||
|
||||
running = Running()
|
||||
|
||||
# 50MB max message size
|
||||
ws = web.WebSocketResponse(max_msg_size=52428800)
|
||||
|
||||
await ws.prepare(request)
|
||||
|
||||
try:
|
||||
await self.listener(ws, running)
|
||||
except Exception as e:
|
||||
print("Socket exception:", e, flush=True)
|
||||
|
||||
running.stop()
|
||||
|
||||
await ws.close()
|
||||
|
||||
return ws
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
def add_routes(self, app):
|
||||
|
||||
app.add_routes([
|
||||
web.get(self.path, self.handle),
|
||||
])
|
||||
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
|
||||
import asyncio
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from .. schema import Metadata
|
||||
from .. schema import Triples
|
||||
from .. schema import triples_store_queue
|
||||
from .. base import Publisher
|
||||
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import to_subgraph
|
||||
|
||||
class TriplesLoadEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(self, pulsar_client, auth, path="/api/v1/load/triples"):
|
||||
|
||||
super(TriplesLoadEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_client=pulsar_client
|
||||
|
||||
self.publisher = Publisher(
|
||||
self.pulsar_client, triples_store_queue,
|
||||
schema=Triples
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.publisher.start()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
async for msg in ws:
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
break
|
||||
else:
|
||||
|
||||
data = msg.json()
|
||||
|
||||
elt = Triples(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
triples=to_subgraph(data["triples"]),
|
||||
)
|
||||
|
||||
await self.publisher.send(None, elt)
|
||||
|
||||
|
||||
running.stop()
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
import uuid
|
||||
|
||||
from .. schema import Triples
|
||||
from .. schema import triples_store_queue
|
||||
from .. base import Subscriber
|
||||
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import serialize_triples
|
||||
|
||||
class TriplesStreamEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(self, pulsar_client, auth, path="/api/v1/stream/triples"):
|
||||
|
||||
super(TriplesStreamEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_client=pulsar_client
|
||||
|
||||
self.subscriber = Subscriber(
|
||||
self.pulsar_client, triples_store_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
schema=Triples
|
||||
)
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
worker = asyncio.create_task(
|
||||
self.async_thread(ws, running)
|
||||
)
|
||||
|
||||
await super(TriplesStreamEndpoint, self).listener(ws, running)
|
||||
|
||||
await worker
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.subscriber.start()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
q = self.subscriber.subscribe_all(id)
|
||||
|
||||
while running.get():
|
||||
try:
|
||||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_triples(resp))
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
self.subscriber.unsubscribe_all(id)
|
||||
|
||||
running.stop()
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue