Feature/flow api 3 (#358)

* Working mux socket

* Change API to incorporate flow

* Add Flow ID to all relevant CLIs, not completely implemented

* Change tg-processor-state to use API gateway

* Updated all CLIs

* New tg-show-flow-state command

* tg-show-flow-state shows classes too
This commit is contained in:
cybermaggedon 2025-05-03 10:39:53 +01:00 committed by GitHub
parent a70ae9793a
commit 3b8b9ea866
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 800 additions and 986 deletions

View file

@ -6,10 +6,12 @@ from ... schema import config_response_queue
from . requestor import ServiceRequestor
class ConfigRequestor(ServiceRequestor):
def __init__(self, pulsar_client, timeout=120):
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
super(ConfigRequestor, self).__init__(
pulsar_client=pulsar_client,
consumer_name = consumer,
subscription = subscriber,
request_queue=config_request_queue,
response_queue=config_response_queue,
request_schema=ConfigRequest,

View file

@ -6,10 +6,12 @@ from ... schema import flow_response_queue
from . requestor import ServiceRequestor
class FlowRequestor(ServiceRequestor):
def __init__(self, pulsar_client, timeout=120):
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
super(FlowRequestor, self).__init__(
pulsar_client=pulsar_client,
consumer_name = consumer,
subscription = subscriber,
request_queue=flow_request_queue,
response_queue=flow_response_queue,
request_schema=FlowRequest,

View file

@ -8,10 +8,12 @@ from . serialize import serialize_document_package, serialize_document_info
from . serialize import to_document_package, to_document_info, to_criteria
class LibrarianRequestor(ServiceRequestor):
def __init__(self, pulsar_client, timeout=120):
def __init__(self, pulsar_client, consumer, subscriber, timeout=120):
super(LibrarianRequestor, self).__init__(
pulsar_client=pulsar_client,
consumer_name = consumer,
subscription = subscriber,
request_queue=librarian_request_queue,
response_queue=librarian_response_queue,
request_schema=LibrarianRequest,

View file

@ -27,6 +27,8 @@ from . triples_import import TriplesImport
from . graph_embeddings_import import GraphEmbeddingsImport
from . document_embeddings_import import DocumentEmbeddingsImport
from . mux import Mux
request_response_dispatchers = {
"agent": AgentRequestor,
"text-completion": TextCompletionRequestor,
@ -35,7 +37,13 @@ request_response_dispatchers = {
"document-rag": DocumentRagRequestor,
"embeddings": EmbeddingsRequestor,
"graph-embeddings": GraphEmbeddingsQueryRequestor,
"triples-query": TriplesQueryRequestor,
"triples": TriplesQueryRequestor,
}
global_dispatchers = {
"config": ConfigRequestor,
"flow": FlowRequestor,
"librarian": LibrarianRequestor,
}
sender_dispatchers = {
@ -56,14 +64,10 @@ import_dispatchers = {
}
class DispatcherWrapper:
def __init__(self, mgr, name, impl):
self.mgr = mgr
self.name = name
self.impl = impl
async def process(self, data, responder):
return await self.mgr.process_impl(
data, responder, self.name, self.impl
)
def __init__(self, handler):
self.handler = handler
async def process(self, *args):
return await self.handler(*args)
class DispatcherManager:
@ -85,24 +89,26 @@ class DispatcherManager:
del self.flows[id]
return
def dispatch_config(self):
return DispatcherWrapper(self, "config", ConfigRequestor)
def dispatch_global_service(self):
return DispatcherWrapper(self.process_global_service)
def dispatch_flow(self):
return DispatcherWrapper(self, "flow", FlowRequestor)
async def process_global_service(self, data, responder, params):
def dispatch_librarian(self):
return DispatcherWrapper(self, "librarian", LibrarianRequestor)
kind = params.get("kind")
return await self.invoke_global_service(data, responder, kind)
async def process_impl(self, data, responder, name, impl):
async def invoke_global_service(self, data, responder, kind):
key = (None, name)
key = (None, kind)
if key in self.dispatchers:
return await self.dispatchers[key].process(data, responder)
dispatcher = impl(
pulsar_client = self.pulsar_client
dispatcher = global_dispatchers[kind](
pulsar_client = self.pulsar_client,
timeout = 120,
consumer = f"api-gateway-{kind}-request",
subscriber = f"api-gateway-{kind}-request",
)
await dispatcher.start()
@ -111,16 +117,19 @@ class DispatcherManager:
return await dispatcher.process(data, responder)
def dispatch_service(self):
return self
def dispatch_flow_import(self):
return self.process_flow_import
def dispatch_import(self):
return self.invoke_import
def dispatch_flow_export(self):
return self.process_flow_export
def dispatch_export(self):
return self.invoke_export
def dispatch_socket(self):
return self.process_socket
async def invoke_import(self, ws, running, params):
def dispatch_flow_service(self):
return DispatcherWrapper(self.process_flow_service)
async def process_flow_import(self, ws, running, params):
flow = params.get("flow")
kind = params.get("kind")
@ -151,7 +160,7 @@ class DispatcherManager:
return dispatcher
async def invoke_export(self, ws, running, params):
async def process_flow_export(self, ws, running, params):
flow = params.get("flow")
kind = params.get("kind")
@ -184,11 +193,21 @@ class DispatcherManager:
return dispatcher
async def process(self, data, responder, params):
async def process_socket(self, ws, running, params):
dispatcher = Mux(self, ws, running)
return dispatcher
async def process_flow_service(self, data, responder, params):
flow = params.get("flow")
kind = params.get("kind")
return await self.invoke_flow_service(data, responder, flow, kind)
async def invoke_flow_service(self, data, responder, flow, kind):
if flow not in self.flows:
raise RuntimeError("Invalid flow")

View file

@ -0,0 +1,167 @@
import asyncio
import queue
import uuid
MAX_OUTSTANDING_REQUESTS = 15
WORKER_CLOSE_WAIT = 0.01
START_REQUEST_WAIT = 0.1
# This buffers requests until task start, so short-lived
MAX_QUEUE_SIZE = 10
class Mux:
def __init__(self, dispatcher_manager, ws, running):
self.dispatcher_manager = dispatcher_manager
self.ws = ws
self.running = running
self.q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
async def destroy(self):
self.running.stop()
if self.ws:
await self.ws.close()
async def receive(self, msg):
try:
data = msg.json()
if "request" not in data:
raise RuntimeError("Bad message")
if "id" not in data:
raise RuntimeError("Bad message")
await self.q.put((
data["id"], data.get("flow"),
data["service"],
data["request"]
))
except Exception as e:
print("receive exception:", str(e), flush=True)
await self.ws.send_json({"error": str(e)})
async def maybe_tidy_workers(self, workers):
while True:
try:
await asyncio.wait_for(
asyncio.shield(workers[0]),
WORKER_CLOSE_WAIT
)
# worker[0] now stopped
# FIXME: Delete reference???
workers.pop(0)
if len(workers) == 0:
break
# Loop iterates to try the next worker
except TimeoutError:
# worker[0] still running, move on
break
async def start_request_task(self, ws, id, flow, svc, request, workers):
# Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS
while len(workers) > MAX_OUTSTANDING_REQUESTS:
# Fixes deadlock
# FIXME: Put it in its own loop
await asyncio.sleep(START_REQUEST_WAIT)
await self.maybe_tidy_workers(workers)
async def responder(resp, fin):
await self.ws.send_json({
"id": id,
"response": resp,
"complete": fin,
})
worker = asyncio.create_task(
self.request_task(request, responder, flow, svc)
)
workers.append(worker)
async def request_task(self, request, responder, flow, svc):
try:
if flow:
await self.dispatcher_manager.invoke_flow_service(
request, responder, flow, svc
)
else:
await self.dispatcher_manager.invoke_global_service(
request, responder, svc
)
except Exception as e:
await self.ws.send_json({"error": str(e)})
async def run(self):
# Worker threads, servicing
workers = []
while self.running.get():
try:
if len(workers) > 0:
await self.maybe_tidy_workers(workers)
# Get next request on queue
item = await asyncio.wait_for(self.q.get(), 1)
id, flow, svc, request = item
except TimeoutError:
continue
except Exception as e:
# This is an internal working error, may not be recoverable
print("run prepare exception:", e)
await self.ws.send_json({"id": id, "error": str(e)})
self.running.stop()
if self.ws:
self.ws.close()
self.ws = None
break
try:
print(id, svc, request)
await self.start_request_task(
self.ws, id, flow, svc, request, workers
)
except Exception as e:
print("Exception2:", e)
await self.ws.send_json({"error": str(e)})
self.running.stop()
if self.ws:
self.ws.close()
self.ws = None

View file

@ -23,37 +23,34 @@ class EndpointManager:
}
self.endpoints = [
ConstantEndpoint(
endpoint_path = "/api/v1/librarian", auth = auth,
dispatcher = dispatcher_manager.dispatch_librarian(),
),
ConstantEndpoint(
endpoint_path = "/api/v1/config", auth = auth,
dispatcher = dispatcher_manager.dispatch_config(),
),
ConstantEndpoint(
endpoint_path = "/api/v1/flow", auth = auth,
dispatcher = dispatcher_manager.dispatch_flow(),
),
MetricsEndpoint(
endpoint_path = "/api/v1/metrics",
endpoint_path = "/api/metrics",
prometheus_url = prometheus_url,
auth = auth,
),
VariableEndpoint(
endpoint_path = "/api/v1/{kind}", auth = auth,
dispatcher = dispatcher_manager.dispatch_global_service(),
),
SocketEndpoint(
endpoint_path = "/api/v1/socket",
auth = auth,
dispatcher = dispatcher_manager.dispatch_socket()
),
VariableEndpoint(
endpoint_path = "/api/v1/flow/{flow}/service/{kind}",
auth = auth,
dispatcher = dispatcher_manager.dispatch_service(),
dispatcher = dispatcher_manager.dispatch_flow_service(),
),
SocketEndpoint(
endpoint_path = "/api/v1/flow/{flow}/import/{kind}",
auth = auth,
dispatcher = dispatcher_manager.dispatch_import()
dispatcher = dispatcher_manager.dispatch_flow_import()
),
SocketEndpoint(
endpoint_path = "/api/v1/flow/{flow}/export/{kind}",
auth = auth,
dispatcher = dispatcher_manager.dispatch_export()
dispatcher = dispatcher_manager.dispatch_flow_export()
),
]

View file

@ -1,167 +0,0 @@
import asyncio
import queue
import uuid
from aiohttp import web, WSMsgType
from . socket import SocketEndpoint
MAX_OUTSTANDING_REQUESTS = 15
WORKER_CLOSE_WAIT = 0.01
START_REQUEST_WAIT = 0.1
# This buffers requests until task start, so short-lived
MAX_QUEUE_SIZE = 10
class MuxEndpoint(SocketEndpoint):
def __init__(
self, pulsar_client, auth,
services,
path="/api/v1/socket",
):
super(MuxEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.services = services
async def start(self):
pass
async def maybe_tidy_workers(self, workers):
while True:
try:
await asyncio.wait_for(
asyncio.shield(workers[0]),
WORKER_CLOSE_WAIT
)
# worker[0] now stopped
# FIXME: Delete reference???
workers.pop(0)
if len(workers) == 0:
break
# Loop iterates to try the next worker
except TimeoutError:
# worker[0] still running, move on
break
async def start_request_task(self, ws, id, svc, request, workers):
if svc not in self.services:
await ws.send_json({"id": id, "error": "Service not recognised"})
return
requestor = self.services[svc]
async def responder(resp, fin):
await ws.send_json({
"id": id,
"response": resp,
"complete": fin,
})
# Wait for outstanding requests to go below MAX_OUTSTANDING_REQUESTS
while len(workers) > MAX_OUTSTANDING_REQUESTS:
# Fixes deadlock
# FIXME: Put it in its own loop
await asyncio.sleep(START_REQUEST_WAIT)
await self.maybe_tidy_workers(workers)
worker = asyncio.create_task(
requestor.process(request, responder)
)
workers.append(worker)
async def async_thread(self, ws, running, q):
# Worker threads, servicing
workers = []
while running.get():
try:
if len(workers) > 0:
await self.maybe_tidy_workers(workers)
# Get next request on queue
id, svc, request = await asyncio.wait_for(q.get(), 1)
except TimeoutError:
continue
except Exception as e:
# This is an internal working error, may not be recoverable
print("Exception:", e)
await ws.send_json({"id": id, "error": str(e)})
break
try:
print(id, svc, request)
await self.start_request_task(ws, id, svc, request, workers)
except Exception as e:
print("Exception2:", e)
await ws.send_json({"error": str(e)})
running.stop()
async def listener(self, ws, running):
# The outstanding request queue, max size is MAX_QUEUE_SIZE
q = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
async_task = asyncio.create_task(self.async_thread(
ws, running, q
))
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.TEXT:
try:
data = msg.json()
if data["service"] not in self.services:
raise RuntimeError("Bad service")
if "request" not in data:
raise RuntimeError("Bad message")
if "id" not in data:
raise RuntimeError("Bad message")
await q.put(
(data["id"], data["service"], data["request"])
)
except Exception as e:
await ws.send_json({"error": str(e)})
continue
elif msg.type == WSMsgType.ERROR:
break
elif msg.type == WSMsgType.CLOSE:
break
else:
break
running.stop()
await async_task