trustgraph/trustgraph-flow/scripts/api-gateway
cybermaggedon 92b84441eb
Feature/api gateway (#164)
* Bare bones API gateway
* Working for LLM + prompt
* RAG query works
* Triples query
* Added agent API
* Embeddings API
* Put API tests in a subdir
2024-11-20 19:55:40 +00:00

540 lines
15 KiB
Python
Executable file

#!/usr/bin/env python3
import asyncio
from aiohttp import web
import json
import logging
import uuid
import pulsar
from pulsar.asyncio import Client
from pulsar.schema import JsonSchema
import _pulsar
import aiopulsar
from trustgraph.clients.llm_client import LlmClient
from trustgraph.clients.prompt_client import PromptClient
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse
from trustgraph.schema import text_completion_request_queue
from trustgraph.schema import text_completion_response_queue
from trustgraph.schema import PromptRequest, PromptResponse
from trustgraph.schema import prompt_request_queue
from trustgraph.schema import prompt_response_queue
from trustgraph.schema import GraphRagQuery, GraphRagResponse
from trustgraph.schema import graph_rag_request_queue
from trustgraph.schema import graph_rag_response_queue
from trustgraph.schema import TriplesQueryRequest, TriplesQueryResponse, Value
from trustgraph.schema import triples_request_queue
from trustgraph.schema import triples_response_queue
from trustgraph.schema import AgentRequest, AgentResponse
from trustgraph.schema import agent_request_queue
from trustgraph.schema import agent_response_queue
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse
from trustgraph.schema import embeddings_request_queue
from trustgraph.schema import embeddings_response_queue
logger = logging.getLogger("api")
logger.setLevel(logging.INFO)
pulsar_host = "pulsar://localhost:6650"
TIME_OUT = 600
class Publisher:
def __init__(self, pulsar_host, topic, schema=None, max_size=10):
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = asyncio.Queue(maxsize=max_size)
async def run(self):
async with aiopulsar.connect(self.pulsar_host) as client:
async with client.create_producer(
topic=self.topic,
schema=self.schema,
) as producer:
while True:
id, item = await self.q.get()
await producer.send(item, { "id": id })
# print("message out")
async def send(self, id, msg):
await self.q.put((id, msg))
class Subscriber:
def __init__(self, pulsar_host, topic, subscription, consumer_name,
schema=None, max_size=10):
self.pulsar_host = pulsar_host
self.topic = topic
self.subscription = subscription
self.consumer_name = consumer_name
self.schema = schema
self.q = {}
async def run(self):
async with aiopulsar.connect(pulsar_host) as client:
async with client.subscribe(
topic=self.topic,
subscription_name=self.subscription,
consumer_name=self.consumer_name,
schema=self.schema,
) as consumer:
while True:
msg = await consumer.receive()
# print("message in", self.topic)
id = msg.properties()["id"]
value = msg.value()
if id in self.q:
await self.q[id].put(value)
async def subscribe(self, id):
q = asyncio.Queue()
self.q[id] = q
return q
async def unsubscribe(self, id):
if id in self.q:
del self.q[id]
class Api:
def __init__(self, **config):
self.port = int(config.get("port", "8088"))
self.app = web.Application(middlewares=[])
self.llm_out = Publisher(
pulsar_host, text_completion_request_queue,
schema=JsonSchema(TextCompletionRequest)
)
self.llm_in = Subscriber(
pulsar_host, text_completion_response_queue,
"api-gateway", "api-gateway",
JsonSchema(TextCompletionResponse)
)
self.prompt_out = Publisher(
pulsar_host, prompt_request_queue,
schema=JsonSchema(PromptRequest)
)
self.prompt_in = Subscriber(
pulsar_host, prompt_response_queue,
"api-gateway", "api-gateway",
JsonSchema(PromptResponse)
)
self.graph_rag_out = Publisher(
pulsar_host, graph_rag_request_queue,
schema=JsonSchema(GraphRagQuery)
)
self.graph_rag_in = Subscriber(
pulsar_host, graph_rag_response_queue,
"api-gateway", "api-gateway",
JsonSchema(GraphRagResponse)
)
self.triples_query_out = Publisher(
pulsar_host, triples_request_queue,
schema=JsonSchema(TriplesQueryRequest)
)
self.triples_query_in = Subscriber(
pulsar_host, triples_response_queue,
"api-gateway", "api-gateway",
JsonSchema(TriplesQueryResponse)
)
self.agent_out = Publisher(
pulsar_host, agent_request_queue,
schema=JsonSchema(AgentRequest)
)
self.agent_in = Subscriber(
pulsar_host, agent_response_queue,
"api-gateway", "api-gateway",
JsonSchema(AgentResponse)
)
self.embeddings_out = Publisher(
pulsar_host, embeddings_request_queue,
schema=JsonSchema(EmbeddingsRequest)
)
self.embeddings_in = Subscriber(
pulsar_host, embeddings_response_queue,
"api-gateway", "api-gateway",
JsonSchema(EmbeddingsResponse)
)
self.app.add_routes([
web.post("/api/v1/text-completion", self.llm),
web.post("/api/v1/prompt", self.prompt),
web.post("/api/v1/graph-rag", self.graph_rag),
web.post("/api/v1/triples-query", self.triples_query),
web.post("/api/v1/agent", self.agent),
web.post("/api/v1/embeddings", self.embeddings),
])
async def llm(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.llm_in.subscribe(id)
await self.llm_out.send(
id,
TextCompletionRequest(
system=data["system"],
prompt=data["prompt"]
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{ "response": resp.response }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.llm_in.unsubscribe(id)
async def prompt(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.prompt_in.subscribe(id)
terms = {
k: json.dumps(v)
for k, v in data["variables"].items()
}
await self.prompt_out.send(
id,
PromptRequest(
id=data["id"],
terms=terms
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
if resp.object:
return web.json_response(
{ "object": resp.object }
)
return web.json_response(
{ "text": resp.text }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.prompt_in.unsubscribe(id)
async def graph_rag(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.graph_rag_in.subscribe(id)
await self.graph_rag_out.send(
id,
GraphRagQuery(
query=data["query"],
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{ "response": resp.response }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.graph_rag_in.unsubscribe(id)
async def triples_query(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.triples_query_in.subscribe(id)
if "s" in data:
if data["s"].startswith("http:") or data["s"].startswith("https:"):
s = Value(value=data["s"], is_uri=True)
else:
s = Value(value=data["s"], is_uri=True)
else:
s = None
if "p" in data:
if data["p"].startswith("http:") or data["p"].startswith("https:"):
p = Value(value=data["p"], is_uri=True)
else:
p = Value(value=data["p"], is_uri=True)
else:
p = None
if "o" in data:
if data["o"].startswith("http:") or data["o"].startswith("https:"):
o = Value(value=data["o"], is_uri=True)
else:
o = Value(value=data["o"], is_uri=True)
else:
o = None
limit = int(data.get("limit", 10000))
await self.triples_query_out.send(
id,
TriplesQueryRequest(
s = s, p = p, o = o,
limit = limit,
user = data.get("user", "trustgraph"),
collection = data.get("collection", "default"),
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{
"response": [
{
"s": {
"v": t.s.value,
"e": t.s.is_uri,
},
"p": {
"v": t.p.value,
"e": t.p.is_uri,
},
"o": {
"v": t.o.value,
"e": t.o.is_uri,
}
}
for t in resp.triples
]
}
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.graph_rag_in.unsubscribe(id)
async def agent(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.agent_in.subscribe(id)
await self.agent_out.send(
id,
AgentRequest(
question=data["question"],
)
)
while True:
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
if resp.answer: break
if resp.thought: print("thought:", resp.thought)
if resp.observation: print("observation:", resp.observation)
if resp.answer:
return web.json_response(
{ "answer": resp.answer }
)
# Can't happen, ook at the logic
raise RuntimeError("Strange state")
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.agent_in.unsubscribe(id)
async def embeddings(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.embeddings_in.subscribe(id)
await self.embeddings_out.send(
id,
EmbeddingsRequest(
text=data["text"],
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{ "vectors": resp.vectors }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.embeddings_in.unsubscribe(id)
async def app_factory(self):
self.llm_pub_task = asyncio.create_task(self.llm_in.run())
self.llm_sub_task = asyncio.create_task(self.llm_out.run())
self.prompt_pub_task = asyncio.create_task(self.prompt_in.run())
self.prompt_sub_task = asyncio.create_task(self.prompt_out.run())
self.graph_rag_pub_task = asyncio.create_task(self.graph_rag_in.run())
self.graph_rag_sub_task = asyncio.create_task(self.graph_rag_out.run())
self.triples_query_pub_task = asyncio.create_task(
self.triples_query_in.run()
)
self.triples_query_sub_task = asyncio.create_task(
self.triples_query_out.run()
)
self.agent_pub_task = asyncio.create_task(self.agent_in.run())
self.agent_sub_task = asyncio.create_task(self.agent_out.run())
self.embeddings_pub_task = asyncio.create_task(
self.embeddings_in.run()
)
self.embeddings_sub_task = asyncio.create_task(
self.embeddings_out.run()
)
return self.app
def run(self):
web.run_app(self.app_factory(), port=self.port)
a = Api()
a.run()