Flow API - update gateway (#357)

* Altered API to incorporate Flow IDs, refactored for dynamic start/stop of flows
* Gateway: Split endpoint / dispatcher for maintainability
This commit is contained in:
cybermaggedon 2025-05-02 21:11:50 +01:00 committed by GitHub
parent 450f664b1b
commit a70ae9793a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 1206 additions and 907 deletions

33
test-api/test-llm2-api Executable file
View file

@ -0,0 +1,33 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"system": "",
"prompt": "Add 2 and 3"
}
resp = requests.post(
f"{url}text-completion",
json=input,
)
if resp.status_code != 200:
raise RuntimeError(f"Status code: {resp.status_code}")
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["response"])
############################################################################

View file

@ -5,7 +5,7 @@ import json
import sys
import base64
url = "http://localhost:8088/api/v1/"
url = "http://localhost:8088/api/v1/flow/0000/document-load"
############################################################################
@ -88,10 +88,7 @@ input = {
}
resp = requests.post(
f"{url}load/document",
json=input,
)
resp = requests.post(url, json=input)
resp = resp.json()

View file

@ -5,7 +5,7 @@ import json
import sys
import base64
url = "http://localhost:8088/api/v1/"
url = "http://localhost:8088/api/v1/flow/0000/service/text-load"
############################################################################
@ -85,10 +85,7 @@ input = {
}
resp = requests.post(
f"{url}load/text",
json=input,
)
resp = requests.post(url, json=input)
resp = resp.json()

View file

@ -9,7 +9,7 @@ from trustgraph.schema import Document, Metadata
client = pulsar.Client("pulsar://localhost:6650", listener_name="localhost")
prod = client.create_producer(
topic="persistent://tg/flow/document-load:0002",
topic="persistent://tg/flow/document-load:0000",
schema=JsonSchema(Document),
chunking_enabled=True,
)

View file

@ -14,7 +14,7 @@ prod = client.create_producer(
chunking_enabled=True,
)
path = "docs/README.cats"
path = "../trustgraph/docs/README.cats"
with open(path, "r") as f:
# blob = base64.b64encode(f.read()).decode("utf-8")

View file

@ -15,16 +15,21 @@ class Publisher:
self.q = asyncio.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
self.running = True
self.task = None
async def start(self):
self.task = asyncio.create_task(self.run())
async def stop(self):
self.running = False
if self.task:
await self.task
async def join(self):
await self.stop()
if self.task:
await self.task
async def run(self):

View file

@ -19,6 +19,7 @@ class Subscriber:
self.lock = asyncio.Lock()
self.running = True
self.metrics = metrics
self.task = None
def __del__(self):
self.running = False
@ -28,10 +29,14 @@ class Subscriber:
async def stop(self):
self.running = False
if self.task:
await self.task
async def join(self):
await self.stop()
if self.task:
await self.task
async def run(self):
@ -45,6 +50,8 @@ class Subscriber:
try:
# FIXME: Create consumer in start method so we know
# it is definitely running when start completes
consumer = self.client.subscribe(
topic = self.topic,
subscription_name = self.subscription,

View file

@ -1,7 +1,6 @@
from pulsar.schema import JsonSchema
from .. schema import EmbeddingsRequest, EmbeddingsResponse
from .. schema import embeddings_request_queue, embeddings_response_queue
from . base import BaseClient
import _pulsar
@ -23,12 +22,6 @@ class EmbeddingsClient(BaseClient):
pulsar_api_key=None,
):
if input_queue == None:
input_queue=embeddings_request_queue
if output_queue == None:
output_queue=embeddings_response_queue
super(EmbeddingsClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
@ -43,4 +36,3 @@ class EmbeddingsClient(BaseClient):
def request(self, text, timeout=300):
return self.call(text=text, timeout=timeout).vectors

View file

@ -28,7 +28,7 @@ async def load_ge(running, queue, url):
async with aiohttp.ClientSession() as session:
async with session.ws_connect(f"{url}load/graph-embeddings") as ws:
async with session.ws_connect(url) as ws:
while running.get():
@ -73,7 +73,7 @@ async def load_triples(running, queue, url):
async with aiohttp.ClientSession() as session:
async with session.ws_connect(f"{url}load/triples") as ws:
async with session.ws_connect(url) as ws:
while running.get():
@ -200,6 +200,9 @@ async def run(running, **args):
ge_q = asyncio.Queue(maxsize=10)
t_q = asyncio.Queue(maxsize=10)
flow_id = args["flow_id"]
url = args["url"]
load_task = asyncio.create_task(
loader(
running=running,
@ -213,14 +216,16 @@ async def run(running, **args):
ge_task = asyncio.create_task(
load_ge(
running = running,
queue=ge_q, url=args["url"] + "api/v1/"
queue = ge_q,
url = f"{url}api/v1/flow/{flow_id}/import/graph-embeddings"
)
)
triples_task = asyncio.create_task(
load_triples(
running = running,
queue=t_q, url=args["url"] + "api/v1/"
queue = t_q,
url = f"{url}api/v1/flow/{flow_id}/import/triples"
)
)
@ -258,6 +263,12 @@ async def main(running):
help=f'Output file'
)
parser.add_argument(
'-f', '--flow-id',
default="0000",
help=f'Flow ID (default: 0000)'
)
parser.add_argument(
'--format',
default="msgpack",

View file

@ -27,7 +27,7 @@ async def fetch_ge(running, queue, user, collection, url):
async with aiohttp.ClientSession() as session:
async with session.ws_connect(f"{url}stream/graph-embeddings") as ws:
async with session.ws_connect(url) as ws:
while running.get():
@ -74,7 +74,7 @@ async def fetch_triples(running, queue, user, collection, url):
async with aiohttp.ClientSession() as session:
async with session.ws_connect(f"{url}stream/triples") as ws:
async with session.ws_connect(url) as ws:
while running.get():
@ -160,11 +160,14 @@ async def run(running, **args):
q = asyncio.Queue()
flow_id = args["flow_id"]
url = args["url"]
ge_task = asyncio.create_task(
fetch_ge(
running=running,
queue=q, user=args["user"], collection=args["collection"],
url=args["url"] + "api/v1/"
url = f"{url}api/v1/flow/{flow_id}/export/graph-embeddings"
)
)
@ -172,7 +175,7 @@ async def run(running, **args):
fetch_triples(
running=running, queue=q,
user=args["user"], collection=args["collection"],
url=args["url"] + "api/v1/"
url = f"{url}api/v1/flow/{flow_id}/export/triples"
)
)
@ -224,6 +227,12 @@ async def main(running):
help=f'Output format (default: msgpack)',
)
parser.add_argument(
'-f', '--flow-id',
default="0000",
help=f'Flow ID (default: 0000)'
)
parser.add_argument(
'--user',
help=f'User ID to filter on (default: no filter)'

View file

@ -0,0 +1,121 @@
"""
API gateway. Offers HTTP services which are translated to interaction on the
Pulsar bus.
"""
module = "api-gateway"
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
# are active listeners
# FIXME: Connection errors in publishers / subscribers cause those threads
# to fail and are not failed or retried
import asyncio
import argparse
from aiohttp import web
import logging
import os
import base64
import uuid
import json
import pulsar
from prometheus_client import start_http_server
from ... schema import ConfigPush, config_push_queue
from ... base import Consumer
logger = logging.getLogger("config.receiver")
logger.setLevel(logging.INFO)
class ConfigReceiver:
def __init__(self, pulsar_client):
self.pulsar_client = pulsar_client
self.flow_handlers = []
self.flows = {}
def add_handler(self, h):
self.flow_handlers.append(h)
async def on_config(self, msg, proc, flow):
try:
v = msg.value()
print(f"Config version", v.version)
if "flows" in v.config:
flows = v.config["flows"]
wanted = list(flows.keys())
current = list(self.flows.keys())
for k in wanted:
if k not in current:
self.flows[k] = json.loads(flows[k])
await self.start_flow(k, self.flows[k])
for k in current:
if k not in wanted:
await self.stop_flow(k, self.flows[k])
del self.flows[k]
except Exception as e:
print(f"Exception: {e}", flush=True)
async def start_flow(self, id, flow):
print("Start flow", id)
for handler in self.flow_handlers:
try:
await handler.start_flow(id, flow)
except Exception as e:
print(f"Exception: {e}", flush=True)
async def stop_flow(self, id, flow):
print("Stop flow", id)
for handler in self.flow_handlers:
try:
await handler.stop_flow(id, flow)
except Exception as e:
print(f"Exception: {e}", flush=True)
async def config_loader(self):
async with asyncio.TaskGroup() as tg:
id = str(uuid.uuid4())
self.config_cons = Consumer(
taskgroup = tg,
flow = None,
client = self.pulsar_client,
subscriber = f"gateway-{id}",
topic = config_push_queue,
schema = ConfigPush,
handler = self.on_config,
start_of_messages = True,
)
await self.config_cons.start()
print("Waiting...")
print("Config consumer done. :/")
async def start(self):
asyncio.create_task(self.config_loader())

View file

@ -1,12 +1,11 @@
from .. schema import AgentRequest, AgentResponse
from ... schema import AgentRequest, AgentResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class AgentRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth,
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):

View file

@ -1,13 +1,12 @@
from .. schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
from .. schema import config_request_queue
from .. schema import config_response_queue
from ... schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
from ... schema import config_request_queue
from ... schema import config_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class ConfigRequestor(ServiceRequestor):
def __init__(self, pulsar_client, timeout, auth):
def __init__(self, pulsar_client, timeout=120):
super(ConfigRequestor, self).__init__(
pulsar_client=pulsar_client,

View file

@ -0,0 +1,67 @@
import asyncio
import queue
import uuid
from ... schema import DocumentEmbeddings
from ... base import Subscriber
from . serialize import serialize_document_embeddings
class DocumentEmbeddingsExport:
def __init__(
self, ws, running, pulsar_client, queue, consumer, subscriber
):
self.ws = ws
self.running = running
self.pulsar_client = pulsar_client
self.queue = queue
self.consumer = consumer
self.subscriber = subscriber
async def destroy(self):
self.running.stop()
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = DocumentEmbeddings
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_document_embeddings(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()

View file

@ -0,0 +1,64 @@
import asyncio
import uuid
from aiohttp import WSMsgType
from ... schema import Metadata
from ... schema import DocumentEmbeddings, ChunkEmbeddings
from ... base import Publisher
from . serialize import to_subgraph
class DocumentEmbeddingsImport:
def __init__(
self, ws, running, pulsar_client, queue
):
self.ws = ws
self.running = running
self.publisher = Publisher(
pulsar_client, topic = queue, schema = DocumentEmbeddings
)
async def destroy(self):
self.running.stop()
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()
elt = DocumentEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
chunks=[
ChunkEmbeddings(
chunk=de["chunk"].encode("utf-8"),
vectors=de["vectors"],
)
for de in data["chunks"]
],
)
await self.publisher.send(None, elt)
async def run(self):
while self.running.get():
await asyncio.sleep(0.5)
if self.ws:
await self.ws.close()
self.ws = None

View file

@ -1,19 +1,18 @@
import base64
from .. schema import Document, Metadata
from .. schema import document_ingest_queue
from ... schema import Document, Metadata
from . sender import ServiceSender
from . serialize import to_subgraph
class DocumentLoadSender(ServiceSender):
def __init__(self, pulsar_client):
class DocumentLoad(ServiceSender):
def __init__(self, pulsar_client, queue):
super(DocumentLoadSender, self).__init__(
super(DocumentLoad, self).__init__(
pulsar_client = pulsar_client,
request_queue=document_ingest_queue,
request_schema=Document,
queue = queue,
schema = Document,
)
def to_request(self, body):

View file

@ -1,20 +1,22 @@
from .. schema import DocumentRagQuery, DocumentRagResponse
from .. schema import document_rag_request_queue
from .. schema import document_rag_response_queue
from ... schema import DocumentRagQuery, DocumentRagResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class DocumentRagRequestor(ServiceRequestor):
def __init__(self, pulsar_client, timeout, auth):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(DocumentRagRequestor, self).__init__(
pulsar_client=pulsar_client,
request_queue=document_rag_request_queue,
response_queue=document_rag_response_queue,
request_queue=request_queue,
response_queue=response_queue,
request_schema=DocumentRagQuery,
response_schema=DocumentRagResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout,
)

View file

@ -1,12 +1,11 @@
from .. schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class EmbeddingsRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth,
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):

View file

@ -1,13 +1,12 @@
from .. schema import FlowRequest, FlowResponse, ConfigKey, ConfigValue
from .. schema import flow_request_queue
from .. schema import flow_response_queue
from ... schema import FlowRequest, FlowResponse
from ... schema import flow_request_queue
from ... schema import flow_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class FlowRequestor(ServiceRequestor):
def __init__(self, pulsar_client, timeout, auth):
def __init__(self, pulsar_client, timeout=120):
super(FlowRequestor, self).__init__(
pulsar_client=pulsar_client,

View file

@ -0,0 +1,67 @@
import asyncio
import queue
import uuid
from ... schema import GraphEmbeddings
from ... base import Subscriber
from . serialize import serialize_graph_embeddings
class GraphEmbeddingsExport:
def __init__(
self, ws, running, pulsar_client, queue, consumer, subscriber
):
self.ws = ws
self.running = running
self.pulsar_client = pulsar_client
self.queue = queue
self.consumer = consumer
self.subscriber = subscriber
async def destroy(self):
self.running.stop()
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = GraphEmbeddings
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()

View file

@ -0,0 +1,64 @@
import asyncio
import uuid
from aiohttp import WSMsgType
from ... schema import Metadata
from ... schema import GraphEmbeddings, EntityEmbeddings
from ... base import Publisher
from . serialize import to_subgraph, to_value
class GraphEmbeddingsImport:
def __init__(
self, ws, running, pulsar_client, queue
):
self.ws = ws
self.running = running
self.publisher = Publisher(
pulsar_client, topic = queue, schema = GraphEmbeddings
)
async def destroy(self):
self.running.stop()
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()
elt = GraphEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entities=[
EntityEmbeddings(
entity=to_value(ent["entity"]),
vectors=ent["vectors"],
)
for ent in data["entities"]
]
)
await self.publisher.send(None, elt)
async def run(self):
while self.running.get():
await asyncio.sleep(0.5)
if self.ws:
await self.ws.close()
self.ws = None

View file

@ -1,13 +1,12 @@
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from ... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
from . serialize import serialize_value
class GraphEmbeddingsQueryRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth,
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):

View file

@ -1,12 +1,11 @@
from .. schema import GraphRagQuery, GraphRagResponse
from ... schema import GraphRagQuery, GraphRagResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class GraphRagRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth,
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):

View file

@ -1,15 +1,14 @@
from .. schema import LibrarianRequest, LibrarianResponse, Triples
from .. schema import librarian_request_queue
from .. schema import librarian_response_queue
from ... schema import LibrarianRequest, LibrarianResponse
from ... schema import librarian_request_queue
from ... schema import librarian_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
from . serialize import serialize_document_package, serialize_document_info
from . serialize import to_document_package, to_document_info, to_criteria
class LibrarianRequestor(ServiceRequestor):
def __init__(self, pulsar_client, timeout, auth):
def __init__(self, pulsar_client, timeout=120):
super(LibrarianRequestor, self).__init__(
pulsar_client=pulsar_client,
@ -22,20 +21,16 @@ class LibrarianRequestor(ServiceRequestor):
def to_request(self, body):
print("TRR")
if "document" in body:
dp = to_document_package(body["document"])
else:
dp = None
print("GOT")
if "criteria" in body:
criteria = to_criteria(body["criteria"])
else:
criteria = None
print("ASLDKJ")
return LibrarianRequest(
operation = body.get("operation", None),
id = body.get("id", None),

View file

@ -0,0 +1,229 @@
import asyncio
import uuid
from . config import ConfigRequestor
from . flow import FlowRequestor
from . librarian import LibrarianRequestor
from . embeddings import EmbeddingsRequestor
from . agent import AgentRequestor
from . text_completion import TextCompletionRequestor
from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor
from . document_rag import DocumentRagRequestor
from . triples_query import TriplesQueryRequestor
from . embeddings import EmbeddingsRequestor
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . prompt import PromptRequestor
from . text_load import TextLoad
from . document_load import DocumentLoad
from . triples_export import TriplesExport
from . graph_embeddings_export import GraphEmbeddingsExport
from . document_embeddings_export import DocumentEmbeddingsExport
from . triples_import import TriplesImport
from . graph_embeddings_import import GraphEmbeddingsImport
from . document_embeddings_import import DocumentEmbeddingsImport
request_response_dispatchers = {
"agent": AgentRequestor,
"text-completion": TextCompletionRequestor,
"prompt": PromptRequestor,
"graph-rag": GraphRagRequestor,
"document-rag": DocumentRagRequestor,
"embeddings": EmbeddingsRequestor,
"graph-embeddings": GraphEmbeddingsQueryRequestor,
"triples-query": TriplesQueryRequestor,
}
sender_dispatchers = {
"text-load": TextLoad,
"document-load": DocumentLoad,
}
export_dispatchers = {
"triples": TriplesExport,
"graph-embeddings": GraphEmbeddingsExport,
"document-embeddings": DocumentEmbeddingsExport,
}
import_dispatchers = {
"triples": TriplesImport,
"graph-embeddings": GraphEmbeddingsImport,
"document-embeddings": DocumentEmbeddingsImport,
}
class DispatcherWrapper:
def __init__(self, mgr, name, impl):
self.mgr = mgr
self.name = name
self.impl = impl
async def process(self, data, responder):
return await self.mgr.process_impl(
data, responder, self.name, self.impl
)
class DispatcherManager:
def __init__(self, pulsar_client, config_receiver):
self.pulsar_client = pulsar_client
self.config_receiver = config_receiver
self.config_receiver.add_handler(self)
self.flows = {}
self.dispatchers = {}
async def start_flow(self, id, flow):
print("Start flow", id)
self.flows[id] = flow
return
async def stop_flow(self, id, flow):
print("Stop flow", id)
del self.flows[id]
return
def dispatch_config(self):
return DispatcherWrapper(self, "config", ConfigRequestor)
def dispatch_flow(self):
return DispatcherWrapper(self, "flow", FlowRequestor)
def dispatch_librarian(self):
return DispatcherWrapper(self, "librarian", LibrarianRequestor)
async def process_impl(self, data, responder, name, impl):
key = (None, name)
if key in self.dispatchers:
return await self.dispatchers[key].process(data, responder)
dispatcher = impl(
pulsar_client = self.pulsar_client
)
await dispatcher.start()
self.dispatchers[key] = dispatcher
return await dispatcher.process(data, responder)
def dispatch_service(self):
return self
def dispatch_import(self):
return self.invoke_import
def dispatch_export(self):
return self.invoke_export
async def invoke_import(self, ws, running, params):
flow = params.get("flow")
kind = params.get("kind")
if flow not in self.flows:
raise RuntimeError("Invalid flow")
if kind not in import_dispatchers:
raise RuntimeError("Invalid kind")
key = (flow, kind)
intf_defs = self.flows[flow]["interfaces"]
if kind not in intf_defs:
raise RuntimeError("This kind not supported by flow")
# FIXME: The -store bit, does it make sense?
qconfig = intf_defs[kind + "-store"]
id = str(uuid.uuid4())
dispatcher = import_dispatchers[kind](
pulsar_client = self.pulsar_client,
ws = ws,
running = running,
queue = qconfig,
)
return dispatcher
async def invoke_export(self, ws, running, params):
flow = params.get("flow")
kind = params.get("kind")
if flow not in self.flows:
raise RuntimeError("Invalid flow")
if kind not in export_dispatchers:
raise RuntimeError("Invalid kind")
key = (flow, kind)
intf_defs = self.flows[flow]["interfaces"]
if kind not in intf_defs:
raise RuntimeError("This kind not supported by flow")
# FIXME: The -store bit, does it make sense?
qconfig = intf_defs[kind + "-store"]
id = str(uuid.uuid4())
dispatcher = export_dispatchers[kind](
pulsar_client = self.pulsar_client,
ws = ws,
running = running,
queue = qconfig,
consumer = f"api-gateway-{id}",
subscriber = f"api-gateway-{id}",
)
return dispatcher
async def process(self, data, responder, params):
flow = params.get("flow")
kind = params.get("kind")
if flow not in self.flows:
raise RuntimeError("Invalid flow")
key = (flow, kind)
if key in self.dispatchers:
return await self.dispatchers[key].process(data, responder)
intf_defs = self.flows[flow]["interfaces"]
if kind not in intf_defs:
raise RuntimeError("This kind not supported by flow")
qconfig = intf_defs[kind]
if kind in request_response_dispatchers:
dispatcher = request_response_dispatchers[kind](
pulsar_client = self.pulsar_client,
request_queue = qconfig["request"],
response_queue = qconfig["response"],
timeout = 120,
consumer = f"api-gateway-{flow}-{kind}-request",
subscriber = f"api-gateway-{flow}-{kind}-request",
)
elif kind in sender_dispatchers:
dispatcher = sender_dispatchers[kind](
pulsar_client = self.pulsar_client,
queue = qconfig,
)
else:
raise RuntimeError("Invalid kind")
await dispatcher.start()
self.dispatchers[key] = dispatcher
return await dispatcher.process(data, responder)

View file

@ -1,14 +1,13 @@
import json
from .. schema import PromptRequest, PromptResponse
from ... schema import PromptRequest, PromptResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class PromptRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth,
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):

View file

@ -3,8 +3,8 @@ import asyncio
import uuid
import logging
from .. base import Publisher
from .. base import Subscriber
from ... base import Publisher
from ... base import Subscriber
logger = logging.getLogger("requestor")
logger.setLevel(logging.INFO)
@ -33,13 +33,17 @@ class ServiceRequestor:
self.timeout = timeout
self.running = True
async def start(self):
await self.pub.start()
self.running = True
await self.sub.start()
await self.pub.start()
async def stop(self):
await self.pub.stop()
await self.sub.stop()
self.running = False
def to_request(self, request):
raise RuntimeError("Not defined")
@ -57,13 +61,14 @@ class ServiceRequestor:
await self.pub.send(id, self.to_request(request))
while True:
while self.running:
try:
resp = await asyncio.wait_for(
q.get(), timeout=self.timeout
)
except Exception as e:
print("Exception", e)
raise RuntimeError("Timeout")
if resp.error:

View file

@ -5,7 +5,7 @@ import asyncio
import uuid
import logging
from .. base import Publisher
from ... base import Publisher
logger = logging.getLogger("sender")
logger.setLevel(logging.INFO)
@ -15,18 +15,20 @@ class ServiceSender:
def __init__(
self,
pulsar_client,
request_queue, request_schema,
queue, schema,
):
self.pub = Publisher(
pulsar_client, request_queue,
schema=request_schema,
pulsar_client, queue,
schema=schema,
)
async def start(self):
await self.pub.start()
async def stop(self):
await self.pub.stop()
def to_request(self, request):
raise RuntimeError("Not defined")
@ -39,6 +41,8 @@ class ServiceSender:
if responder:
await responder({}, True)
return {}
except Exception as e:
logging.error(f"Exception: {e}")

View file

@ -1,7 +1,7 @@
import base64
from .. schema import Value, Triple, DocumentPackage, DocumentInfo
from ... schema import Value, Triple, DocumentPackage, DocumentInfo
def to_value(x):
return Value(value=x["v"], is_uri=x["e"])

View file

@ -0,0 +1,99 @@
import asyncio
import uuid
import logging
from ... base import Publisher
from ... base import Subscriber
logger = logging.getLogger("requestor")
logger.setLevel(logging.INFO)
class ServiceRequestor:
def __init__(
self,
pulsar_client,
queue, schema,
handler,
subscription="api-gateway", consumer_name="api-gateway",
timeout=600,
):
self.sub = Subscriber(
pulsar_client, queue,
subscription, consumer_name,
schema
)
self.timeout = timeout
self.running = True
self.receiver = handler
async def start(self):
await self.sub.start()
self.streamer = asyncio.create_task(self.stream())
sub.start()
self.running = True
async def stop(self):
await self.sub.stop()
self.running = False
def from_inbound(self, response):
raise RuntimeError("Not defined")
async def stream(self):
id = str(uuid.uuid4())
try:
q = await self.sub.subscribe(id)
while self.running:
try:
resp = await asyncio.wait_for(
q.get(), timeout=self.timeout
)
except Exception as e:
raise RuntimeError("Timeout")
if resp.error:
err = { "error": {
"type": resp.error.type,
"message": resp.error.message,
} }
fin = False
await self.receiver(err, fin)
else:
resp, fin = self.from_inbound(resp)
print(resp, fin)
await self.receiver(resp, fin)
if fin: break
except Exception as e:
logging.error(f"Exception: {e}")
err = { "error": {
"type": "gateway-error",
"message": str(e),
} }
if responder:
await responder(err, True)
return err
finally:
await self.sub.unsubscribe(id)

View file

@ -1,12 +1,11 @@
from .. schema import TextCompletionRequest, TextCompletionResponse
from ... schema import TextCompletionRequest, TextCompletionResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class TextCompletionRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth,
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):

