mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-29 17:25:15 +02:00
Feature/flow api 3 (#358)
* Working mux socket * Change API to incorporate flow * Add Flow ID to all relevant CLIs, not completely implemented * Change tg-processor-state to use API gateway * Updated all CLIs * New tg-show-flow-state command * tg-show-flow-state shows classes too
This commit is contained in:
parent
a70ae9793a
commit
3b8b9ea866
23 changed files with 800 additions and 986 deletions
|
|
@ -6,10 +6,12 @@ from ... schema import config_response_queue
|
|||
from . requestor import ServiceRequestor
|
||||
|
||||
class ConfigRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout=120):
|
||||
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
|
||||
|
||||
super(ConfigRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
consumer_name = consumer,
|
||||
subscription = subscriber,
|
||||
request_queue=config_request_queue,
|
||||
response_queue=config_response_queue,
|
||||
request_schema=ConfigRequest,
|
||||
|
|
|
|||
|
|
@ -6,10 +6,12 @@ from ... schema import flow_response_queue
|
|||
from . requestor import ServiceRequestor
|
||||
|
||||
class FlowRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout=120):
|
||||
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
|
||||
|
||||
super(FlowRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
consumer_name = consumer,
|
||||
subscription = subscriber,
|
||||
request_queue=flow_request_queue,
|
||||
response_queue=flow_response_queue,
|
||||
request_schema=FlowRequest,
|
||||
|
|
|
|||
|
|
@ -8,10 +8,12 @@ from . serialize import serialize_document_package, serialize_document_info
|
|||
from . serialize import to_document_package, to_document_info, to_criteria
|
||||
|
||||
class LibrarianRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout=120):
|
||||
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
|
||||
|
||||
super(LibrarianRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
consumer_name = consumer,
|
||||
subscription = subscriber,
|
||||
request_queue=librarian_request_queue,
|
||||
response_queue=librarian_response_queue,
|
||||
request_schema=LibrarianRequest,
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ from . triples_import import TriplesImport
|
|||
from . graph_embeddings_import import GraphEmbeddingsImport
|
||||
from . document_embeddings_import import DocumentEmbeddingsImport
|
||||
|
||||
from . mux import Mux
|
||||
|
||||
request_response_dispatchers = {
|
||||
"agent": AgentRequestor,
|
||||
"text-completion": TextCompletionRequestor,
|
||||
|
|
@ -35,7 +37,13 @@ request_response_dispatchers = {
|
|||
"document-rag": DocumentRagRequestor,
|
||||
"embeddings": EmbeddingsRequestor,
|
||||
"graph-embeddings": GraphEmbeddingsQueryRequestor,
|
||||
"triples-query": TriplesQueryRequestor,
|
||||
"triples": TriplesQueryRequestor,
|
||||
}
|
||||
|
||||
global_dispatchers = {
|
||||
"config": ConfigRequestor,
|
||||
"flow": FlowRequestor,
|
||||
"librarian": LibrarianRequestor,
|
||||
}
|
||||
|
||||
sender_dispatchers = {
|
||||
|
|
@ -56,14 +64,10 @@ import_dispatchers = {
|
|||
}
|
||||
|
||||
class DispatcherWrapper:
|
||||
def __init__(self, mgr, name, impl):
|
||||
self.mgr = mgr
|
||||
self.name = name
|
||||
self.impl = impl
|
||||
async def process(self, data, responder):
|
||||
return await self.mgr.process_impl(
|
||||
data, responder, self.name, self.impl
|
||||
)
|
||||
def __init__(self, handler):
|
||||
self.handler = handler
|
||||
async def process(self, *args):
|
||||
return await self.handler(*args)
|
||||
|
||||
class DispatcherManager:
|
||||
|
||||
|
|
@ -85,24 +89,26 @@ class DispatcherManager:
|
|||
del self.flows[id]
|
||||
return
|
||||
|
||||
def dispatch_config(self):
|
||||
return DispatcherWrapper(self, "config", ConfigRequestor)
|
||||
def dispatch_global_service(self):
|
||||
return DispatcherWrapper(self.process_global_service)
|
||||
|
||||
def dispatch_flow(self):
|
||||
return DispatcherWrapper(self, "flow", FlowRequestor)
|
||||
async def process_global_service(self, data, responder, params):
|
||||
|
||||
def dispatch_librarian(self):
|
||||
return DispatcherWrapper(self, "librarian", LibrarianRequestor)
|
||||
kind = params.get("kind")
|
||||
return await self.invoke_global_service(data, responder, kind)
|
||||
|
||||
async def process_impl(self, data, responder, name, impl):
|
||||
async def invoke_global_service(self, data, responder, kind):
|
||||
|
||||
key = (None, name)
|
||||
key = (None, kind)
|
||||
|
||||
if key in self.dispatchers:
|
||||
return await self.dispatchers[key].process(data, responder)
|
||||
|
||||
dispatcher = impl(
|
||||
pulsar_client = self.pulsar_client
|
||||
dispatcher = global_dispatchers[kind](
|
||||
pulsar_client = self.pulsar_client,
|
||||
timeout = 120,
|
||||
consumer = f"api-gateway-{kind}-request",
|
||||
subscriber = f"api-gateway-{kind}-request",
|
||||
)
|
||||
|
||||
await dispatcher.start()
|
||||
|
|
@ -111,16 +117,19 @@ class DispatcherManager:
|
|||
|
||||
return await dispatcher.process(data, responder)
|
||||
|
||||
def dispatch_service(self):
|
||||
return self
|
||||
def dispatch_flow_import(self):
|
||||
return self.process_flow_import
|
||||
|
||||
def dispatch_import(self):
|
||||
return self.invoke_import
|
||||
def dispatch_flow_export(self):
|
||||
return self.process_flow_export
|
||||
|
||||
def dispatch_export(self):
|
||||
return self.invoke_export
|
||||
def dispatch_socket(self):
|
||||
return self.process_socket
|
||||
|
||||
async def invoke_import(self, ws, running, params):
|
||||
def dispatch_flow_service(self):
|
||||
return DispatcherWrapper(self.process_flow_service)
|
||||
|
||||
async def process_flow_import(self, ws, running, params):
|
||||
|
||||
flow = params.get("flow")
|
||||
kind = params.get("kind")
|
||||
|
|
@ -151,7 +160,7 @@ class DispatcherManager:
|
|||
|
||||
return dispatcher
|
||||
|
||||
async def invoke_export(self, ws, running, params):
|
||||
async def process_flow_export(self, ws, running, params):
|
||||
|
||||
flow = params.get("flow")
|
||||
kind = params.get("kind")
|
||||
|
|
@ -184,11 +193,21 @@ class DispatcherManager:
|
|||
|
||||
return dispatcher
|
||||
|
||||
async def process(self, data, responder, params):
|
||||
async def process_socket(self, ws, running, params):
|
||||
|
||||
dispatcher = Mux(self, ws, running)
|
||||
|
||||
return dispatcher
|
||||
|
||||
async def process_flow_service(self, data, responder, params):
|
||||
|
||||
flow = params.get("flow")
|
||||
kind = params.get("kind")
|
||||
|
||||
return await self.invoke_flow_service(data, responder, flow, kind)
|
||||
|
||||
async def invoke_flow_service(self, data, responder, flow, kind):
|
||||
|
||||
if flow not in self.flows:
|
||||
raise RuntimeError("Invalid flow")
|
||||
|
||||
|
|
|
|||
167
trustgraph-flow/trustgraph/gateway/dispatch/mux.py
Normal file
167
trustgraph-flow/trustgraph/gateway/dispatch/mux.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
import uuid
|
||||
|
||||
MAX_OUTSTANDING_REQUESTS = 15
|
||||
WORKER_CLOSE_WAIT = 0.01
|
||||
START_REQUEST_WAIT = 0.1
|
||||
|
||||
# This buffers requests until task start, so short-lived
|
||||
MAX_QUEUE_SIZE = 10
|
||||
|
||||
class Mux:
|
||||
|
||||
def __init__(self, dispatcher_manager, ws, running):
|
||||
|
||||
self.dispatcher_manager = dispatcher_manager
|
||||
self.ws = ws
|
||||
self.running = running
|
||||
|
||||
self.q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
|
||||
|
||||
async def destroy(self):
|
||||
|
||||
self.running.stop()
|
||||
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
async def receive(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
data = msg.json()
|
||||
|
||||
if "request" not in data:
|
||||
raise RuntimeError("Bad message")
|
||||
|
||||
if "id" not in data:
|
||||
raise RuntimeError("Bad message")
|
||||
|
||||
await self.q.put((
|
||||
data["id"], data.get("flow"),
|
||||
data["service"],
|
||||
data["request"]
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
print("receive exception:", str(e), flush=True)
|
||||
await self.ws.send_json({"error": str(e)})
|
||||
|
||||
async def maybe_tidy_workers(self, workers):
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(workers[0]),
|
||||
WORKER_CLOSE_WAIT
|
||||
)
|
||||
|
||||
# worker[0] now stopped
|
||||
# FIXME: Delete reference???
|
||||
|
||||
workers.pop(0)
|
||||
|
||||
if len(workers) == 0:
|
||||
break
|
||||
|
||||
# Loop iterates to try the next worker
|
||||
|
||||
except TimeoutError:
|
||||
# worker[0] still running, move on
|
||||
break
|
||||
|
||||
async def start_request_task(self, ws, id, flow, svc, request, workers):
|
||||
|
||||
# Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS
|
||||
while len(workers) > MAX_OUTSTANDING_REQUESTS:
|
||||
|
||||
# Fixes deadlock
|
||||
# FIXME: Put it in its own loop
|
||||
await asyncio.sleep(START_REQUEST_WAIT)
|
||||
|
||||
await self.maybe_tidy_workers(workers)
|
||||
|
||||
async def responder(resp, fin):
|
||||
await self.ws.send_json({
|
||||
"id": id,
|
||||
"response": resp,
|
||||
"complete": fin,
|
||||
})
|
||||
|
||||
worker = asyncio.create_task(
|
||||
self.request_task(request, responder, flow, svc)
|
||||
)
|
||||
|
||||
workers.append(worker)
|
||||
|
||||
async def request_task(self, request, responder, flow, svc):
|
||||
|
||||
try:
|
||||
|
||||
if flow:
|
||||
|
||||
await self.dispatcher_manager.invoke_flow_service(
|
||||
request, responder, flow, svc
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
await self.dispatcher_manager.invoke_global_service(
|
||||
request, responder, svc
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await self.ws.send_json({"error": str(e)})
|
||||
|
||||
async def run(self):
|
||||
|
||||
# Worker threads, servicing
|
||||
workers = []
|
||||
|
||||
while self.running.get():
|
||||
|
||||
try:
|
||||
|
||||
if len(workers) > 0:
|
||||
await self.maybe_tidy_workers(workers)
|
||||
|
||||
# Get next request on queue
|
||||
item = await asyncio.wait_for(self.q.get(), 1)
|
||||
id, flow, svc, request = item
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
# This is an internal working error, may not be recoverable
|
||||
print("run prepare exception:", e)
|
||||
await self.ws.send_json({"id": id, "error": str(e)})
|
||||
self.running.stop()
|
||||
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
self.ws = None
|
||||
|
||||
break
|
||||
|
||||
try:
|
||||
print(id, svc, request)
|
||||
|
||||
await self.start_request_task(
|
||||
self.ws, id, flow, svc, request, workers
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception2:", e)
|
||||
await self.ws.send_json({"error": str(e)})
|
||||
|
||||
self.running.stop()
|
||||
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
self.ws = None
|
||||
|
||||
|
|
@ -23,37 +23,34 @@ class EndpointManager:
|
|||
}
|
||||
|
||||
self.endpoints = [
|
||||
ConstantEndpoint(
|
||||
endpoint_path = "/api/v1/librarian", auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_librarian(),
|
||||
),
|
||||
ConstantEndpoint(
|
||||
endpoint_path = "/api/v1/config", auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_config(),
|
||||
),
|
||||
ConstantEndpoint(
|
||||
endpoint_path = "/api/v1/flow", auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_flow(),
|
||||
),
|
||||
MetricsEndpoint(
|
||||
endpoint_path = "/api/v1/metrics",
|
||||
endpoint_path = "/api/metrics",
|
||||
prometheus_url = prometheus_url,
|
||||
auth = auth,
|
||||
),
|
||||
VariableEndpoint(
|
||||
endpoint_path = "/api/v1/{kind}", auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_global_service(),
|
||||
),
|
||||
SocketEndpoint(
|
||||
endpoint_path = "/api/v1/socket",
|
||||
auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_socket()
|
||||
),
|
||||
VariableEndpoint(
|
||||
endpoint_path = "/api/v1/flow/{flow}/service/{kind}",
|
||||
auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_service(),
|
||||
dispatcher = dispatcher_manager.dispatch_flow_service(),
|
||||
),
|
||||
SocketEndpoint(
|
||||
endpoint_path = "/api/v1/flow/{flow}/import/{kind}",
|
||||
auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_import()
|
||||
dispatcher = dispatcher_manager.dispatch_flow_import()
|
||||
),
|
||||
SocketEndpoint(
|
||||
endpoint_path = "/api/v1/flow/{flow}/export/{kind}",
|
||||
auth = auth,
|
||||
dispatcher = dispatcher_manager.dispatch_export()
|
||||
dispatcher = dispatcher_manager.dispatch_flow_export()
|
||||
),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,167 +0,0 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
import uuid
|
||||
from aiohttp import web, WSMsgType
|
||||
|
||||
from . socket import SocketEndpoint
|
||||
|
||||
MAX_OUTSTANDING_REQUESTS = 15
|
||||
WORKER_CLOSE_WAIT = 0.01
|
||||
START_REQUEST_WAIT = 0.1
|
||||
|
||||
# This buffers requests until task start, so short-lived
|
||||
MAX_QUEUE_SIZE = 10
|
||||
|
||||
class MuxEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(
|
||||
self, pulsar_client, auth,
|
||||
services,
|
||||
path="/api/v1/socket",
|
||||
):
|
||||
|
||||
super(MuxEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.services = services
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
async def maybe_tidy_workers(self, workers):
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(workers[0]),
|
||||
WORKER_CLOSE_WAIT
|
||||
)
|
||||
|
||||
# worker[0] now stopped
|
||||
# FIXME: Delete reference???
|
||||
|
||||
workers.pop(0)
|
||||
|
||||
if len(workers) == 0:
|
||||
break
|
||||
|
||||
# Loop iterates to try the next worker
|
||||
|
||||
except TimeoutError:
|
||||
# worker[0] still running, move on
|
||||
break
|
||||
|
||||
async def start_request_task(self, ws, id, svc, request, workers):
|
||||
|
||||
if svc not in self.services:
|
||||
await ws.send_json({"id": id, "error": "Service not recognised"})
|
||||
return
|
||||
|
||||
requestor = self.services[svc]
|
||||
|
||||
async def responder(resp, fin):
|
||||
await ws.send_json({
|
||||
"id": id,
|
||||
"response": resp,
|
||||
"complete": fin,
|
||||
})
|
||||
|
||||
# Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS
|
||||
while len(workers) > MAX_OUTSTANDING_REQUESTS:
|
||||
|
||||
# Fixes deadlock
|
||||
# FIXME: Put it in its own loop
|
||||
await asyncio.sleep(START_REQUEST_WAIT)
|
||||
|
||||
await self.maybe_tidy_workers(workers)
|
||||
|
||||
worker = asyncio.create_task(
|
||||
requestor.process(request, responder)
|
||||
)
|
||||
|
||||
workers.append(worker)
|
||||
|
||||
async def async_thread(self, ws, running, q):
|
||||
|
||||
# Worker threads, servicing
|
||||
workers = []
|
||||
|
||||
while running.get():
|
||||
|
||||
try:
|
||||
|
||||
if len(workers) > 0:
|
||||
await self.maybe_tidy_workers(workers)
|
||||
|
||||
# Get next request on queue
|
||||
id, svc, request = await asyncio.wait_for(q.get(), 1)
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
# This is an internal working error, may not be recoverable
|
||||
print("Exception:", e)
|
||||
await ws.send_json({"id": id, "error": str(e)})
|
||||
break
|
||||
|
||||
try:
|
||||
print(id, svc, request)
|
||||
await self.start_request_task(ws, id, svc, request, workers)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception2:", e)
|
||||
await ws.send_json({"error": str(e)})
|
||||
|
||||
running.stop()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
# The outstanding request queue, max size is MAX_QUEUE_SIZE
|
||||
q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
|
||||
|
||||
async_task = asyncio.create_task(self.async_thread(
|
||||
ws, running, q
|
||||
))
|
||||
|
||||
async for msg in ws:
|
||||
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
|
||||
try:
|
||||
|
||||
data = msg.json()
|
||||
|
||||
if data["service"] not in self.services:
|
||||
raise RuntimeError("Bad service")
|
||||
|
||||
if "request" not in data:
|
||||
raise RuntimeError("Bad message")
|
||||
|
||||
if "id" not in data:
|
||||
raise RuntimeError("Bad message")
|
||||
|
||||
await q.put(
|
||||
(data["id"], data["service"], data["request"])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
await ws.send_json({"error": str(e)})
|
||||
continue
|
||||
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
break
|
||||
elif msg.type == WSMsgType.CLOSE:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
running.stop()
|
||||
|
||||
await async_task
|
||||
Loading…
Add table
Add a link
Reference in a new issue