API gateway in a proper module, restarting publishers & subscribers as appropriate (#166)

This commit is contained in:
cybermaggedon 2024-11-20 23:17:55 +00:00 committed by GitHub
parent ba6d6c13af
commit a1e0edd96f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 643 additions and 549 deletions

View file

@ -1,553 +1,6 @@
#!/usr/bin/env python3
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
# are active listeners
from trustgraph.api.gateway import run
# FIXME: Connection errors in publishers / subscribers cause those threads
# to fail and are not failed or retried
import asyncio
from aiohttp import web
import json
import logging
import uuid
import os
import pulsar
from pulsar.asyncio import Client
from pulsar.schema import JsonSchema
import _pulsar
import aiopulsar
from trustgraph.clients.llm_client import LlmClient
from trustgraph.clients.prompt_client import PromptClient
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse
from trustgraph.schema import text_completion_request_queue
from trustgraph.schema import text_completion_response_queue
from trustgraph.schema import PromptRequest, PromptResponse
from trustgraph.schema import prompt_request_queue
from trustgraph.schema import prompt_response_queue
from trustgraph.schema import GraphRagQuery, GraphRagResponse
from trustgraph.schema import graph_rag_request_queue
from trustgraph.schema import graph_rag_response_queue
from trustgraph.schema import TriplesQueryRequest, TriplesQueryResponse, Value
from trustgraph.schema import triples_request_queue
from trustgraph.schema import triples_response_queue
from trustgraph.schema import AgentRequest, AgentResponse
from trustgraph.schema import agent_request_queue
from trustgraph.schema import agent_response_queue
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse
from trustgraph.schema import embeddings_request_queue
from trustgraph.schema import embeddings_response_queue
logger = logging.getLogger("api")
logger.setLevel(logging.INFO)
pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
TIME_OUT = 600
class Publisher:
def __init__(self, pulsar_host, topic, schema=None, max_size=10):
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = asyncio.Queue(maxsize=max_size)
async def run(self):
try:
async with aiopulsar.connect(self.pulsar_host) as client:
async with client.create_producer(
topic=self.topic,
schema=self.schema,
) as producer:
while True:
id, item = await self.q.get()
await producer.send(item, { "id": id })
# print("message out")
except Exception as e:
print("Exception:", e, flush=True)
async def send(self, id, msg):
await self.q.put((id, msg))
class Subscriber:
def __init__(self, pulsar_host, topic, subscription, consumer_name,
schema=None, max_size=10):
self.pulsar_host = pulsar_host
self.topic = topic
self.subscription = subscription
self.consumer_name = consumer_name
self.schema = schema
self.q = {}
async def run(self):
try:
async with aiopulsar.connect(pulsar_host) as client:
async with client.subscribe(
topic=self.topic,
subscription_name=self.subscription,
consumer_name=self.consumer_name,
schema=self.schema,
) as consumer:
while True:
msg = await consumer.receive()
# print("message in", self.topic)
id = msg.properties()["id"]
value = msg.value()
if id in self.q:
await self.q[id].put(value)
except Exception as e:
print("Exception:", e, flush=True)
async def subscribe(self, id):
q = asyncio.Queue()
self.q[id] = q
return q
async def unsubscribe(self, id):
if id in self.q:
del self.q[id]
class Api:
def __init__(self, **config):
self.port = int(config.get("port", "8088"))
self.app = web.Application(middlewares=[])
self.llm_out = Publisher(
pulsar_host, text_completion_request_queue,
schema=JsonSchema(TextCompletionRequest)
)
self.llm_in = Subscriber(
pulsar_host, text_completion_response_queue,
"api-gateway", "api-gateway",
JsonSchema(TextCompletionResponse)
)
self.prompt_out = Publisher(
pulsar_host, prompt_request_queue,
schema=JsonSchema(PromptRequest)
)
self.prompt_in = Subscriber(
pulsar_host, prompt_response_queue,
"api-gateway", "api-gateway",
JsonSchema(PromptResponse)
)
self.graph_rag_out = Publisher(
pulsar_host, graph_rag_request_queue,
schema=JsonSchema(GraphRagQuery)
)
self.graph_rag_in = Subscriber(
pulsar_host, graph_rag_response_queue,
"api-gateway", "api-gateway",
JsonSchema(GraphRagResponse)
)
self.triples_query_out = Publisher(
pulsar_host, triples_request_queue,
schema=JsonSchema(TriplesQueryRequest)
)
self.triples_query_in = Subscriber(
pulsar_host, triples_response_queue,
"api-gateway", "api-gateway",
JsonSchema(TriplesQueryResponse)
)
self.agent_out = Publisher(
pulsar_host, agent_request_queue,
schema=JsonSchema(AgentRequest)
)
self.agent_in = Subscriber(
pulsar_host, agent_response_queue,
"api-gateway", "api-gateway",
JsonSchema(AgentResponse)
)
self.embeddings_out = Publisher(
pulsar_host, embeddings_request_queue,
schema=JsonSchema(EmbeddingsRequest)
)
self.embeddings_in = Subscriber(
pulsar_host, embeddings_response_queue,
"api-gateway", "api-gateway",
JsonSchema(EmbeddingsResponse)
)
self.app.add_routes([
web.post("/api/v1/text-completion", self.llm),
web.post("/api/v1/prompt", self.prompt),
web.post("/api/v1/graph-rag", self.graph_rag),
web.post("/api/v1/triples-query", self.triples_query),
web.post("/api/v1/agent", self.agent),
web.post("/api/v1/embeddings", self.embeddings),
])
async def llm(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.llm_in.subscribe(id)
await self.llm_out.send(
id,
TextCompletionRequest(
system=data["system"],
prompt=data["prompt"]
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{ "response": resp.response }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.llm_in.unsubscribe(id)
async def prompt(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.prompt_in.subscribe(id)
terms = {
k: json.dumps(v)
for k, v in data["variables"].items()
}
await self.prompt_out.send(
id,
PromptRequest(
id=data["id"],
terms=terms
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
if resp.object:
return web.json_response(
{ "object": resp.object }
)
return web.json_response(
{ "text": resp.text }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.prompt_in.unsubscribe(id)
async def graph_rag(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.graph_rag_in.subscribe(id)
await self.graph_rag_out.send(
id,
GraphRagQuery(
query=data["query"],
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{ "response": resp.response }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.graph_rag_in.unsubscribe(id)
async def triples_query(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.triples_query_in.subscribe(id)
if "s" in data:
if data["s"].startswith("http:") or data["s"].startswith("https:"):
s = Value(value=data["s"], is_uri=True)
else:
s = Value(value=data["s"], is_uri=True)
else:
s = None
if "p" in data:
if data["p"].startswith("http:") or data["p"].startswith("https:"):
p = Value(value=data["p"], is_uri=True)
else:
p = Value(value=data["p"], is_uri=True)
else:
p = None
if "o" in data:
if data["o"].startswith("http:") or data["o"].startswith("https:"):
o = Value(value=data["o"], is_uri=True)
else:
o = Value(value=data["o"], is_uri=True)
else:
o = None
limit = int(data.get("limit", 10000))
await self.triples_query_out.send(
id,
TriplesQueryRequest(
s = s, p = p, o = o,
limit = limit,
user = data.get("user", "trustgraph"),
collection = data.get("collection", "default"),
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{
"response": [
{
"s": {
"v": t.s.value,
"e": t.s.is_uri,
},
"p": {
"v": t.p.value,
"e": t.p.is_uri,
},
"o": {
"v": t.o.value,
"e": t.o.is_uri,
}
}
for t in resp.triples
]
}
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.graph_rag_in.unsubscribe(id)
async def agent(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.agent_in.subscribe(id)
await self.agent_out.send(
id,
AgentRequest(
question=data["question"],
)
)
while True:
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
if resp.answer: break
if resp.thought: print("thought:", resp.thought)
if resp.observation: print("observation:", resp.observation)
if resp.answer:
return web.json_response(
{ "answer": resp.answer }
)
# Can't happen, ook at the logic
raise RuntimeError("Strange state")
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.agent_in.unsubscribe(id)
async def embeddings(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.embeddings_in.subscribe(id)
await self.embeddings_out.send(
id,
EmbeddingsRequest(
text=data["text"],
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{ "vectors": resp.vectors }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.embeddings_in.unsubscribe(id)
async def app_factory(self):
self.llm_pub_task = asyncio.create_task(self.llm_in.run())
self.llm_sub_task = asyncio.create_task(self.llm_out.run())
self.prompt_pub_task = asyncio.create_task(self.prompt_in.run())
self.prompt_sub_task = asyncio.create_task(self.prompt_out.run())
self.graph_rag_pub_task = asyncio.create_task(self.graph_rag_in.run())
self.graph_rag_sub_task = asyncio.create_task(self.graph_rag_out.run())
self.triples_query_pub_task = asyncio.create_task(
self.triples_query_in.run()
)
self.triples_query_sub_task = asyncio.create_task(
self.triples_query_out.run()
)
self.agent_pub_task = asyncio.create_task(self.agent_in.run())
self.agent_sub_task = asyncio.create_task(self.agent_out.run())
self.embeddings_pub_task = asyncio.create_task(
self.embeddings_in.run()
)
self.embeddings_sub_task = asyncio.create_task(
self.embeddings_out.run()
)
return self.app
def run(self):
web.run_app(self.app_factory(), port=self.port)
a = Api()
a.run()
run()