View file

@ -1,19 +1,18 @@
import base64
from .. schema import TextDocument, Metadata
from .. schema import text_ingest_queue
from ... schema import TextDocument, Metadata
from . sender import ServiceSender
from . serialize import to_subgraph
class TextLoadSender(ServiceSender):
def __init__(self, pulsar_client):
class TextLoad(ServiceSender):
def __init__(self, pulsar_client, queue):
super(TextLoadSender, self).__init__(
super(TextLoad, self).__init__(
pulsar_client = pulsar_client,
request_queue=text_ingest_queue,
request_schema=TextDocument,
queue = queue,
schema = TextDocument,
)
def to_request(self, body):

View file

@ -0,0 +1,67 @@
import asyncio
import queue
import uuid
from ... schema import Triples
from ... base import Subscriber
from . serialize import serialize_triples
class TriplesExport:
def __init__(
self, ws, running, pulsar_client, queue, consumer, subscriber
):
self.ws = ws
self.running = running
self.pulsar_client = pulsar_client
self.queue = queue
self.consumer = consumer
self.subscriber = subscriber
async def destroy(self):
self.running.stop()
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = Triples
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_triples(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()

View file

@ -0,0 +1,58 @@
import asyncio
import uuid
from aiohttp import WSMsgType
from ... schema import Metadata
from ... schema import Triples
from ... base import Publisher
from . serialize import to_subgraph
class TriplesImport:
def __init__(
self, ws, running, pulsar_client, queue
):
self.ws = ws
self.running = running
self.publisher = Publisher(
pulsar_client, topic = queue, schema = Triples
)
async def destroy(self):
self.running.stop()
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()
elt = Triples(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
triples=to_subgraph(data["triples"]),
)
await self.publisher.send(None, elt)
async def run(self):
while self.running.get():
await asyncio.sleep(0.5)
if self.ws:
await self.ws.close()
self.ws = None

View file

@ -1,13 +1,12 @@
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Triples
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
from . serialize import to_value, serialize_subgraph
class TriplesQueryRequestor(ServiceRequestor):
def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth,
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):

View file

@ -1,63 +0,0 @@
import asyncio
import uuid
from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import DocumentEmbeddings, ChunkEmbeddings
from .. schema import document_embeddings_store_queue
from .. base import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph
class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
def __init__(
self, pulsar_client, auth, path="/api/v1/load/document-embeddings",
):
super(DocumentEmbeddingsLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_client=pulsar_client
self.publisher = Publisher(
self.pulsar_client, document_embeddings_store_queue,
schema=DocumentEmbeddings
)
async def start(self):
await self.publisher.start()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = DocumentEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
chunks=[
ChunkEmbeddings(
chunk=de["chunk"].encode("utf-8"),
vectors=de["vectors"],
)
for de in data["chunks"]
],
)
await self.publisher.send(None, elt)
running.stop()

View file

@ -1,72 +0,0 @@
import asyncio
import queue
import uuid
from .. schema import DocumentEmbeddings
from .. schema import document_embeddings_store_queue
from .. base import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_document_embeddings
class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
def __init__(
self, pulsar_client, auth,
path="/api/v1/stream/document-embeddings"
):
super(DocumentEmbeddingsStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_client=pulsar_client
self.subscriber = Subscriber(
self.pulsar_client, document_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=DocumentEmbeddings,
)
async def listener(self, ws, running):
worker = asyncio.create_task(
self.async_thread(ws, running)
)
await super(DocumentEmbeddingsStreamEndpoint, self).listener(
ws, running
)
await worker
async def start(self):
await self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = await self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await ws.send_json(serialize_document_embeddings(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -7,19 +7,19 @@ import logging
logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO)
class ServiceEndpoint:
class ConstantEndpoint:
def __init__(self, endpoint_path, auth, requestor):
def __init__(self, endpoint_path, auth, dispatcher):
self.path = endpoint_path
self.auth = auth
self.operation = "service"
self.requestor = requestor
self.dispatcher = dispatcher
async def start(self):
await self.requestor.start()
pass
def add_routes(self, app):
@ -52,7 +52,7 @@ class ServiceEndpoint:
async def responder(x, fin):
print(x)
resp = await self.requestor.process(data, responder)
resp = await self.dispatcher.process(data, responder)
return web.json_response(resp)

View file

@ -0,0 +1,67 @@
import asyncio
from aiohttp import web
from . constant_endpoint import ConstantEndpoint
from . variable_endpoint import VariableEndpoint
from . socket import SocketEndpoint
from . metrics import MetricsEndpoint
from .. dispatch.manager import DispatcherManager
class EndpointManager:
def __init__(
self, dispatcher_manager, auth, prometheus_url, timeout=600
):
self.dispatcher_manager = dispatcher_manager
self.timeout = timeout
self.services = {
}
self.endpoints = [
ConstantEndpoint(
endpoint_path = "/api/v1/librarian", auth = auth,
dispatcher = dispatcher_manager.dispatch_librarian(),
),
ConstantEndpoint(
endpoint_path = "/api/v1/config", auth = auth,
dispatcher = dispatcher_manager.dispatch_config(),
),
ConstantEndpoint(
endpoint_path = "/api/v1/flow", auth = auth,
dispatcher = dispatcher_manager.dispatch_flow(),
),
MetricsEndpoint(
endpoint_path = "/api/v1/metrics",
prometheus_url = prometheus_url,
auth = auth,
),
VariableEndpoint(
endpoint_path = "/api/v1/flow/{flow}/service/{kind}",
auth = auth,
dispatcher = dispatcher_manager.dispatch_service(),
),
SocketEndpoint(
endpoint_path = "/api/v1/flow/{flow}/import/{kind}",
auth = auth,
dispatcher = dispatcher_manager.dispatch_import()
),
SocketEndpoint(
endpoint_path = "/api/v1/flow/{flow}/export/{kind}",
auth = auth,
dispatcher = dispatcher_manager.dispatch_export()
),
]
def add_routes(self, app):
for ep in self.endpoints:
ep.add_routes(app)
async def start(self):
for ep in self.endpoints:
await ep.start()

View file

@ -0,0 +1,111 @@
import asyncio
from aiohttp import web, WSMsgType
import logging
from .. running import Running
logger = logging.getLogger("socket")
logger.setLevel(logging.INFO)
class SocketEndpoint:
def __init__(
self, endpoint_path, auth, dispatcher,
):
self.path = endpoint_path
self.auth = auth
self.operation = "socket"
self.dispatcher = dispatcher
async def worker(self, ws, dispatcher, running):
await dispatcher.run()
async def listener(self, ws, dispatcher, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.TEXT:
await dispatcher.receive(msg)
continue
elif msg.type == WSMsgType.BINARY:
await dispatcher.receive(msg)
continue
else:
break
running.stop()
await ws.close()
async def handle(self, request):
try:
token = request.query['token']
except:
token = ""
if not self.auth.permitted(token, self.operation):
return web.HTTPUnauthorized()
# 50MB max message size
ws = web.WebSocketResponse(max_msg_size=52428800)
await ws.prepare(request)
try:
async with asyncio.TaskGroup() as tg:
running = Running()
dispatcher = await self.dispatcher(
ws, running, request.match_info
)
worker_task = tg.create_task(
self.worker(ws, dispatcher, running)
)
lsnr_task = tg.create_task(
self.listener(ws, dispatcher, running)
)
print("Created taskgroup, waiting...")
# Wait for threads to complete
print("Task group closed")
# Finally?
await dispatcher.destroy()
except ExceptionGroup as e:
print("Exception group:", flush=True)
for se in e.exceptions:
print(" Type:", type(se), flush=True)
print(f" Exception: {se}", flush=True)
except Exception as e:
print("Socket exception:", e, flush=True)
await ws.close()
return ws
async def start(self):
pass
async def stop(self):
self.running.stop()
def add_routes(self, app):
app.add_routes([
web.get(self.path, self.handle),
])

View file

@ -4,26 +4,25 @@ from aiohttp import web
import uuid
import logging
logger = logging.getLogger("flow-endpoint")
logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO)
class FlowEndpoint:
class VariableEndpoint:
def __init__(self, endpoint_path, auth, requestors):
def __init__(self, endpoint_path, auth, dispatcher):
self.path = endpoint_path
self.auth = auth
self.operation = "service"
self.requestors = requestors
self.dispatcher = dispatcher
async def start(self):
pass
def add_routes(self, app):
pass
app.add_routes([
web.post(self.path, self.handle),
])
@ -32,15 +31,6 @@ class FlowEndpoint:
print(request.path, "...")
flow_id = request.match_info['flow']
kind = request.match_info['kind']
k = (flow_id, kind)
if k not in self.requestors:
raise web.HTTPBadRequest()
requestor = self.requestors[k]
try:
ht = request.headers["Authorization"]
tokens = ht.split(" ", 2)
@ -62,7 +52,9 @@ class FlowEndpoint:
async def responder(x, fin):
print(x)
resp = await requestor.process(data, responder)
resp = await self.dispatcher.process(
data, responder, request.match_info
)
return web.json_response(resp)

View file

@ -1,64 +0,0 @@
import asyncio
import uuid
from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import GraphEmbeddings, EntityEmbeddings
from .. schema import graph_embeddings_store_queue
from .. base import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph, to_value
class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
def __init__(
self, pulsar_client, auth, path="/api/v1/load/graph-embeddings",
):
super(GraphEmbeddingsLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_client=pulsar_client
self.publisher = Publisher(
self.pulsar_client, graph_embeddings_store_queue,
schema=GraphEmbeddings
)
async def start(self):
await self.publisher.start()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = GraphEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entities=[
EntityEmbeddings(
entity=to_value(ent["entity"]),
vectors=ent["vectors"],
)
for ent in data["entities"]
]
)
await self.publisher.send(None, elt)
running.stop()

View file

@ -1,69 +0,0 @@
import asyncio
import queue
import uuid
from .. schema import GraphEmbeddings
from .. schema import graph_embeddings_store_queue
from .. base import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_graph_embeddings
class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
def __init__(
self, pulsar_client, auth, path="/api/v1/stream/graph-embeddings"
):
super(GraphEmbeddingsStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_client=pulsar_client
self.subscriber = Subscriber(
self.pulsar_client, graph_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=GraphEmbeddings
)
async def listener(self, ws, running):
worker = asyncio.create_task(
self.async_thread(ws, running)
)
await super(GraphEmbeddingsStreamEndpoint, self).listener(ws, running)
await worker
async def start(self):
await self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = await self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.wait_for(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -3,63 +3,22 @@ API gateway. Offers HTTP services which are translated to interaction on the
Pulsar bus.
"""
module = "api-gateway"
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
# are active listeners
# FIXME: Connection errors in publishers / subscribers cause those threads
# to fail and are not failed or retried
import asyncio
import argparse
from aiohttp import web
import logging
import os
import base64
import uuid
import json
import pulsar
from prometheus_client import start_http_server
from .. log_level import LogLevel
from . serialize import to_subgraph
from . running import Running
from .. schema import ConfigPush, config_push_queue
from . text_completion import TextCompletionRequestor
from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor
#from . document_rag import DocumentRagRequestor
from . triples_query import TriplesQueryRequestor
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . embeddings import EmbeddingsRequestor
#from . encyclopedia import EncyclopediaRequestor
from . agent import AgentRequestor
#from . dbpedia import DbpediaRequestor
#from . internet_search import InternetSearchRequestor
#from . librarian import LibrarianRequestor
from . config import ConfigRequestor
from . flow import FlowRequestor
#from . triples_stream import TriplesStreamEndpoint
#from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
#from . document_embeddings_stream import DocumentEmbeddingsStreamEndpoint
#from . triples_load import TriplesLoadEndpoint
#from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
#from . document_embeddings_load import DocumentEmbeddingsLoadEndpoint
from . mux import MuxEndpoint
#from . document_load import DocumentLoadSender
#from . text_load import TextLoadSender
from . metrics import MetricsEndpoint
from . endpoint import ServiceEndpoint
from . flow_endpoint import FlowEndpoint
from . auth import Authenticator
from .. base import Subscriber
from .. base import Consumer
from . config.receiver import ConfigReceiver
from . dispatch.manager import DispatcherManager
from . endpoint.manager import EndpointManager
import pulsar
from prometheus_client import start_http_server
logger = logging.getLogger("api")
logger.setLevel(logging.INFO)
@ -81,6 +40,7 @@ class Api:
self.pulsar_api_key = config.get(
"pulsar_api_key", default_pulsar_api_key
)
self.pulsar_listener = config.get("pulsar_listener", None)
if self.pulsar_api_key:
@ -108,278 +68,24 @@ class Api:
else:
self.auth = Authenticator(allow_all=True)
self.services = {
# "text-completion": TextCompletionRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "prompt": PromptRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "graph-rag": GraphRagRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "document-rag": DocumentRagRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "triples-query": TriplesQueryRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "graph-embeddings-query": GraphEmbeddingsQueryRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "embeddings": EmbeddingsRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "agent": AgentRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "librarian": LibrarianRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
(None, "config"): ConfigRequestor(
pulsar_client=self.pulsar_client, timeout=self.timeout,
self.config_receiver = ConfigReceiver(self.pulsar_client)
self.dispatcher_manager = DispatcherManager(
pulsar_client = self.pulsar_client,
config_receiver = self.config_receiver,
)
self.endpoint_manager = EndpointManager(
dispatcher_manager = self.dispatcher_manager,
auth = self.auth,
),
(None, "flow"): FlowRequestor(
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
# "encyclopedia": EncyclopediaRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "dbpedia": DbpediaRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "internet-search": InternetSearchRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "document-load": DocumentLoadSender(
# pulsar_client=self.pulsar_client,
# ),
# "text-load": TextLoadSender(
# pulsar_client=self.pulsar_client,
# ),
}
prometheus_url = self.prometheus_url,
timeout = self.timeout,
)
self.endpoints = [
# ServiceEndpoint(
# endpoint_path = "/api/v1/text-completion", auth=self.auth,
# requestor = self.services["text-completion"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/prompt", auth=self.auth,
# requestor = self.services["prompt"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/graph-rag", auth=self.auth,
# requestor = self.services["graph-rag"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/document-rag", auth=self.auth,
# requestor = self.services["document-rag"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/triples-query", auth=self.auth,
# requestor = self.services["triples-query"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/graph-embeddings-query",
# auth=self.auth,
# requestor = self.services["graph-embeddings-query"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/embeddings", auth=self.auth,
# requestor = self.services["embeddings"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/agent", auth=self.auth,
# requestor = self.services["agent"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/librarian", auth=self.auth,
# requestor = self.services["librarian"],
# ),
ServiceEndpoint(
endpoint_path = "/api/v1/config", auth=self.auth,
requestor = self.services[(None, "config")],
),
ServiceEndpoint(
endpoint_path = "/api/v1/flow", auth=self.auth,
requestor = self.services[(None, "flow")],
),
FlowEndpoint(
endpoint_path = "/api/v1/flow/{flow}/{kind}",
auth=self.auth,
requestors = self.services,
),
# ServiceEndpoint(
# endpoint_path = "/api/v1/encyclopedia", auth=self.auth,
# requestor = self.services["encyclopedia"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/dbpedia", auth=self.auth,
# requestor = self.services["dbpedia"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/internet-search", auth=self.auth,
# requestor = self.services["internet-search"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/load/document", auth=self.auth,
# requestor = self.services["document-load"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/load/text", auth=self.auth,
# requestor = self.services["text-load"],
# ),
# TriplesStreamEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# ),
# GraphEmbeddingsStreamEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# ),
# DocumentEmbeddingsStreamEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# ),
# TriplesLoadEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# ),
# GraphEmbeddingsLoadEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# ),
# DocumentEmbeddingsLoadEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# ),
# MuxEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# services = self.services,
# ),
MetricsEndpoint(
endpoint_path = "/api/v1/metrics",
prometheus_url = self.prometheus_url,
auth = self.auth,
),
]
self.flows = {}
async def on_config(self, msg, proc, flow):
try:
v = msg.value()
print(f"Config version", v.version)
if "flows" in v.config:
flows = v.config["flows"]
wanted = list(flows.keys())
current = list(self.flows.keys())
for k in wanted:
if k not in current:
self.flows[k] = json.loads(flows[k])
await self.start_flow(k, self.flows[k])
for k in current:
if k not in wanted:
await self.stop_flow(k, self.flows[k])
del self.flows[k]
except Exception as e:
print(f"Exception: {e}", flush=True)
async def start_flow(self, id, flow):
print("Start flow", id)
intf = flow["interfaces"]
kinds = {
"agent": AgentRequestor,
"text-completion": TextCompletionRequestor,
"prompt": PromptRequestor,
"graph-rag": GraphRagRequestor,
"embeddings": EmbeddingsRequestor,
"graph-embeddings": GraphEmbeddingsQueryRequestor,
"triples-query": TriplesQueryRequestor,
}
for api_kind, requestor in kinds.items():
if api_kind in intf:
k = (id, api_kind)
if k in self.services:
await self.services[k].stop()
del self.services[k]
self.services[k] = requestor(
pulsar_client=self.pulsar_client, timeout=self.timeout,
request_queue = intf[api_kind]["request"],
response_queue = intf[api_kind]["response"],
consumer = f"api-gateway-{id}-{api_kind}-request",
subscriber = f"api-gateway-{id}-{api_kind}-request",
auth = self.auth,
)
await self.services[k].start()
async def stop_flow(self, id, flow):
print("Stop flow", id)
intf = flow["interfaces"]
svc_list = list(self.services.keys())
for k in svc_list:
kid, kkind = k
if id == kid:
await self.services[k].stop()
del self.services[k]
async def config_loader(self):
async with asyncio.TaskGroup() as tg:
id = str(uuid.uuid4())
self.config_cons = Consumer(
taskgroup = tg,
flow = None,
client = self.pulsar_client,
subscriber = f"gateway-{id}",
topic = config_push_queue,
schema = ConfigPush,
handler = self.on_config,
start_of_messages = True,
)
await self.config_cons.start()
print("Waiting...")
print("Config consumer done. :/")
async def app_factory(self):
self.app = web.Application(
@ -387,7 +93,8 @@ class Api:
client_max_size=256 * 1024 * 1024
)
asyncio.create_task(self.config_loader())
await self.config_receiver.start()
for ep in self.endpoints:
ep.add_routes(self.app)
@ -395,6 +102,9 @@ class Api:
for ep in self.endpoints:
await ep.start()
self.endpoint_manager.add_routes(self.app)
await self.endpoint_manager.start()
return self.app
def run(self):

View file

@ -1,72 +0,0 @@
import asyncio
from aiohttp import web, WSMsgType
import logging
from . running import Running
logger = logging.getLogger("socket")
logger.setLevel(logging.INFO)
class SocketEndpoint:
def __init__(
self, endpoint_path, auth,
):
self.path = endpoint_path
self.auth = auth
self.operation = "socket"
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.TEXT:
# Ignore incoming message
continue
elif msg.type == WSMsgType.BINARY:
# Ignore incoming message
continue
else:
break
running.stop()
async def handle(self, request):
try:
token = request.query['token']
except:
token = ""
if not self.auth.permitted(token, self.operation):
return web.HTTPUnauthorized()
running = Running()
# 50MB max message size
ws = web.WebSocketResponse(max_msg_size=52428800)
await ws.prepare(request)
try:
await self.listener(ws, running)
except Exception as e:
print("Socket exception:", e, flush=True)
running.stop()
await ws.close()
return ws
async def start(self):
pass
def add_routes(self, app):
app.add_routes([
web.get(self.path, self.handle),
])

View file

@ -1,56 +0,0 @@
import asyncio
import uuid
from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import Triples
from .. schema import triples_store_queue
from .. base import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph
class TriplesLoadEndpoint(SocketEndpoint):
def __init__(self, pulsar_client, auth, path="/api/v1/load/triples"):
super(TriplesLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_client=pulsar_client
self.publisher = Publisher(
self.pulsar_client, triples_store_queue,
schema=Triples
)
async def start(self):
await self.publisher.start()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = Triples(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
triples=to_subgraph(data["triples"]),
)
await self.publisher.send(None, elt)
running.stop()

View file

@ -1,67 +0,0 @@
import asyncio
import queue
import uuid
from .. schema import Triples
from .. schema import triples_store_queue
from .. base import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_triples
class TriplesStreamEndpoint(SocketEndpoint):
def __init__(self, pulsar_client, auth, path="/api/v1/stream/triples"):
super(TriplesStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_client=pulsar_client
self.subscriber = Subscriber(
self.pulsar_client, triples_store_queue,
"api-gateway", "api-gateway",
schema=Triples
)
async def listener(self, ws, running):
worker = asyncio.create_task(
self.async_thread(ws, running)
)
await super(TriplesStreamEndpoint, self).listener(ws, running)
await worker
async def start(self):
await self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_triples(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
self.subscriber.unsubscribe_all(id)
running.stop()