Flow API - update gateway (#357)

* Altered API to incorporate Flow IDs, refactored for dynamic start/stop of flows
* Gateway: Split endpoint / dispatcher for maintainability
This commit is contained in:
cybermaggedon 2025-05-02 21:11:50 +01:00 committed by GitHub
parent 450f664b1b
commit a70ae9793a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 1206 additions and 907 deletions

33
test-api/test-llm2-api Executable file
View file

@ -0,0 +1,33 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"system": "",
"prompt": "Add 2 and 3"
}
resp = requests.post(
f"{url}text-completion",
json=input,
)
if resp.status_code != 200:
raise RuntimeError(f"Status code: {resp.status_code}")
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["response"])
############################################################################

View file

@ -5,7 +5,7 @@ import json
import sys import sys
import base64 import base64
url = "http://localhost:8088/api/v1/" url = "http://localhost:8088/api/v1/flow/0000/document-load"
############################################################################ ############################################################################
@ -88,10 +88,7 @@ input = {
} }
resp = requests.post( resp = requests.post(url, json=input)
f"{url}load/document",
json=input,
)
resp = resp.json() resp = resp.json()

View file

@ -5,7 +5,7 @@ import json
import sys import sys
import base64 import base64
url = "http://localhost:8088/api/v1/" url = "http://localhost:8088/api/v1/flow/0000/service/text-load"
############################################################################ ############################################################################
@ -85,10 +85,7 @@ input = {
} }
resp = requests.post( resp = requests.post(url, json=input)
f"{url}load/text",
json=input,
)
resp = resp.json() resp = resp.json()

View file

@ -9,7 +9,7 @@ from trustgraph.schema import Document, Metadata
client = pulsar.Client("pulsar://localhost:6650", listener_name="localhost") client = pulsar.Client("pulsar://localhost:6650", listener_name="localhost")
prod = client.create_producer( prod = client.create_producer(
topic="persistent://tg/flow/document-load:0002", topic="persistent://tg/flow/document-load:0000",
schema=JsonSchema(Document), schema=JsonSchema(Document),
chunking_enabled=True, chunking_enabled=True,
) )

View file

@ -14,7 +14,7 @@ prod = client.create_producer(
chunking_enabled=True, chunking_enabled=True,
) )
path = "docs/README.cats" path = "../trustgraph/docs/README.cats"
with open(path, "r") as f: with open(path, "r") as f:
# blob = base64.b64encode(f.read()).decode("utf-8") # blob = base64.b64encode(f.read()).decode("utf-8")

View file

@ -15,17 +15,22 @@ class Publisher:
self.q = asyncio.Queue(maxsize=max_size) self.q = asyncio.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled self.chunking_enabled = chunking_enabled
self.running = True self.running = True
self.task = None
async def start(self): async def start(self):
self.task = asyncio.create_task(self.run()) self.task = asyncio.create_task(self.run())
async def stop(self): async def stop(self):
self.running = False self.running = False
await self.task
if self.task:
await self.task
async def join(self): async def join(self):
await self.stop() await self.stop()
await self.task
if self.task:
await self.task
async def run(self): async def run(self):

View file

@ -19,6 +19,7 @@ class Subscriber:
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.running = True self.running = True
self.metrics = metrics self.metrics = metrics
self.task = None
def __del__(self): def __del__(self):
self.running = False self.running = False
@ -28,11 +29,15 @@ class Subscriber:
async def stop(self): async def stop(self):
self.running = False self.running = False
await self.task
if self.task:
await self.task
async def join(self): async def join(self):
await self.stop() await self.stop()
await self.task
if self.task:
await self.task
async def run(self): async def run(self):
@ -45,6 +50,8 @@ class Subscriber:
try: try:
# FIXME: Create consumer in start method so we know
# it is definitely running when start completes
consumer = self.client.subscribe( consumer = self.client.subscribe(
topic = self.topic, topic = self.topic,
subscription_name = self.subscription, subscription_name = self.subscription,

View file

@ -1,7 +1,6 @@
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
from .. schema import EmbeddingsRequest, EmbeddingsResponse from .. schema import EmbeddingsRequest, EmbeddingsResponse
from .. schema import embeddings_request_queue, embeddings_response_queue
from . base import BaseClient from . base import BaseClient
import _pulsar import _pulsar
@ -23,12 +22,6 @@ class EmbeddingsClient(BaseClient):
pulsar_api_key=None, pulsar_api_key=None,
): ):
if input_queue == None:
input_queue=embeddings_request_queue
if output_queue == None:
output_queue=embeddings_response_queue
super(EmbeddingsClient, self).__init__( super(EmbeddingsClient, self).__init__(
log_level=log_level, log_level=log_level,
subscriber=subscriber, subscriber=subscriber,
@ -43,4 +36,3 @@ class EmbeddingsClient(BaseClient):
def request(self, text, timeout=300): def request(self, text, timeout=300):
return self.call(text=text, timeout=timeout).vectors return self.call(text=text, timeout=timeout).vectors

View file

@ -28,7 +28,7 @@ async def load_ge(running, queue, url):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.ws_connect(f"{url}load/graph-embeddings") as ws: async with session.ws_connect(url) as ws:
while running.get(): while running.get():
@ -73,7 +73,7 @@ async def load_triples(running, queue, url):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.ws_connect(f"{url}load/triples") as ws: async with session.ws_connect(url) as ws:
while running.get(): while running.get():
@ -200,6 +200,9 @@ async def run(running, **args):
ge_q = asyncio.Queue(maxsize=10) ge_q = asyncio.Queue(maxsize=10)
t_q = asyncio.Queue(maxsize=10) t_q = asyncio.Queue(maxsize=10)
flow_id = args["flow_id"]
url = args["url"]
load_task = asyncio.create_task( load_task = asyncio.create_task(
loader( loader(
running=running, running=running,
@ -212,15 +215,17 @@ async def run(running, **args):
ge_task = asyncio.create_task( ge_task = asyncio.create_task(
load_ge( load_ge(
running=running, running = running,
queue=ge_q, url=args["url"] + "api/v1/" queue = ge_q,
url = f"{url}api/v1/flow/{flow_id}/import/graph-embeddings"
) )
) )
triples_task = asyncio.create_task( triples_task = asyncio.create_task(
load_triples( load_triples(
running=running, running = running,
queue=t_q, url=args["url"] + "api/v1/" queue = t_q,
url = f"{url}api/v1/flow/{flow_id}/import/triples"
) )
) )
@ -258,6 +263,12 @@ async def main(running):
help=f'Output file' help=f'Output file'
) )
parser.add_argument(
'-f', '--flow-id',
default="0000",
help=f'Flow ID (default: 0000)'
)
parser.add_argument( parser.add_argument(
'--format', '--format',
default="msgpack", default="msgpack",

View file

@ -27,7 +27,7 @@ async def fetch_ge(running, queue, user, collection, url):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.ws_connect(f"{url}stream/graph-embeddings") as ws: async with session.ws_connect(url) as ws:
while running.get(): while running.get():
@ -74,7 +74,7 @@ async def fetch_triples(running, queue, user, collection, url):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.ws_connect(f"{url}stream/triples") as ws: async with session.ws_connect(url) as ws:
while running.get(): while running.get():
@ -160,11 +160,14 @@ async def run(running, **args):
q = asyncio.Queue() q = asyncio.Queue()
flow_id = args["flow_id"]
url = args["url"]
ge_task = asyncio.create_task( ge_task = asyncio.create_task(
fetch_ge( fetch_ge(
running=running, running=running,
queue=q, user=args["user"], collection=args["collection"], queue=q, user=args["user"], collection=args["collection"],
url=args["url"] + "api/v1/" url = f"{url}api/v1/flow/{flow_id}/export/graph-embeddings"
) )
) )
@ -172,7 +175,7 @@ async def run(running, **args):
fetch_triples( fetch_triples(
running=running, queue=q, running=running, queue=q,
user=args["user"], collection=args["collection"], user=args["user"], collection=args["collection"],
url=args["url"] + "api/v1/" url = f"{url}api/v1/flow/{flow_id}/export/triples"
) )
) )
@ -224,6 +227,12 @@ async def main(running):
help=f'Output format (default: msgpack)', help=f'Output format (default: msgpack)',
) )
parser.add_argument(
'-f', '--flow-id',
default="0000",
help=f'Flow ID (default: 0000)'
)
parser.add_argument( parser.add_argument(
'--user', '--user',
help=f'User ID to filter on (default: no filter)' help=f'User ID to filter on (default: no filter)'

View file

@ -0,0 +1,121 @@
"""
API gateway. Offers HTTP services which are translated to interaction on the
Pulsar bus.
"""
module = "api-gateway"
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
# are active listeners
# FIXME: Connection errors in publishers / subscribers cause those threads
# to fail and are not failed or retried
import asyncio
import argparse
from aiohttp import web
import logging
import os
import base64
import uuid
import json
import pulsar
from prometheus_client import start_http_server
from ... schema import ConfigPush, config_push_queue
from ... base import Consumer
logger = logging.getLogger("config.receiver")
logger.setLevel(logging.INFO)
class ConfigReceiver:
def __init__(self, pulsar_client):
self.pulsar_client = pulsar_client
self.flow_handlers = []
self.flows = {}
def add_handler(self, h):
self.flow_handlers.append(h)
async def on_config(self, msg, proc, flow):
try:
v = msg.value()
print(f"Config version", v.version)
if "flows" in v.config:
flows = v.config["flows"]
wanted = list(flows.keys())
current = list(self.flows.keys())
for k in wanted:
if k not in current:
self.flows[k] = json.loads(flows[k])
await self.start_flow(k, self.flows[k])
for k in current:
if k not in wanted:
await self.stop_flow(k, self.flows[k])
del self.flows[k]
except Exception as e:
print(f"Exception: {e}", flush=True)
async def start_flow(self, id, flow):
print("Start flow", id)
for handler in self.flow_handlers:
try:
await handler.start_flow(id, flow)
except Exception as e:
print(f"Exception: {e}", flush=True)
async def stop_flow(self, id, flow):
print("Stop flow", id)
for handler in self.flow_handlers:
try:
await handler.stop_flow(id, flow)
except Exception as e:
print(f"Exception: {e}", flush=True)
async def config_loader(self):
async with asyncio.TaskGroup() as tg:
id = str(uuid.uuid4())
self.config_cons = Consumer(
taskgroup = tg,
flow = None,
client = self.pulsar_client,
subscriber = f"gateway-{id}",
topic = config_push_queue,
schema = ConfigPush,
handler = self.on_config,
start_of_messages = True,
)
await self.config_cons.start()
print("Waiting...")
print("Config consumer done. :/")
async def start(self):
asyncio.create_task(self.config_loader())

View file

@ -1,12 +1,11 @@
from .. schema import AgentRequest, AgentResponse from ... schema import AgentRequest, AgentResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor from . requestor import ServiceRequestor
class AgentRequestor(ServiceRequestor): class AgentRequestor(ServiceRequestor):
def __init__( def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth, self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber, consumer, subscriber,
): ):

View file

@ -1,13 +1,12 @@
from .. schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue from ... schema import ConfigRequest, ConfigResponse, ConfigKey, ConfigValue
from .. schema import config_request_queue from ... schema import config_request_queue
from .. schema import config_response_queue from ... schema import config_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor from . requestor import ServiceRequestor
class ConfigRequestor(ServiceRequestor): class ConfigRequestor(ServiceRequestor):
def __init__(self, pulsar_client, timeout, auth): def __init__(self, pulsar_client, timeout=120):
super(ConfigRequestor, self).__init__( super(ConfigRequestor, self).__init__(
pulsar_client=pulsar_client, pulsar_client=pulsar_client,

View file

@ -0,0 +1,67 @@
import asyncio
import queue
import uuid
from ... schema import DocumentEmbeddings
from ... base import Subscriber
from . serialize import serialize_document_embeddings
class DocumentEmbeddingsExport:
def __init__(
self, ws, running, pulsar_client, queue, consumer, subscriber
):
self.ws = ws
self.running = running
self.pulsar_client = pulsar_client
self.queue = queue
self.consumer = consumer
self.subscriber = subscriber
async def destroy(self):
self.running.stop()
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = DocumentEmbeddings
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_document_embeddings(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()

View file

@ -0,0 +1,64 @@
import asyncio
import uuid
from aiohttp import WSMsgType
from ... schema import Metadata
from ... schema import DocumentEmbeddings, ChunkEmbeddings
from ... base import Publisher
from . serialize import to_subgraph
class DocumentEmbeddingsImport:
def __init__(
self, ws, running, pulsar_client, queue
):
self.ws = ws
self.running = running
self.publisher = Publisher(
pulsar_client, topic = queue, schema = DocumentEmbeddings
)
async def destroy(self):
self.running.stop()
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()
elt = DocumentEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
chunks=[
ChunkEmbeddings(
chunk=de["chunk"].encode("utf-8"),
vectors=de["vectors"],
)
for de in data["chunks"]
],
)
await self.publisher.send(None, elt)
async def run(self):
while self.running.get():
await asyncio.sleep(0.5)
if self.ws:
await self.ws.close()
self.ws = None

View file

@ -1,19 +1,18 @@
import base64 import base64
from .. schema import Document, Metadata from ... schema import Document, Metadata
from .. schema import document_ingest_queue
from . sender import ServiceSender from . sender import ServiceSender
from . serialize import to_subgraph from . serialize import to_subgraph
class DocumentLoadSender(ServiceSender): class DocumentLoad(ServiceSender):
def __init__(self, pulsar_client): def __init__(self, pulsar_client, queue):
super(DocumentLoadSender, self).__init__( super(DocumentLoad, self).__init__(
pulsar_client=pulsar_client, pulsar_client = pulsar_client,
request_queue=document_ingest_queue, queue = queue,
request_schema=Document, schema = Document,
) )
def to_request(self, body): def to_request(self, body):

View file

@ -1,20 +1,22 @@
from .. schema import DocumentRagQuery, DocumentRagResponse from ... schema import DocumentRagQuery, DocumentRagResponse
from .. schema import document_rag_request_queue
from .. schema import document_rag_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor from . requestor import ServiceRequestor
class DocumentRagRequestor(ServiceRequestor): class DocumentRagRequestor(ServiceRequestor):
def __init__(self, pulsar_client, timeout, auth): def __init__(
self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber,
):
super(DocumentRagRequestor, self).__init__( super(DocumentRagRequestor, self).__init__(
pulsar_client=pulsar_client, pulsar_client=pulsar_client,
request_queue=document_rag_request_queue, request_queue=request_queue,
response_queue=document_rag_response_queue, response_queue=response_queue,
request_schema=DocumentRagQuery, request_schema=DocumentRagQuery,
response_schema=DocumentRagResponse, response_schema=DocumentRagResponse,
subscription = subscriber,
consumer_name = consumer,
timeout=timeout, timeout=timeout,
) )

View file

@ -1,12 +1,11 @@
from .. schema import EmbeddingsRequest, EmbeddingsResponse from ... schema import EmbeddingsRequest, EmbeddingsResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor from . requestor import ServiceRequestor
class EmbeddingsRequestor(ServiceRequestor): class EmbeddingsRequestor(ServiceRequestor):
def __init__( def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth, self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber, consumer, subscriber,
): ):

View file

@ -1,13 +1,12 @@
from .. schema import FlowRequest, FlowResponse, ConfigKey, ConfigValue from ... schema import FlowRequest, FlowResponse
from .. schema import flow_request_queue from ... schema import flow_request_queue
from .. schema import flow_response_queue from ... schema import flow_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor from . requestor import ServiceRequestor
class FlowRequestor(ServiceRequestor): class FlowRequestor(ServiceRequestor):
def __init__(self, pulsar_client, timeout, auth): def __init__(self, pulsar_client, timeout=120):
super(FlowRequestor, self).__init__( super(FlowRequestor, self).__init__(
pulsar_client=pulsar_client, pulsar_client=pulsar_client,

View file

@ -0,0 +1,67 @@
import asyncio
import queue
import uuid
from ... schema import GraphEmbeddings
from ... base import Subscriber
from . serialize import serialize_graph_embeddings
class GraphEmbeddingsExport:
def __init__(
self, ws, running, pulsar_client, queue, consumer, subscriber
):
self.ws = ws
self.running = running
self.pulsar_client = pulsar_client
self.queue = queue
self.consumer = consumer
self.subscriber = subscriber
async def destroy(self):
self.running.stop()
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = GraphEmbeddings
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()

View file

@ -0,0 +1,64 @@
import asyncio
import uuid
from aiohttp import WSMsgType
from ... schema import Metadata
from ... schema import GraphEmbeddings, EntityEmbeddings
from ... base import Publisher
from . serialize import to_subgraph, to_value
class GraphEmbeddingsImport:
def __init__(
self, ws, running, pulsar_client, queue
):
self.ws = ws
self.running = running
self.publisher = Publisher(
pulsar_client, topic = queue, schema = GraphEmbeddings
)
async def destroy(self):
self.running.stop()
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()
elt = GraphEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entities=[
EntityEmbeddings(
entity=to_value(ent["entity"]),
vectors=ent["vectors"],
)
for ent in data["entities"]
]
)
await self.publisher.send(None, elt)
async def run(self):
while self.running.get():
await asyncio.sleep(0.5)
if self.ws:
await self.ws.close()
self.ws = None

View file

@ -1,13 +1,12 @@
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse from ... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor from . requestor import ServiceRequestor
from . serialize import serialize_value from . serialize import serialize_value
class GraphEmbeddingsQueryRequestor(ServiceRequestor): class GraphEmbeddingsQueryRequestor(ServiceRequestor):
def __init__( def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth, self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber, consumer, subscriber,
): ):

View file

@ -1,12 +1,11 @@
from .. schema import GraphRagQuery, GraphRagResponse from ... schema import GraphRagQuery, GraphRagResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor from . requestor import ServiceRequestor
class GraphRagRequestor(ServiceRequestor): class GraphRagRequestor(ServiceRequestor):
def __init__( def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth, self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber, consumer, subscriber,
): ):

View file

@ -1,15 +1,14 @@
from .. schema import LibrarianRequest, LibrarianResponse, Triples from ... schema import LibrarianRequest, LibrarianResponse
from .. schema import librarian_request_queue from ... schema import librarian_request_queue
from .. schema import librarian_response_queue from ... schema import librarian_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor from . requestor import ServiceRequestor
from . serialize import serialize_document_package, serialize_document_info from . serialize import serialize_document_package, serialize_document_info
from . serialize import to_document_package, to_document_info, to_criteria from . serialize import to_document_package, to_document_info, to_criteria
class LibrarianRequestor(ServiceRequestor): class LibrarianRequestor(ServiceRequestor):
def __init__(self, pulsar_client, timeout, auth): def __init__(self, pulsar_client, timeout=120):
super(LibrarianRequestor, self).__init__( super(LibrarianRequestor, self).__init__(
pulsar_client=pulsar_client, pulsar_client=pulsar_client,
@ -22,20 +21,16 @@ class LibrarianRequestor(ServiceRequestor):
def to_request(self, body): def to_request(self, body):
print("TRR")
if "document" in body: if "document" in body:
dp = to_document_package(body["document"]) dp = to_document_package(body["document"])
else: else:
dp = None dp = None
print("GOT")
if "criteria" in body: if "criteria" in body:
criteria = to_criteria(body["criteria"]) criteria = to_criteria(body["criteria"])
else: else:
criteria = None criteria = None
print("ASLDKJ")
return LibrarianRequest( return LibrarianRequest(
operation = body.get("operation", None), operation = body.get("operation", None),
id = body.get("id", None), id = body.get("id", None),

View file

@ -0,0 +1,229 @@
import asyncio
import uuid
from . config import ConfigRequestor
from . flow import FlowRequestor
from . librarian import LibrarianRequestor
from . embeddings import EmbeddingsRequestor
from . agent import AgentRequestor
from . text_completion import TextCompletionRequestor
from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor
from . document_rag import DocumentRagRequestor
from . triples_query import TriplesQueryRequestor
from . embeddings import EmbeddingsRequestor
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . prompt import PromptRequestor
from . text_load import TextLoad
from . document_load import DocumentLoad
from . triples_export import TriplesExport
from . graph_embeddings_export import GraphEmbeddingsExport
from . document_embeddings_export import DocumentEmbeddingsExport
from . triples_import import TriplesImport
from . graph_embeddings_import import GraphEmbeddingsImport
from . document_embeddings_import import DocumentEmbeddingsImport
request_response_dispatchers = {
"agent": AgentRequestor,
"text-completion": TextCompletionRequestor,
"prompt": PromptRequestor,
"graph-rag": GraphRagRequestor,
"document-rag": DocumentRagRequestor,
"embeddings": EmbeddingsRequestor,
"graph-embeddings": GraphEmbeddingsQueryRequestor,
"triples-query": TriplesQueryRequestor,
}
sender_dispatchers = {
"text-load": TextLoad,
"document-load": DocumentLoad,
}
export_dispatchers = {
"triples": TriplesExport,
"graph-embeddings": GraphEmbeddingsExport,
"document-embeddings": DocumentEmbeddingsExport,
}
import_dispatchers = {
"triples": TriplesImport,
"graph-embeddings": GraphEmbeddingsImport,
"document-embeddings": DocumentEmbeddingsImport,
}
class DispatcherWrapper:
def __init__(self, mgr, name, impl):
self.mgr = mgr
self.name = name
self.impl = impl
async def process(self, data, responder):
return await self.mgr.process_impl(
data, responder, self.name, self.impl
)
class DispatcherManager:
def __init__(self, pulsar_client, config_receiver):
self.pulsar_client = pulsar_client
self.config_receiver = config_receiver
self.config_receiver.add_handler(self)
self.flows = {}
self.dispatchers = {}
async def start_flow(self, id, flow):
print("Start flow", id)
self.flows[id] = flow
return
async def stop_flow(self, id, flow):
print("Stop flow", id)
del self.flows[id]
return
def dispatch_config(self):
return DispatcherWrapper(self, "config", ConfigRequestor)
def dispatch_flow(self):
return DispatcherWrapper(self, "flow", FlowRequestor)
def dispatch_librarian(self):
return DispatcherWrapper(self, "librarian", LibrarianRequestor)
async def process_impl(self, data, responder, name, impl):
key = (None, name)
if key in self.dispatchers:
return await self.dispatchers[key].process(data, responder)
dispatcher = impl(
pulsar_client = self.pulsar_client
)
await dispatcher.start()
self.dispatchers[key] = dispatcher
return await dispatcher.process(data, responder)
def dispatch_service(self):
return self
def dispatch_import(self):
return self.invoke_import
def dispatch_export(self):
return self.invoke_export
async def invoke_import(self, ws, running, params):
flow = params.get("flow")
kind = params.get("kind")
if flow not in self.flows:
raise RuntimeError("Invalid flow")
if kind not in import_dispatchers:
raise RuntimeError("Invalid kind")
key = (flow, kind)
intf_defs = self.flows[flow]["interfaces"]
if kind not in intf_defs:
raise RuntimeError("This kind not supported by flow")
# FIXME: The -store bit, does it make sense?
qconfig = intf_defs[kind + "-store"]
id = str(uuid.uuid4())
dispatcher = import_dispatchers[kind](
pulsar_client = self.pulsar_client,
ws = ws,
running = running,
queue = qconfig,
)
return dispatcher
async def invoke_export(self, ws, running, params):
flow = params.get("flow")
kind = params.get("kind")
if flow not in self.flows:
raise RuntimeError("Invalid flow")
if kind not in export_dispatchers:
raise RuntimeError("Invalid kind")
key = (flow, kind)
intf_defs = self.flows[flow]["interfaces"]
if kind not in intf_defs:
raise RuntimeError("This kind not supported by flow")
# FIXME: The -store bit, does it make sense?
qconfig = intf_defs[kind + "-store"]
id = str(uuid.uuid4())
dispatcher = export_dispatchers[kind](
pulsar_client = self.pulsar_client,
ws = ws,
running = running,
queue = qconfig,
consumer = f"api-gateway-{id}",
subscriber = f"api-gateway-{id}",
)
return dispatcher
async def process(self, data, responder, params):
flow = params.get("flow")
kind = params.get("kind")
if flow not in self.flows:
raise RuntimeError("Invalid flow")
key = (flow, kind)
if key in self.dispatchers:
return await self.dispatchers[key].process(data, responder)
intf_defs = self.flows[flow]["interfaces"]
if kind not in intf_defs:
raise RuntimeError("This kind not supported by flow")
qconfig = intf_defs[kind]
if kind in request_response_dispatchers:
dispatcher = request_response_dispatchers[kind](
pulsar_client = self.pulsar_client,
request_queue = qconfig["request"],
response_queue = qconfig["response"],
timeout = 120,
consumer = f"api-gateway-{flow}-{kind}-request",
subscriber = f"api-gateway-{flow}-{kind}-request",
)
elif kind in sender_dispatchers:
dispatcher = sender_dispatchers[kind](
pulsar_client = self.pulsar_client,
queue = qconfig,
)
else:
raise RuntimeError("Invalid kind")
await dispatcher.start()
self.dispatchers[key] = dispatcher
return await dispatcher.process(data, responder)

View file

@ -1,14 +1,13 @@
import json import json
from .. schema import PromptRequest, PromptResponse from ... schema import PromptRequest, PromptResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor from . requestor import ServiceRequestor
class PromptRequestor(ServiceRequestor): class PromptRequestor(ServiceRequestor):
def __init__( def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth, self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber, consumer, subscriber,
): ):

View file

@ -3,8 +3,8 @@ import asyncio
import uuid import uuid
import logging import logging
from .. base import Publisher from ... base import Publisher
from .. base import Subscriber from ... base import Subscriber
logger = logging.getLogger("requestor") logger = logging.getLogger("requestor")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -33,13 +33,17 @@ class ServiceRequestor:
self.timeout = timeout self.timeout = timeout
self.running = True
async def start(self): async def start(self):
await self.pub.start() self.running = True
await self.sub.start() await self.sub.start()
await self.pub.start()
async def stop(self): async def stop(self):
await self.pub.stop() await self.pub.stop()
await self.sub.stop() await self.sub.stop()
self.running = False
def to_request(self, request): def to_request(self, request):
raise RuntimeError("Not defined") raise RuntimeError("Not defined")
@ -57,13 +61,14 @@ class ServiceRequestor:
await self.pub.send(id, self.to_request(request)) await self.pub.send(id, self.to_request(request))
while True: while self.running:
try: try:
resp = await asyncio.wait_for( resp = await asyncio.wait_for(
q.get(), timeout=self.timeout q.get(), timeout=self.timeout
) )
except Exception as e: except Exception as e:
print("Exception", e)
raise RuntimeError("Timeout") raise RuntimeError("Timeout")
if resp.error: if resp.error:

View file

@ -5,7 +5,7 @@ import asyncio
import uuid import uuid
import logging import logging
from .. base import Publisher from ... base import Publisher
logger = logging.getLogger("sender") logger = logging.getLogger("sender")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -15,18 +15,20 @@ class ServiceSender:
def __init__( def __init__(
self, self,
pulsar_client, pulsar_client,
request_queue, request_schema, queue, schema,
): ):
self.pub = Publisher( self.pub = Publisher(
pulsar_client, request_queue, pulsar_client, queue,
schema=request_schema, schema=schema,
) )
async def start(self): async def start(self):
await self.pub.start() await self.pub.start()
async def stop(self):
await self.pub.stop()
def to_request(self, request): def to_request(self, request):
raise RuntimeError("Not defined") raise RuntimeError("Not defined")
@ -39,6 +41,8 @@ class ServiceSender:
if responder: if responder:
await responder({}, True) await responder({}, True)
return {}
except Exception as e: except Exception as e:
logging.error(f"Exception: {e}") logging.error(f"Exception: {e}")

View file

@ -1,7 +1,7 @@
import base64 import base64
from .. schema import Value, Triple, DocumentPackage, DocumentInfo from ... schema import Value, Triple, DocumentPackage, DocumentInfo
def to_value(x): def to_value(x):
return Value(value=x["v"], is_uri=x["e"]) return Value(value=x["v"], is_uri=x["e"])

View file

@ -0,0 +1,99 @@
import asyncio
import uuid
import logging
from ... base import Publisher
from ... base import Subscriber
logger = logging.getLogger("requestor")
logger.setLevel(logging.INFO)
class ServiceRequestor:
def __init__(
self,
pulsar_client,
queue, schema,
handler,
subscription="api-gateway", consumer_name="api-gateway",
timeout=600,
):
self.sub = Subscriber(
pulsar_client, queue,
subscription, consumer_name,
schema
)
self.timeout = timeout
self.running = True
self.receiver = handler
async def start(self):
await self.sub.start()
self.streamer = asyncio.create_task(self.stream())
sub.start()
self.running = True
async def stop(self):
await self.sub.stop()
self.running = False
def from_inbound(self, response):
raise RuntimeError("Not defined")
async def stream(self):
id = str(uuid.uuid4())
try:
q = await self.sub.subscribe(id)
while self.running:
try:
resp = await asyncio.wait_for(
q.get(), timeout=self.timeout
)
except Exception as e:
raise RuntimeError("Timeout")
if resp.error:
err = { "error": {
"type": resp.error.type,
"message": resp.error.message,
} }
fin = False
await self.receiver(err, fin)
else:
resp, fin = self.from_inbound(resp)
print(resp, fin)
await self.receiver(resp, fin)
if fin: break
except Exception as e:
logging.error(f"Exception: {e}")
err = { "error": {
"type": "gateway-error",
"message": str(e),
} }
if responder:
await responder(err, True)
return err
finally:
await self.sub.unsubscribe(id)

View file

@ -1,12 +1,11 @@
from .. schema import TextCompletionRequest, TextCompletionResponse from ... schema import TextCompletionRequest, TextCompletionResponse
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor from . requestor import ServiceRequestor
class TextCompletionRequestor(ServiceRequestor): class TextCompletionRequestor(ServiceRequestor):
def __init__( def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth, self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber, consumer, subscriber,
): ):

View file

@ -1,19 +1,18 @@
import base64 import base64
from .. schema import TextDocument, Metadata from ... schema import TextDocument, Metadata
from .. schema import text_ingest_queue
from . sender import ServiceSender from . sender import ServiceSender
from . serialize import to_subgraph from . serialize import to_subgraph
class TextLoadSender(ServiceSender): class TextLoad(ServiceSender):
def __init__(self, pulsar_client): def __init__(self, pulsar_client, queue):
super(TextLoadSender, self).__init__( super(TextLoad, self).__init__(
pulsar_client=pulsar_client, pulsar_client = pulsar_client,
request_queue=text_ingest_queue, queue = queue,
request_schema=TextDocument, schema = TextDocument,
) )
def to_request(self, body): def to_request(self, body):

View file

@ -0,0 +1,67 @@
import asyncio
import queue
import uuid
from ... schema import Triples
from ... base import Subscriber
from . serialize import serialize_triples
class TriplesExport:
def __init__(
self, ws, running, pulsar_client, queue, consumer, subscriber
):
self.ws = ws
self.running = running
self.pulsar_client = pulsar_client
self.queue = queue
self.consumer = consumer
self.subscriber = subscriber
async def destroy(self):
self.running.stop()
await self.ws.close()
async def receive(self, msg):
# Ignore incoming info from websocket
pass
async def run(self):
subs = Subscriber(
client = self.pulsar_client, topic = self.queue,
consumer_name = self.consumer, subscription = self.subscriber,
schema = Triples
)
await subs.start()
id = str(uuid.uuid4())
q = await subs.subscribe_all(id)
while self.running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await self.ws.send_json(serialize_triples(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await subs.unsubscribe_all(id)
await subs.stop()
await self.ws.close()
self.running.stop()

View file

@ -0,0 +1,58 @@
import asyncio
import uuid
from aiohttp import WSMsgType
from ... schema import Metadata
from ... schema import Triples
from ... base import Publisher
from . serialize import to_subgraph
class TriplesImport:
def __init__(
self, ws, running, pulsar_client, queue
):
self.ws = ws
self.running = running
self.publisher = Publisher(
pulsar_client, topic = queue, schema = Triples
)
async def destroy(self):
self.running.stop()
if self.ws:
await self.ws.close()
await self.publisher.stop()
async def receive(self, msg):
data = msg.json()
elt = Triples(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
triples=to_subgraph(data["triples"]),
)
await self.publisher.send(None, elt)
async def run(self):
while self.running.get():
await asyncio.sleep(0.5)
if self.ws:
await self.ws.close()
self.ws = None

View file

@ -1,13 +1,12 @@
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Triples from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor from . requestor import ServiceRequestor
from . serialize import to_value, serialize_subgraph from . serialize import to_value, serialize_subgraph
class TriplesQueryRequestor(ServiceRequestor): class TriplesQueryRequestor(ServiceRequestor):
def __init__( def __init__(
self, pulsar_client, request_queue, response_queue, timeout, auth, self, pulsar_client, request_queue, response_queue, timeout,
consumer, subscriber, consumer, subscriber,
): ):

View file

@ -1,63 +0,0 @@
import asyncio
import uuid
from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import DocumentEmbeddings, ChunkEmbeddings
from .. schema import document_embeddings_store_queue
from .. base import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph
class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
def __init__(
self, pulsar_client, auth, path="/api/v1/load/document-embeddings",
):
super(DocumentEmbeddingsLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_client=pulsar_client
self.publisher = Publisher(
self.pulsar_client, document_embeddings_store_queue,
schema=DocumentEmbeddings
)
async def start(self):
await self.publisher.start()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = DocumentEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
chunks=[
ChunkEmbeddings(
chunk=de["chunk"].encode("utf-8"),
vectors=de["vectors"],
)
for de in data["chunks"]
],
)
await self.publisher.send(None, elt)
running.stop()

View file

@ -1,72 +0,0 @@
import asyncio
import queue
import uuid
from .. schema import DocumentEmbeddings
from .. schema import document_embeddings_store_queue
from .. base import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_document_embeddings
class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
def __init__(
self, pulsar_client, auth,
path="/api/v1/stream/document-embeddings"
):
super(DocumentEmbeddingsStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_client=pulsar_client
self.subscriber = Subscriber(
self.pulsar_client, document_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=DocumentEmbeddings,
)
async def listener(self, ws, running):
worker = asyncio.create_task(
self.async_thread(ws, running)
)
await super(DocumentEmbeddingsStreamEndpoint, self).listener(
ws, running
)
await worker
async def start(self):
await self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = await self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.wait_for(q.get(), timeout=0.5)
await ws.send_json(serialize_document_embeddings(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -7,19 +7,19 @@ import logging
logger = logging.getLogger("endpoint") logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class ServiceEndpoint: class ConstantEndpoint:
def __init__(self, endpoint_path, auth, requestor): def __init__(self, endpoint_path, auth, dispatcher):
self.path = endpoint_path self.path = endpoint_path
self.auth = auth self.auth = auth
self.operation = "service" self.operation = "service"
self.requestor = requestor self.dispatcher = dispatcher
async def start(self): async def start(self):
await self.requestor.start() pass
def add_routes(self, app): def add_routes(self, app):
@ -52,7 +52,7 @@ class ServiceEndpoint:
async def responder(x, fin): async def responder(x, fin):
print(x) print(x)
resp = await self.requestor.process(data, responder) resp = await self.dispatcher.process(data, responder)
return web.json_response(resp) return web.json_response(resp)

View file

@ -0,0 +1,67 @@
import asyncio
from aiohttp import web
from . constant_endpoint import ConstantEndpoint
from . variable_endpoint import VariableEndpoint
from . socket import SocketEndpoint
from . metrics import MetricsEndpoint
from .. dispatch.manager import DispatcherManager
class EndpointManager:
def __init__(
self, dispatcher_manager, auth, prometheus_url, timeout=600
):
self.dispatcher_manager = dispatcher_manager
self.timeout = timeout
self.services = {
}
self.endpoints = [
ConstantEndpoint(
endpoint_path = "/api/v1/librarian", auth = auth,
dispatcher = dispatcher_manager.dispatch_librarian(),
),
ConstantEndpoint(
endpoint_path = "/api/v1/config", auth = auth,
dispatcher = dispatcher_manager.dispatch_config(),
),
ConstantEndpoint(
endpoint_path = "/api/v1/flow", auth = auth,
dispatcher = dispatcher_manager.dispatch_flow(),
),
MetricsEndpoint(
endpoint_path = "/api/v1/metrics",
prometheus_url = prometheus_url,
auth = auth,
),
VariableEndpoint(
endpoint_path = "/api/v1/flow/{flow}/service/{kind}",
auth = auth,
dispatcher = dispatcher_manager.dispatch_service(),
),
SocketEndpoint(
endpoint_path = "/api/v1/flow/{flow}/import/{kind}",
auth = auth,
dispatcher = dispatcher_manager.dispatch_import()
),
SocketEndpoint(
endpoint_path = "/api/v1/flow/{flow}/export/{kind}",
auth = auth,
dispatcher = dispatcher_manager.dispatch_export()
),
]
def add_routes(self, app):
for ep in self.endpoints:
ep.add_routes(app)
async def start(self):
for ep in self.endpoints:
await ep.start()

View file

@ -0,0 +1,111 @@
import asyncio
from aiohttp import web, WSMsgType
import logging
from .. running import Running
logger = logging.getLogger("socket")
logger.setLevel(logging.INFO)
class SocketEndpoint:
def __init__(
self, endpoint_path, auth, dispatcher,
):
self.path = endpoint_path
self.auth = auth
self.operation = "socket"
self.dispatcher = dispatcher
async def worker(self, ws, dispatcher, running):
await dispatcher.run()
async def listener(self, ws, dispatcher, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.TEXT:
await dispatcher.receive(msg)
continue
elif msg.type == WSMsgType.BINARY:
await dispatcher.receive(msg)
continue
else:
break
running.stop()
await ws.close()
async def handle(self, request):
try:
token = request.query['token']
except:
token = ""
if not self.auth.permitted(token, self.operation):
return web.HTTPUnauthorized()
# 50MB max message size
ws = web.WebSocketResponse(max_msg_size=52428800)
await ws.prepare(request)
try:
async with asyncio.TaskGroup() as tg:
running = Running()
dispatcher = await self.dispatcher(
ws, running, request.match_info
)
worker_task = tg.create_task(
self.worker(ws, dispatcher, running)
)
lsnr_task = tg.create_task(
self.listener(ws, dispatcher, running)
)
print("Created taskgroup, waiting...")
# Wait for threads to complete
print("Task group closed")
# Finally?
await dispatcher.destroy()
except ExceptionGroup as e:
print("Exception group:", flush=True)
for se in e.exceptions:
print(" Type:", type(se), flush=True)
print(f" Exception: {se}", flush=True)
except Exception as e:
print("Socket exception:", e, flush=True)
await ws.close()
return ws
async def start(self):
pass
async def stop(self):
self.running.stop()
def add_routes(self, app):
app.add_routes([
web.get(self.path, self.handle),
])

View file

@ -4,26 +4,25 @@ from aiohttp import web
import uuid import uuid
import logging import logging
logger = logging.getLogger("flow-endpoint") logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
class FlowEndpoint: class VariableEndpoint:
def __init__(self, endpoint_path, auth, requestors): def __init__(self, endpoint_path, auth, dispatcher):
self.path = endpoint_path self.path = endpoint_path
self.auth = auth self.auth = auth
self.operation = "service" self.operation = "service"
self.requestors = requestors self.dispatcher = dispatcher
async def start(self): async def start(self):
pass pass
def add_routes(self, app): def add_routes(self, app):
pass
app.add_routes([ app.add_routes([
web.post(self.path, self.handle), web.post(self.path, self.handle),
]) ])
@ -32,15 +31,6 @@ class FlowEndpoint:
print(request.path, "...") print(request.path, "...")
flow_id = request.match_info['flow']
kind = request.match_info['kind']
k = (flow_id, kind)
if k not in self.requestors:
raise web.HTTPBadRequest()
requestor = self.requestors[k]
try: try:
ht = request.headers["Authorization"] ht = request.headers["Authorization"]
tokens = ht.split(" ", 2) tokens = ht.split(" ", 2)
@ -62,7 +52,9 @@ class FlowEndpoint:
async def responder(x, fin): async def responder(x, fin):
print(x) print(x)
resp = await requestor.process(data, responder) resp = await self.dispatcher.process(
data, responder, request.match_info
)
return web.json_response(resp) return web.json_response(resp)

View file

@ -1,64 +0,0 @@
import asyncio
import uuid
from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import GraphEmbeddings, EntityEmbeddings
from .. schema import graph_embeddings_store_queue
from .. base import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph, to_value
class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
def __init__(
self, pulsar_client, auth, path="/api/v1/load/graph-embeddings",
):
super(GraphEmbeddingsLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_client=pulsar_client
self.publisher = Publisher(
self.pulsar_client, graph_embeddings_store_queue,
schema=GraphEmbeddings
)
async def start(self):
await self.publisher.start()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = GraphEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entities=[
EntityEmbeddings(
entity=to_value(ent["entity"]),
vectors=ent["vectors"],
)
for ent in data["entities"]
]
)
await self.publisher.send(None, elt)
running.stop()

View file

@ -1,69 +0,0 @@
import asyncio
import queue
import uuid
from .. schema import GraphEmbeddings
from .. schema import graph_embeddings_store_queue
from .. base import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_graph_embeddings
class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
def __init__(
self, pulsar_client, auth, path="/api/v1/stream/graph-embeddings"
):
super(GraphEmbeddingsStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_client=pulsar_client
self.subscriber = Subscriber(
self.pulsar_client, graph_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=GraphEmbeddings
)
async def listener(self, ws, running):
worker = asyncio.create_task(
self.async_thread(ws, running)
)
await super(GraphEmbeddingsStreamEndpoint, self).listener(ws, running)
await worker
async def start(self):
await self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = await self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.wait_for(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -3,63 +3,22 @@ API gateway. Offers HTTP services which are translated to interaction on the
Pulsar bus. Pulsar bus.
""" """
module = "api-gateway"
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
# are active listeners
# FIXME: Connection errors in publishers / subscribers cause those threads
# to fail and are not failed or retried
import asyncio import asyncio
import argparse import argparse
from aiohttp import web from aiohttp import web
import logging import logging
import os import os
import base64
import uuid
import json
import pulsar
from prometheus_client import start_http_server
from .. log_level import LogLevel from .. log_level import LogLevel
from . serialize import to_subgraph
from . running import Running
from .. schema import ConfigPush, config_push_queue
from . text_completion import TextCompletionRequestor
from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor
#from . document_rag import DocumentRagRequestor
from . triples_query import TriplesQueryRequestor
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
from . embeddings import EmbeddingsRequestor
#from . encyclopedia import EncyclopediaRequestor
from . agent import AgentRequestor
#from . dbpedia import DbpediaRequestor
#from . internet_search import InternetSearchRequestor
#from . librarian import LibrarianRequestor
from . config import ConfigRequestor
from . flow import FlowRequestor
#from . triples_stream import TriplesStreamEndpoint
#from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
#from . document_embeddings_stream import DocumentEmbeddingsStreamEndpoint
#from . triples_load import TriplesLoadEndpoint
#from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
#from . document_embeddings_load import DocumentEmbeddingsLoadEndpoint
from . mux import MuxEndpoint
#from . document_load import DocumentLoadSender
#from . text_load import TextLoadSender
from . metrics import MetricsEndpoint
from . endpoint import ServiceEndpoint
from . flow_endpoint import FlowEndpoint
from . auth import Authenticator from . auth import Authenticator
from .. base import Subscriber from . config.receiver import ConfigReceiver
from .. base import Consumer from . dispatch.manager import DispatcherManager
from . endpoint.manager import EndpointManager
import pulsar
from prometheus_client import start_http_server
logger = logging.getLogger("api") logger = logging.getLogger("api")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -81,6 +40,7 @@ class Api:
self.pulsar_api_key = config.get( self.pulsar_api_key = config.get(
"pulsar_api_key", default_pulsar_api_key "pulsar_api_key", default_pulsar_api_key
) )
self.pulsar_listener = config.get("pulsar_listener", None) self.pulsar_listener = config.get("pulsar_listener", None)
if self.pulsar_api_key: if self.pulsar_api_key:
@ -108,278 +68,24 @@ class Api:
else: else:
self.auth = Authenticator(allow_all=True) self.auth = Authenticator(allow_all=True)
self.services = { self.config_receiver = ConfigReceiver(self.pulsar_client)
# "text-completion": TextCompletionRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout, self.dispatcher_manager = DispatcherManager(
# auth = self.auth, pulsar_client = self.pulsar_client,
# ), config_receiver = self.config_receiver,
# "prompt": PromptRequestor( )
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth, self.endpoint_manager = EndpointManager(
# ), dispatcher_manager = self.dispatcher_manager,
# "graph-rag": GraphRagRequestor( auth = self.auth,
# pulsar_client=self.pulsar_client, timeout=self.timeout, prometheus_url = self.prometheus_url,
# auth = self.auth, timeout = self.timeout,
# ),
# "document-rag": DocumentRagRequestor( )
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "triples-query": TriplesQueryRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "graph-embeddings-query": GraphEmbeddingsQueryRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "embeddings": EmbeddingsRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "agent": AgentRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "librarian": LibrarianRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
(None, "config"): ConfigRequestor(
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
(None, "flow"): FlowRequestor(
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
# "encyclopedia": EncyclopediaRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "dbpedia": DbpediaRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "internet-search": InternetSearchRequestor(
# pulsar_client=self.pulsar_client, timeout=self.timeout,
# auth = self.auth,
# ),
# "document-load": DocumentLoadSender(
# pulsar_client=self.pulsar_client,
# ),
# "text-load": TextLoadSender(
# pulsar_client=self.pulsar_client,
# ),
}
self.endpoints = [ self.endpoints = [
# ServiceEndpoint(
# endpoint_path = "/api/v1/text-completion", auth=self.auth,
# requestor = self.services["text-completion"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/prompt", auth=self.auth,
# requestor = self.services["prompt"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/graph-rag", auth=self.auth,
# requestor = self.services["graph-rag"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/document-rag", auth=self.auth,
# requestor = self.services["document-rag"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/triples-query", auth=self.auth,
# requestor = self.services["triples-query"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/graph-embeddings-query",
# auth=self.auth,
# requestor = self.services["graph-embeddings-query"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/embeddings", auth=self.auth,
# requestor = self.services["embeddings"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/agent", auth=self.auth,
# requestor = self.services["agent"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/librarian", auth=self.auth,
# requestor = self.services["librarian"],
# ),
ServiceEndpoint(
endpoint_path = "/api/v1/config", auth=self.auth,
requestor = self.services[(None, "config")],
),
ServiceEndpoint(
endpoint_path = "/api/v1/flow", auth=self.auth,
requestor = self.services[(None, "flow")],
),
FlowEndpoint(
endpoint_path = "/api/v1/flow/{flow}/{kind}",
auth=self.auth,
requestors = self.services,
),
# ServiceEndpoint(
# endpoint_path = "/api/v1/encyclopedia", auth=self.auth,
# requestor = self.services["encyclopedia"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/dbpedia", auth=self.auth,
# requestor = self.services["dbpedia"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/internet-search", auth=self.auth,
# requestor = self.services["internet-search"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/load/document", auth=self.auth,
# requestor = self.services["document-load"],
# ),
# ServiceEndpoint(
# endpoint_path = "/api/v1/load/text", auth=self.auth,
# requestor = self.services["text-load"],
# ),
# TriplesStreamEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# ),
# GraphEmbeddingsStreamEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# ),
# DocumentEmbeddingsStreamEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# ),
# TriplesLoadEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# ),
# GraphEmbeddingsLoadEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# ),
# DocumentEmbeddingsLoadEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# ),
# MuxEndpoint(
# pulsar_client=self.pulsar_client,
# auth = self.auth,
# services = self.services,
# ),
MetricsEndpoint(
endpoint_path = "/api/v1/metrics",
prometheus_url = self.prometheus_url,
auth = self.auth,
),
] ]
self.flows = {}
async def on_config(self, msg, proc, flow):
try:
v = msg.value()
print(f"Config version", v.version)
if "flows" in v.config:
flows = v.config["flows"]
wanted = list(flows.keys())
current = list(self.flows.keys())
for k in wanted:
if k not in current:
self.flows[k] = json.loads(flows[k])
await self.start_flow(k, self.flows[k])
for k in current:
if k not in wanted:
await self.stop_flow(k, self.flows[k])
del self.flows[k]
except Exception as e:
print(f"Exception: {e}", flush=True)
async def start_flow(self, id, flow):
print("Start flow", id)
intf = flow["interfaces"]
kinds = {
"agent": AgentRequestor,
"text-completion": TextCompletionRequestor,
"prompt": PromptRequestor,
"graph-rag": GraphRagRequestor,
"embeddings": EmbeddingsRequestor,
"graph-embeddings": GraphEmbeddingsQueryRequestor,
"triples-query": TriplesQueryRequestor,
}
for api_kind, requestor in kinds.items():
if api_kind in intf:
k = (id, api_kind)
if k in self.services:
await self.services[k].stop()
del self.services[k]
self.services[k] = requestor(
pulsar_client=self.pulsar_client, timeout=self.timeout,
request_queue = intf[api_kind]["request"],
response_queue = intf[api_kind]["response"],
consumer = f"api-gateway-{id}-{api_kind}-request",
subscriber = f"api-gateway-{id}-{api_kind}-request",
auth = self.auth,
)
await self.services[k].start()
async def stop_flow(self, id, flow):
print("Stop flow", id)
intf = flow["interfaces"]
svc_list = list(self.services.keys())
for k in svc_list:
kid, kkind = k
if id == kid:
await self.services[k].stop()
del self.services[k]
async def config_loader(self):
async with asyncio.TaskGroup() as tg:
id = str(uuid.uuid4())
self.config_cons = Consumer(
taskgroup = tg,
flow = None,
client = self.pulsar_client,
subscriber = f"gateway-{id}",
topic = config_push_queue,
schema = ConfigPush,
handler = self.on_config,
start_of_messages = True,
)
await self.config_cons.start()
print("Waiting...")
print("Config consumer done. :/")
async def app_factory(self): async def app_factory(self):
self.app = web.Application( self.app = web.Application(
@ -387,7 +93,8 @@ class Api:
client_max_size=256 * 1024 * 1024 client_max_size=256 * 1024 * 1024
) )
asyncio.create_task(self.config_loader()) await self.config_receiver.start()
for ep in self.endpoints: for ep in self.endpoints:
ep.add_routes(self.app) ep.add_routes(self.app)
@ -395,6 +102,9 @@ class Api:
for ep in self.endpoints: for ep in self.endpoints:
await ep.start() await ep.start()
self.endpoint_manager.add_routes(self.app)
await self.endpoint_manager.start()
return self.app return self.app
def run(self): def run(self):

View file

@ -1,72 +0,0 @@
import asyncio
from aiohttp import web, WSMsgType
import logging
from . running import Running
logger = logging.getLogger("socket")
logger.setLevel(logging.INFO)
class SocketEndpoint:
def __init__(
self, endpoint_path, auth,
):
self.path = endpoint_path
self.auth = auth
self.operation = "socket"
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.TEXT:
# Ignore incoming message
continue
elif msg.type == WSMsgType.BINARY:
# Ignore incoming message
continue
else:
break
running.stop()
async def handle(self, request):
try:
token = request.query['token']
except:
token = ""
if not self.auth.permitted(token, self.operation):
return web.HTTPUnauthorized()
running = Running()
# 50MB max message size
ws = web.WebSocketResponse(max_msg_size=52428800)
await ws.prepare(request)
try:
await self.listener(ws, running)
except Exception as e:
print("Socket exception:", e, flush=True)
running.stop()
await ws.close()
return ws
async def start(self):
pass
def add_routes(self, app):
app.add_routes([
web.get(self.path, self.handle),
])

View file

@ -1,56 +0,0 @@
import asyncio
import uuid
from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import Triples
from .. schema import triples_store_queue
from .. base import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph
class TriplesLoadEndpoint(SocketEndpoint):
def __init__(self, pulsar_client, auth, path="/api/v1/load/triples"):
super(TriplesLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_client=pulsar_client
self.publisher = Publisher(
self.pulsar_client, triples_store_queue,
schema=Triples
)
async def start(self):
await self.publisher.start()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = Triples(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
triples=to_subgraph(data["triples"]),
)
await self.publisher.send(None, elt)
running.stop()

View file

@ -1,67 +0,0 @@
import asyncio
import queue
import uuid
from .. schema import Triples
from .. schema import triples_store_queue
from .. base import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_triples
class TriplesStreamEndpoint(SocketEndpoint):
def __init__(self, pulsar_client, auth, path="/api/v1/stream/triples"):
super(TriplesStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_client=pulsar_client
self.subscriber = Subscriber(
self.pulsar_client, triples_store_queue,
"api-gateway", "api-gateway",
schema=Triples
)
async def listener(self, ws, running):
worker = asyncio.create_task(
self.async_thread(ws, running)
)
await super(TriplesStreamEndpoint, self).listener(ws, running)
await worker
async def start(self):
await self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_triples(resp))
except TimeoutError:
continue
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
self.subscriber.unsubscribe_all(id)
running.stop()