mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-26 08:56:21 +02:00
API gateway in a proper module, restarting publishers & subscribers as appropriate (#166)
This commit is contained in:
parent
ba6d6c13af
commit
a1e0edd96f
5 changed files with 643 additions and 549 deletions
|
|
@ -1,553 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
|
||||
# are active listeners
|
||||
from trustgraph.api.gateway import run
|
||||
|
||||
# FIXME: Connection errors in publishers / subscribers cause those threads
|
||||
# to fail and are not failed or retried
|
||||
|
||||
import asyncio
|
||||
from aiohttp import web
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
import os
|
||||
|
||||
import pulsar
|
||||
from pulsar.asyncio import Client
|
||||
from pulsar.schema import JsonSchema
|
||||
import _pulsar
|
||||
import aiopulsar
|
||||
|
||||
from trustgraph.clients.llm_client import LlmClient
|
||||
from trustgraph.clients.prompt_client import PromptClient
|
||||
|
||||
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse
|
||||
from trustgraph.schema import text_completion_request_queue
|
||||
from trustgraph.schema import text_completion_response_queue
|
||||
|
||||
from trustgraph.schema import PromptRequest, PromptResponse
|
||||
from trustgraph.schema import prompt_request_queue
|
||||
from trustgraph.schema import prompt_response_queue
|
||||
|
||||
from trustgraph.schema import GraphRagQuery, GraphRagResponse
|
||||
from trustgraph.schema import graph_rag_request_queue
|
||||
from trustgraph.schema import graph_rag_response_queue
|
||||
|
||||
from trustgraph.schema import TriplesQueryRequest, TriplesQueryResponse, Value
|
||||
from trustgraph.schema import triples_request_queue
|
||||
from trustgraph.schema import triples_response_queue
|
||||
|
||||
from trustgraph.schema import AgentRequest, AgentResponse
|
||||
from trustgraph.schema import agent_request_queue
|
||||
from trustgraph.schema import agent_response_queue
|
||||
|
||||
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from trustgraph.schema import embeddings_request_queue
|
||||
from trustgraph.schema import embeddings_response_queue
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
|
||||
TIME_OUT = 600
|
||||
|
||||
class Publisher:
|
||||
|
||||
def __init__(self, pulsar_host, topic, schema=None, max_size=10):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.schema = schema
|
||||
self.q = asyncio.Queue(maxsize=max_size)
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
async with aiopulsar.connect(self.pulsar_host) as client:
|
||||
async with client.create_producer(
|
||||
topic=self.topic,
|
||||
schema=self.schema,
|
||||
) as producer:
|
||||
while True:
|
||||
id, item = await self.q.get()
|
||||
await producer.send(item, { "id": id })
|
||||
# print("message out")
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
async def send(self, id, msg):
|
||||
await self.q.put((id, msg))
|
||||
|
||||
class Subscriber:
|
||||
|
||||
def __init__(self, pulsar_host, topic, subscription, consumer_name,
|
||||
schema=None, max_size=10):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.subscription = subscription
|
||||
self.consumer_name = consumer_name
|
||||
self.schema = schema
|
||||
self.q = {}
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
async with aiopulsar.connect(pulsar_host) as client:
|
||||
async with client.subscribe(
|
||||
topic=self.topic,
|
||||
subscription_name=self.subscription,
|
||||
consumer_name=self.consumer_name,
|
||||
schema=self.schema,
|
||||
) as consumer:
|
||||
while True:
|
||||
msg = await consumer.receive()
|
||||
# print("message in", self.topic)
|
||||
id = msg.properties()["id"]
|
||||
value = msg.value()
|
||||
if id in self.q:
|
||||
await self.q[id].put(value)
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
async def subscribe(self, id):
|
||||
q = asyncio.Queue()
|
||||
self.q[id] = q
|
||||
return q
|
||||
|
||||
async def unsubscribe(self, id):
|
||||
if id in self.q:
|
||||
del self.q[id]
|
||||
|
||||
class Api:
|
||||
|
||||
def __init__(self, **config):
|
||||
|
||||
self.port = int(config.get("port", "8088"))
|
||||
self.app = web.Application(middlewares=[])
|
||||
|
||||
self.llm_out = Publisher(
|
||||
pulsar_host, text_completion_request_queue,
|
||||
schema=JsonSchema(TextCompletionRequest)
|
||||
)
|
||||
|
||||
self.llm_in = Subscriber(
|
||||
pulsar_host, text_completion_response_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
JsonSchema(TextCompletionResponse)
|
||||
)
|
||||
|
||||
self.prompt_out = Publisher(
|
||||
pulsar_host, prompt_request_queue,
|
||||
schema=JsonSchema(PromptRequest)
|
||||
)
|
||||
|
||||
self.prompt_in = Subscriber(
|
||||
pulsar_host, prompt_response_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
JsonSchema(PromptResponse)
|
||||
)
|
||||
|
||||
self.graph_rag_out = Publisher(
|
||||
pulsar_host, graph_rag_request_queue,
|
||||
schema=JsonSchema(GraphRagQuery)
|
||||
)
|
||||
|
||||
self.graph_rag_in = Subscriber(
|
||||
pulsar_host, graph_rag_response_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
JsonSchema(GraphRagResponse)
|
||||
)
|
||||
|
||||
self.triples_query_out = Publisher(
|
||||
pulsar_host, triples_request_queue,
|
||||
schema=JsonSchema(TriplesQueryRequest)
|
||||
)
|
||||
|
||||
self.triples_query_in = Subscriber(
|
||||
pulsar_host, triples_response_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
JsonSchema(TriplesQueryResponse)
|
||||
)
|
||||
|
||||
self.agent_out = Publisher(
|
||||
pulsar_host, agent_request_queue,
|
||||
schema=JsonSchema(AgentRequest)
|
||||
)
|
||||
|
||||
self.agent_in = Subscriber(
|
||||
pulsar_host, agent_response_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
JsonSchema(AgentResponse)
|
||||
)
|
||||
|
||||
self.embeddings_out = Publisher(
|
||||
pulsar_host, embeddings_request_queue,
|
||||
schema=JsonSchema(EmbeddingsRequest)
|
||||
)
|
||||
|
||||
self.embeddings_in = Subscriber(
|
||||
pulsar_host, embeddings_response_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
JsonSchema(EmbeddingsResponse)
|
||||
)
|
||||
|
||||
self.app.add_routes([
|
||||
web.post("/api/v1/text-completion", self.llm),
|
||||
web.post("/api/v1/prompt", self.prompt),
|
||||
web.post("/api/v1/graph-rag", self.graph_rag),
|
||||
web.post("/api/v1/triples-query", self.triples_query),
|
||||
web.post("/api/v1/agent", self.agent),
|
||||
web.post("/api/v1/embeddings", self.embeddings),
|
||||
])
|
||||
|
||||
async def llm(self, request):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
q = await self.llm_in.subscribe(id)
|
||||
|
||||
await self.llm_out.send(
|
||||
id,
|
||||
TextCompletionRequest(
|
||||
system=data["system"],
|
||||
prompt=data["prompt"]
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get(), TIME_OUT)
|
||||
except:
|
||||
raise RuntimeError("Timeout waiting for response")
|
||||
|
||||
if resp.error:
|
||||
return web.json_response(
|
||||
{ "error": resp.error.message }
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{ "response": resp.response }
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
finally:
|
||||
await self.llm_in.unsubscribe(id)
|
||||
|
||||
async def prompt(self, request):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
q = await self.prompt_in.subscribe(id)
|
||||
|
||||
terms = {
|
||||
k: json.dumps(v)
|
||||
for k, v in data["variables"].items()
|
||||
}
|
||||
|
||||
await self.prompt_out.send(
|
||||
id,
|
||||
PromptRequest(
|
||||
id=data["id"],
|
||||
terms=terms
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get(), TIME_OUT)
|
||||
except:
|
||||
raise RuntimeError("Timeout waiting for response")
|
||||
|
||||
if resp.error:
|
||||
return web.json_response(
|
||||
{ "error": resp.error.message }
|
||||
)
|
||||
|
||||
if resp.object:
|
||||
return web.json_response(
|
||||
{ "object": resp.object }
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{ "text": resp.text }
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
finally:
|
||||
await self.prompt_in.unsubscribe(id)
|
||||
|
||||
async def graph_rag(self, request):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
q = await self.graph_rag_in.subscribe(id)
|
||||
|
||||
await self.graph_rag_out.send(
|
||||
id,
|
||||
GraphRagQuery(
|
||||
query=data["query"],
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get(), TIME_OUT)
|
||||
except:
|
||||
raise RuntimeError("Timeout waiting for response")
|
||||
|
||||
if resp.error:
|
||||
return web.json_response(
|
||||
{ "error": resp.error.message }
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{ "response": resp.response }
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
finally:
|
||||
await self.graph_rag_in.unsubscribe(id)
|
||||
|
||||
async def triples_query(self, request):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
q = await self.triples_query_in.subscribe(id)
|
||||
|
||||
if "s" in data:
|
||||
if data["s"].startswith("http:") or data["s"].startswith("https:"):
|
||||
s = Value(value=data["s"], is_uri=True)
|
||||
else:
|
||||
s = Value(value=data["s"], is_uri=True)
|
||||
else:
|
||||
s = None
|
||||
|
||||
if "p" in data:
|
||||
if data["p"].startswith("http:") or data["p"].startswith("https:"):
|
||||
p = Value(value=data["p"], is_uri=True)
|
||||
else:
|
||||
p = Value(value=data["p"], is_uri=True)
|
||||
else:
|
||||
p = None
|
||||
|
||||
if "o" in data:
|
||||
if data["o"].startswith("http:") or data["o"].startswith("https:"):
|
||||
o = Value(value=data["o"], is_uri=True)
|
||||
else:
|
||||
o = Value(value=data["o"], is_uri=True)
|
||||
else:
|
||||
o = None
|
||||
|
||||
limit = int(data.get("limit", 10000))
|
||||
|
||||
await self.triples_query_out.send(
|
||||
id,
|
||||
TriplesQueryRequest(
|
||||
s = s, p = p, o = o,
|
||||
limit = limit,
|
||||
user = data.get("user", "trustgraph"),
|
||||
collection = data.get("collection", "default"),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get(), TIME_OUT)
|
||||
except:
|
||||
raise RuntimeError("Timeout waiting for response")
|
||||
|
||||
if resp.error:
|
||||
return web.json_response(
|
||||
{ "error": resp.error.message }
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"response": [
|
||||
{
|
||||
"s": {
|
||||
"v": t.s.value,
|
||||
"e": t.s.is_uri,
|
||||
},
|
||||
"p": {
|
||||
"v": t.p.value,
|
||||
"e": t.p.is_uri,
|
||||
},
|
||||
"o": {
|
||||
"v": t.o.value,
|
||||
"e": t.o.is_uri,
|
||||
}
|
||||
}
|
||||
for t in resp.triples
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
finally:
|
||||
await self.graph_rag_in.unsubscribe(id)
|
||||
|
||||
async def agent(self, request):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
q = await self.agent_in.subscribe(id)
|
||||
|
||||
await self.agent_out.send(
|
||||
id,
|
||||
AgentRequest(
|
||||
question=data["question"],
|
||||
)
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get(), TIME_OUT)
|
||||
except:
|
||||
raise RuntimeError("Timeout waiting for response")
|
||||
|
||||
if resp.error:
|
||||
return web.json_response(
|
||||
{ "error": resp.error.message }
|
||||
)
|
||||
|
||||
if resp.answer: break
|
||||
|
||||
if resp.thought: print("thought:", resp.thought)
|
||||
if resp.observation: print("observation:", resp.observation)
|
||||
|
||||
if resp.answer:
|
||||
return web.json_response(
|
||||
{ "answer": resp.answer }
|
||||
)
|
||||
|
||||
# Can't happen, ook at the logic
|
||||
raise RuntimeError("Strange state")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
finally:
|
||||
await self.agent_in.unsubscribe(id)
|
||||
|
||||
async def embeddings(self, request):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
q = await self.embeddings_in.subscribe(id)
|
||||
|
||||
await self.embeddings_out.send(
|
||||
id,
|
||||
EmbeddingsRequest(
|
||||
text=data["text"],
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await asyncio.wait_for(q.get(), TIME_OUT)
|
||||
except:
|
||||
raise RuntimeError("Timeout waiting for response")
|
||||
|
||||
if resp.error:
|
||||
return web.json_response(
|
||||
{ "error": resp.error.message }
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{ "vectors": resp.vectors }
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
finally:
|
||||
await self.embeddings_in.unsubscribe(id)
|
||||
|
||||
async def app_factory(self):
|
||||
|
||||
self.llm_pub_task = asyncio.create_task(self.llm_in.run())
|
||||
self.llm_sub_task = asyncio.create_task(self.llm_out.run())
|
||||
|
||||
self.prompt_pub_task = asyncio.create_task(self.prompt_in.run())
|
||||
self.prompt_sub_task = asyncio.create_task(self.prompt_out.run())
|
||||
|
||||
self.graph_rag_pub_task = asyncio.create_task(self.graph_rag_in.run())
|
||||
self.graph_rag_sub_task = asyncio.create_task(self.graph_rag_out.run())
|
||||
|
||||
self.triples_query_pub_task = asyncio.create_task(
|
||||
self.triples_query_in.run()
|
||||
)
|
||||
self.triples_query_sub_task = asyncio.create_task(
|
||||
self.triples_query_out.run()
|
||||
)
|
||||
|
||||
self.agent_pub_task = asyncio.create_task(self.agent_in.run())
|
||||
self.agent_sub_task = asyncio.create_task(self.agent_out.run())
|
||||
|
||||
self.embeddings_pub_task = asyncio.create_task(
|
||||
self.embeddings_in.run()
|
||||
)
|
||||
self.embeddings_sub_task = asyncio.create_task(
|
||||
self.embeddings_out.run()
|
||||
)
|
||||
|
||||
return self.app
|
||||
|
||||
def run(self):
|
||||
web.run_app(self.app_factory(), port=self.port)
|
||||
|
||||
a = Api()
|
||||
a.run()
|
||||
run()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue