mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-28 01:46:22 +02:00
Feature/flow enable api gateway (#356)
* Tweak timeouts, reduce stop time for publishers / subscribers * More APIs working as flow endpoint
This commit is contained in:
parent
027b52cd7c
commit
450f664b1b
19 changed files with 303 additions and 76 deletions
|
|
@ -1,20 +1,23 @@
|
|||
|
||||
from .. schema import AgentRequest, AgentResponse
|
||||
from .. schema import agent_request_queue
|
||||
from .. schema import agent_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class AgentRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout, auth):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(AgentRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
request_queue=agent_request_queue,
|
||||
response_queue=agent_response_queue,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=AgentRequest,
|
||||
response_schema=AgentResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,23 @@
|
|||
|
||||
from .. schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from .. schema import embeddings_request_queue
|
||||
from .. schema import embeddings_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class EmbeddingsRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout, auth):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(EmbeddingsRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
request_queue=embeddings_request_queue,
|
||||
response_queue=embeddings_response_queue,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=EmbeddingsRequest,
|
||||
response_schema=EmbeddingsResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
|
|||
75
trustgraph-flow/trustgraph/gateway/flow_endpoint.py
Normal file
75
trustgraph-flow/trustgraph/gateway/flow_endpoint.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
|
||||
import asyncio
|
||||
from aiohttp import web
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("flow-endpoint")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class FlowEndpoint:
|
||||
|
||||
def __init__(self, endpoint_path, auth, requestors):
|
||||
|
||||
self.path = endpoint_path
|
||||
|
||||
self.auth = auth
|
||||
self.operation = "service"
|
||||
|
||||
self.requestors = requestors
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
def add_routes(self, app):
|
||||
|
||||
pass
|
||||
app.add_routes([
|
||||
web.post(self.path, self.handle),
|
||||
])
|
||||
|
||||
async def handle(self, request):
|
||||
|
||||
print(request.path, "...")
|
||||
|
||||
flow_id = request.match_info['flow']
|
||||
kind = request.match_info['kind']
|
||||
k = (flow_id, kind)
|
||||
|
||||
if k not in self.requestors:
|
||||
raise web.HTTPBadRequest()
|
||||
|
||||
requestor = self.requestors[k]
|
||||
|
||||
try:
|
||||
ht = request.headers["Authorization"]
|
||||
tokens = ht.split(" ", 2)
|
||||
if tokens[0] != "Bearer":
|
||||
return web.HTTPUnauthorized()
|
||||
token = tokens[1]
|
||||
except:
|
||||
token = ""
|
||||
|
||||
if not self.auth.permitted(token, self.operation):
|
||||
return web.HTTPUnauthorized()
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
print(data)
|
||||
|
||||
async def responder(x, fin):
|
||||
print(x)
|
||||
|
||||
resp = await requestor.process(data, responder)
|
||||
|
||||
return web.json_response(resp)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
|
|
@ -1,21 +1,24 @@
|
|||
|
||||
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from .. schema import graph_embeddings_request_queue
|
||||
from .. schema import graph_embeddings_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
from . serialize import serialize_value
|
||||
|
||||
class GraphEmbeddingsQueryRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout, auth):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(GraphEmbeddingsQueryRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
request_queue=graph_embeddings_request_queue,
|
||||
response_queue=graph_embeddings_response_queue,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=GraphEmbeddingsRequest,
|
||||
response_schema=GraphEmbeddingsResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,23 @@
|
|||
|
||||
from .. schema import GraphRagQuery, GraphRagResponse
|
||||
from .. schema import graph_rag_request_queue
|
||||
from .. schema import graph_rag_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class GraphRagRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout, auth):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(GraphRagRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
request_queue=graph_rag_request_queue,
|
||||
response_queue=graph_rag_response_queue,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=GraphRagQuery,
|
||||
response_schema=GraphRagResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,21 +2,24 @@
|
|||
import json
|
||||
|
||||
from .. schema import PromptRequest, PromptResponse
|
||||
from .. schema import prompt_request_queue
|
||||
from .. schema import prompt_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class PromptRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout, auth):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(PromptRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
request_queue=prompt_request_queue,
|
||||
response_queue=prompt_response_queue,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=PromptRequest,
|
||||
response_schema=PromptResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,10 +34,13 @@ class ServiceRequestor:
|
|||
self.timeout = timeout
|
||||
|
||||
async def start(self):
|
||||
|
||||
await self.pub.start()
|
||||
await self.sub.start()
|
||||
|
||||
async def stop(self):
|
||||
await self.pub.stop()
|
||||
await self.sub.stop()
|
||||
|
||||
def to_request(self, request):
|
||||
raise RuntimeError("Not defined")
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ from aiohttp import web
|
|||
import logging
|
||||
import os
|
||||
import base64
|
||||
import uuid
|
||||
import json
|
||||
|
||||
import pulsar
|
||||
from prometheus_client import start_http_server
|
||||
|
|
@ -26,15 +28,17 @@ from .. log_level import LogLevel
|
|||
from . serialize import to_subgraph
|
||||
from . running import Running
|
||||
|
||||
#from . text_completion import TextCompletionRequestor
|
||||
#from . prompt import PromptRequestor
|
||||
#from . graph_rag import GraphRagRequestor
|
||||
from .. schema import ConfigPush, config_push_queue
|
||||
|
||||
from . text_completion import TextCompletionRequestor
|
||||
from . prompt import PromptRequestor
|
||||
from . graph_rag import GraphRagRequestor
|
||||
#from . document_rag import DocumentRagRequestor
|
||||
#from . triples_query import TriplesQueryRequestor
|
||||
#from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
||||
#from . embeddings import EmbeddingsRequestor
|
||||
from . triples_query import TriplesQueryRequestor
|
||||
from . graph_embeddings_query import GraphEmbeddingsQueryRequestor
|
||||
from . embeddings import EmbeddingsRequestor
|
||||
#from . encyclopedia import EncyclopediaRequestor
|
||||
#from . agent import AgentRequestor
|
||||
from . agent import AgentRequestor
|
||||
#from . dbpedia import DbpediaRequestor
|
||||
#from . internet_search import InternetSearchRequestor
|
||||
#from . librarian import LibrarianRequestor
|
||||
|
|
@ -52,7 +56,10 @@ from . mux import MuxEndpoint
|
|||
from . metrics import MetricsEndpoint
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . flow_endpoint import FlowEndpoint
|
||||
from . auth import Authenticator
|
||||
from .. base import Subscriber
|
||||
from .. base import Consumer
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -68,11 +75,6 @@ class Api:
|
|||
|
||||
def __init__(self, **config):
|
||||
|
||||
self.app = web.Application(
|
||||
middlewares=[],
|
||||
client_max_size=256 * 1024 * 1024
|
||||
)
|
||||
|
||||
self.port = int(config.get("port", default_port))
|
||||
self.timeout = int(config.get("timeout", default_timeout))
|
||||
self.pulsar_host = config.get("pulsar_host", default_pulsar_host)
|
||||
|
|
@ -143,11 +145,11 @@ class Api:
|
|||
# pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
# auth = self.auth,
|
||||
# ),
|
||||
"config": ConfigRequestor(
|
||||
(None, "config"): ConfigRequestor(
|
||||
pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
"flow": FlowRequestor(
|
||||
(None, "flow"): FlowRequestor(
|
||||
pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
|
|
@ -211,11 +213,16 @@ class Api:
|
|||
# ),
|
||||
ServiceEndpoint(
|
||||
endpoint_path = "/api/v1/config", auth=self.auth,
|
||||
requestor = self.services["config"],
|
||||
requestor = self.services[(None, "config")],
|
||||
),
|
||||
ServiceEndpoint(
|
||||
endpoint_path = "/api/v1/flow", auth=self.auth,
|
||||
requestor = self.services["flow"],
|
||||
requestor = self.services[(None, "flow")],
|
||||
),
|
||||
FlowEndpoint(
|
||||
endpoint_path = "/api/v1/flow/{flow}/{kind}",
|
||||
auth=self.auth,
|
||||
requestors = self.services,
|
||||
),
|
||||
# ServiceEndpoint(
|
||||
# endpoint_path = "/api/v1/encyclopedia", auth=self.auth,
|
||||
|
|
@ -273,10 +280,117 @@ class Api:
|
|||
),
|
||||
]
|
||||
|
||||
for ep in self.endpoints:
|
||||
ep.add_routes(self.app)
|
||||
self.flows = {}
|
||||
|
||||
async def on_config(self, msg, proc, flow):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
print(f"Config version", v.version)
|
||||
|
||||
if "flows" in v.config:
|
||||
|
||||
flows = v.config["flows"]
|
||||
|
||||
wanted = list(flows.keys())
|
||||
current = list(self.flows.keys())
|
||||
|
||||
for k in wanted:
|
||||
if k not in current:
|
||||
self.flows[k] = json.loads(flows[k])
|
||||
await self.start_flow(k, self.flows[k])
|
||||
|
||||
for k in current:
|
||||
if k not in wanted:
|
||||
await self.stop_flow(k, self.flows[k])
|
||||
del self.flows[k]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}", flush=True)
|
||||
|
||||
async def start_flow(self, id, flow):
|
||||
|
||||
print("Start flow", id)
|
||||
intf = flow["interfaces"]
|
||||
|
||||
kinds = {
|
||||
"agent": AgentRequestor,
|
||||
"text-completion": TextCompletionRequestor,
|
||||
"prompt": PromptRequestor,
|
||||
"graph-rag": GraphRagRequestor,
|
||||
"embeddings": EmbeddingsRequestor,
|
||||
"graph-embeddings": GraphEmbeddingsQueryRequestor,
|
||||
"triples-query": TriplesQueryRequestor,
|
||||
}
|
||||
|
||||
for api_kind, requestor in kinds.items():
|
||||
|
||||
if api_kind in intf:
|
||||
k = (id, api_kind)
|
||||
if k in self.services:
|
||||
await self.services[k].stop()
|
||||
del self.services[k]
|
||||
|
||||
self.services[k] = requestor(
|
||||
pulsar_client=self.pulsar_client, timeout=self.timeout,
|
||||
request_queue = intf[api_kind]["request"],
|
||||
response_queue = intf[api_kind]["response"],
|
||||
consumer = f"api-gateway-{id}-{api_kind}-request",
|
||||
subscriber = f"api-gateway-{id}-{api_kind}-request",
|
||||
auth = self.auth,
|
||||
)
|
||||
await self.services[k].start()
|
||||
|
||||
async def stop_flow(self, id, flow):
|
||||
print("Stop flow", id)
|
||||
intf = flow["interfaces"]
|
||||
|
||||
svc_list = list(self.services.keys())
|
||||
|
||||
for k in svc_list:
|
||||
|
||||
kid, kkind = k
|
||||
|
||||
if id == kid:
|
||||
await self.services[k].stop()
|
||||
del self.services[k]
|
||||
|
||||
async def config_loader(self):
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
self.config_cons = Consumer(
|
||||
taskgroup = tg,
|
||||
flow = None,
|
||||
client = self.pulsar_client,
|
||||
subscriber = f"gateway-{id}",
|
||||
topic = config_push_queue,
|
||||
schema = ConfigPush,
|
||||
handler = self.on_config,
|
||||
start_of_messages = True,
|
||||
)
|
||||
|
||||
await self.config_cons.start()
|
||||
|
||||
print("Waiting...")
|
||||
|
||||
print("Config consumer done. :/")
|
||||
|
||||
async def app_factory(self):
|
||||
|
||||
self.app = web.Application(
|
||||
middlewares=[],
|
||||
client_max_size=256 * 1024 * 1024
|
||||
)
|
||||
|
||||
asyncio.create_task(self.config_loader())
|
||||
|
||||
for ep in self.endpoints:
|
||||
ep.add_routes(self.app)
|
||||
|
||||
for ep in self.endpoints:
|
||||
await ep.start()
|
||||
|
|
|
|||
|
|
@ -1,20 +1,23 @@
|
|||
|
||||
from .. schema import TextCompletionRequest, TextCompletionResponse
|
||||
from .. schema import text_completion_request_queue
|
||||
from .. schema import text_completion_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
|
||||
class TextCompletionRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout, auth):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(TextCompletionRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
request_queue=text_completion_request_queue,
|
||||
response_queue=text_completion_response_queue,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=TextCompletionRequest,
|
||||
response_schema=TextCompletionResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,21 +1,24 @@
|
|||
|
||||
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Triples
|
||||
from .. schema import triples_request_queue
|
||||
from .. schema import triples_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
from . serialize import to_value, serialize_subgraph
|
||||
|
||||
class TriplesQueryRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_client, timeout, auth):
|
||||
def __init__(
|
||||
self, pulsar_client, request_queue, response_queue, timeout, auth,
|
||||
consumer, subscriber,
|
||||
):
|
||||
|
||||
super(TriplesQueryRequestor, self).__init__(
|
||||
pulsar_client=pulsar_client,
|
||||
request_queue=triples_request_queue,
|
||||
response_queue=triples_response_queue,
|
||||
request_queue=request_queue,
|
||||
response_queue=response_queue,
|
||||
request_schema=TriplesQueryRequest,
|
||||
response_schema=TriplesQueryResponse,
|
||||
subscription = subscriber,
|
||||
consumer_name = consumer,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue