mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
553 lines
16 KiB
Python
Executable file
553 lines
16 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
|
|
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
|
|
# are active listeners
|
|
|
|
# FIXME: Connection errors in publishers / subscribers cause those threads
|
|
# to fail and are not failed or retried
|
|
|
|
import asyncio
|
|
from aiohttp import web
|
|
import json
|
|
import logging
|
|
import uuid
|
|
import os
|
|
|
|
import pulsar
|
|
from pulsar.asyncio import Client
|
|
from pulsar.schema import JsonSchema
|
|
import _pulsar
|
|
import aiopulsar
|
|
|
|
from trustgraph.clients.llm_client import LlmClient
|
|
from trustgraph.clients.prompt_client import PromptClient
|
|
|
|
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse
|
|
from trustgraph.schema import text_completion_request_queue
|
|
from trustgraph.schema import text_completion_response_queue
|
|
|
|
from trustgraph.schema import PromptRequest, PromptResponse
|
|
from trustgraph.schema import prompt_request_queue
|
|
from trustgraph.schema import prompt_response_queue
|
|
|
|
from trustgraph.schema import GraphRagQuery, GraphRagResponse
|
|
from trustgraph.schema import graph_rag_request_queue
|
|
from trustgraph.schema import graph_rag_response_queue
|
|
|
|
from trustgraph.schema import TriplesQueryRequest, TriplesQueryResponse, Value
|
|
from trustgraph.schema import triples_request_queue
|
|
from trustgraph.schema import triples_response_queue
|
|
|
|
from trustgraph.schema import AgentRequest, AgentResponse
|
|
from trustgraph.schema import agent_request_queue
|
|
from trustgraph.schema import agent_response_queue
|
|
|
|
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse
|
|
from trustgraph.schema import embeddings_request_queue
|
|
from trustgraph.schema import embeddings_response_queue
|
|
|
|
logger = logging.getLogger("api")
|
|
logger.setLevel(logging.INFO)
|
|
|
|
pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
|
|
TIME_OUT = 600
|
|
|
|
class Publisher:
|
|
|
|
def __init__(self, pulsar_host, topic, schema=None, max_size=10):
|
|
self.pulsar_host = pulsar_host
|
|
self.topic = topic
|
|
self.schema = schema
|
|
self.q = asyncio.Queue(maxsize=max_size)
|
|
|
|
async def run(self):
|
|
try:
|
|
async with aiopulsar.connect(self.pulsar_host) as client:
|
|
async with client.create_producer(
|
|
topic=self.topic,
|
|
schema=self.schema,
|
|
) as producer:
|
|
while True:
|
|
id, item = await self.q.get()
|
|
await producer.send(item, { "id": id })
|
|
# print("message out")
|
|
except Exception as e:
|
|
print("Exception:", e, flush=True)
|
|
|
|
async def send(self, id, msg):
|
|
await self.q.put((id, msg))
|
|
|
|
class Subscriber:
|
|
|
|
def __init__(self, pulsar_host, topic, subscription, consumer_name,
|
|
schema=None, max_size=10):
|
|
self.pulsar_host = pulsar_host
|
|
self.topic = topic
|
|
self.subscription = subscription
|
|
self.consumer_name = consumer_name
|
|
self.schema = schema
|
|
self.q = {}
|
|
|
|
async def run(self):
|
|
try:
|
|
async with aiopulsar.connect(pulsar_host) as client:
|
|
async with client.subscribe(
|
|
topic=self.topic,
|
|
subscription_name=self.subscription,
|
|
consumer_name=self.consumer_name,
|
|
schema=self.schema,
|
|
) as consumer:
|
|
while True:
|
|
msg = await consumer.receive()
|
|
# print("message in", self.topic)
|
|
id = msg.properties()["id"]
|
|
value = msg.value()
|
|
if id in self.q:
|
|
await self.q[id].put(value)
|
|
except Exception as e:
|
|
print("Exception:", e, flush=True)
|
|
|
|
async def subscribe(self, id):
|
|
q = asyncio.Queue()
|
|
self.q[id] = q
|
|
return q
|
|
|
|
async def unsubscribe(self, id):
|
|
if id in self.q:
|
|
del self.q[id]
|
|
|
|
class Api:
|
|
|
|
def __init__(self, **config):
|
|
|
|
self.port = int(config.get("port", "8088"))
|
|
self.app = web.Application(middlewares=[])
|
|
|
|
self.llm_out = Publisher(
|
|
pulsar_host, text_completion_request_queue,
|
|
schema=JsonSchema(TextCompletionRequest)
|
|
)
|
|
|
|
self.llm_in = Subscriber(
|
|
pulsar_host, text_completion_response_queue,
|
|
"api-gateway", "api-gateway",
|
|
JsonSchema(TextCompletionResponse)
|
|
)
|
|
|
|
self.prompt_out = Publisher(
|
|
pulsar_host, prompt_request_queue,
|
|
schema=JsonSchema(PromptRequest)
|
|
)
|
|
|
|
self.prompt_in = Subscriber(
|
|
pulsar_host, prompt_response_queue,
|
|
"api-gateway", "api-gateway",
|
|
JsonSchema(PromptResponse)
|
|
)
|
|
|
|
self.graph_rag_out = Publisher(
|
|
pulsar_host, graph_rag_request_queue,
|
|
schema=JsonSchema(GraphRagQuery)
|
|
)
|
|
|
|
self.graph_rag_in = Subscriber(
|
|
pulsar_host, graph_rag_response_queue,
|
|
"api-gateway", "api-gateway",
|
|
JsonSchema(GraphRagResponse)
|
|
)
|
|
|
|
self.triples_query_out = Publisher(
|
|
pulsar_host, triples_request_queue,
|
|
schema=JsonSchema(TriplesQueryRequest)
|
|
)
|
|
|
|
self.triples_query_in = Subscriber(
|
|
pulsar_host, triples_response_queue,
|
|
"api-gateway", "api-gateway",
|
|
JsonSchema(TriplesQueryResponse)
|
|
)
|
|
|
|
self.agent_out = Publisher(
|
|
pulsar_host, agent_request_queue,
|
|
schema=JsonSchema(AgentRequest)
|
|
)
|
|
|
|
self.agent_in = Subscriber(
|
|
pulsar_host, agent_response_queue,
|
|
"api-gateway", "api-gateway",
|
|
JsonSchema(AgentResponse)
|
|
)
|
|
|
|
self.embeddings_out = Publisher(
|
|
pulsar_host, embeddings_request_queue,
|
|
schema=JsonSchema(EmbeddingsRequest)
|
|
)
|
|
|
|
self.embeddings_in = Subscriber(
|
|
pulsar_host, embeddings_response_queue,
|
|
"api-gateway", "api-gateway",
|
|
JsonSchema(EmbeddingsResponse)
|
|
)
|
|
|
|
self.app.add_routes([
|
|
web.post("/api/v1/text-completion", self.llm),
|
|
web.post("/api/v1/prompt", self.prompt),
|
|
web.post("/api/v1/graph-rag", self.graph_rag),
|
|
web.post("/api/v1/triples-query", self.triples_query),
|
|
web.post("/api/v1/agent", self.agent),
|
|
web.post("/api/v1/embeddings", self.embeddings),
|
|
])
|
|
|
|
async def llm(self, request):
|
|
|
|
id = str(uuid.uuid4())
|
|
|
|
try:
|
|
|
|
data = await request.json()
|
|
|
|
q = await self.llm_in.subscribe(id)
|
|
|
|
await self.llm_out.send(
|
|
id,
|
|
TextCompletionRequest(
|
|
system=data["system"],
|
|
prompt=data["prompt"]
|
|
)
|
|
)
|
|
|
|
try:
|
|
resp = await asyncio.wait_for(q.get(), TIME_OUT)
|
|
except:
|
|
raise RuntimeError("Timeout waiting for response")
|
|
|
|
if resp.error:
|
|
return web.json_response(
|
|
{ "error": resp.error.message }
|
|
)
|
|
|
|
return web.json_response(
|
|
{ "response": resp.response }
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Exception: {e}")
|
|
|
|
return web.json_response(
|
|
{ "error": str(e) }
|
|
)
|
|
|
|
finally:
|
|
await self.llm_in.unsubscribe(id)
|
|
|
|
async def prompt(self, request):
|
|
|
|
id = str(uuid.uuid4())
|
|
|
|
try:
|
|
|
|
data = await request.json()
|
|
|
|
q = await self.prompt_in.subscribe(id)
|
|
|
|
terms = {
|
|
k: json.dumps(v)
|
|
for k, v in data["variables"].items()
|
|
}
|
|
|
|
await self.prompt_out.send(
|
|
id,
|
|
PromptRequest(
|
|
id=data["id"],
|
|
terms=terms
|
|
)
|
|
)
|
|
|
|
try:
|
|
resp = await asyncio.wait_for(q.get(), TIME_OUT)
|
|
except:
|
|
raise RuntimeError("Timeout waiting for response")
|
|
|
|
if resp.error:
|
|
return web.json_response(
|
|
{ "error": resp.error.message }
|
|
)
|
|
|
|
if resp.object:
|
|
return web.json_response(
|
|
{ "object": resp.object }
|
|
)
|
|
|
|
return web.json_response(
|
|
{ "text": resp.text }
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Exception: {e}")
|
|
|
|
return web.json_response(
|
|
{ "error": str(e) }
|
|
)
|
|
|
|
finally:
|
|
await self.prompt_in.unsubscribe(id)
|
|
|
|
async def graph_rag(self, request):
|
|
|
|
id = str(uuid.uuid4())
|
|
|
|
try:
|
|
|
|
data = await request.json()
|
|
|
|
q = await self.graph_rag_in.subscribe(id)
|
|
|
|
await self.graph_rag_out.send(
|
|
id,
|
|
GraphRagQuery(
|
|
query=data["query"],
|
|
user=data.get("user", "trustgraph"),
|
|
collection=data.get("collection", "default"),
|
|
)
|
|
)
|
|
|
|
try:
|
|
resp = await asyncio.wait_for(q.get(), TIME_OUT)
|
|
except:
|
|
raise RuntimeError("Timeout waiting for response")
|
|
|
|
if resp.error:
|
|
return web.json_response(
|
|
{ "error": resp.error.message }
|
|
)
|
|
|
|
return web.json_response(
|
|
{ "response": resp.response }
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Exception: {e}")
|
|
|
|
return web.json_response(
|
|
{ "error": str(e) }
|
|
)
|
|
|
|
finally:
|
|
await self.graph_rag_in.unsubscribe(id)
|
|
|
|
async def triples_query(self, request):
|
|
|
|
id = str(uuid.uuid4())
|
|
|
|
try:
|
|
|
|
data = await request.json()
|
|
|
|
q = await self.triples_query_in.subscribe(id)
|
|
|
|
if "s" in data:
|
|
if data["s"].startswith("http:") or data["s"].startswith("https:"):
|
|
s = Value(value=data["s"], is_uri=True)
|
|
else:
|
|
s = Value(value=data["s"], is_uri=True)
|
|
else:
|
|
s = None
|
|
|
|
if "p" in data:
|
|
if data["p"].startswith("http:") or data["p"].startswith("https:"):
|
|
p = Value(value=data["p"], is_uri=True)
|
|
else:
|
|
p = Value(value=data["p"], is_uri=True)
|
|
else:
|
|
p = None
|
|
|
|
if "o" in data:
|
|
if data["o"].startswith("http:") or data["o"].startswith("https:"):
|
|
o = Value(value=data["o"], is_uri=True)
|
|
else:
|
|
o = Value(value=data["o"], is_uri=True)
|
|
else:
|
|
o = None
|
|
|
|
limit = int(data.get("limit", 10000))
|
|
|
|
await self.triples_query_out.send(
|
|
id,
|
|
TriplesQueryRequest(
|
|
s = s, p = p, o = o,
|
|
limit = limit,
|
|
user = data.get("user", "trustgraph"),
|
|
collection = data.get("collection", "default"),
|
|
)
|
|
)
|
|
|
|
try:
|
|
resp = await asyncio.wait_for(q.get(), TIME_OUT)
|
|
except:
|
|
raise RuntimeError("Timeout waiting for response")
|
|
|
|
if resp.error:
|
|
return web.json_response(
|
|
{ "error": resp.error.message }
|
|
)
|
|
|
|
return web.json_response(
|
|
{
|
|
"response": [
|
|
{
|
|
"s": {
|
|
"v": t.s.value,
|
|
"e": t.s.is_uri,
|
|
},
|
|
"p": {
|
|
"v": t.p.value,
|
|
"e": t.p.is_uri,
|
|
},
|
|
"o": {
|
|
"v": t.o.value,
|
|
"e": t.o.is_uri,
|
|
}
|
|
}
|
|
for t in resp.triples
|
|
]
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Exception: {e}")
|
|
|
|
return web.json_response(
|
|
{ "error": str(e) }
|
|
)
|
|
|
|
finally:
|
|
await self.graph_rag_in.unsubscribe(id)
|
|
|
|
async def agent(self, request):
|
|
|
|
id = str(uuid.uuid4())
|
|
|
|
try:
|
|
|
|
data = await request.json()
|
|
|
|
q = await self.agent_in.subscribe(id)
|
|
|
|
await self.agent_out.send(
|
|
id,
|
|
AgentRequest(
|
|
question=data["question"],
|
|
)
|
|
)
|
|
|
|
while True:
|
|
try:
|
|
resp = await asyncio.wait_for(q.get(), TIME_OUT)
|
|
except:
|
|
raise RuntimeError("Timeout waiting for response")
|
|
|
|
if resp.error:
|
|
return web.json_response(
|
|
{ "error": resp.error.message }
|
|
)
|
|
|
|
if resp.answer: break
|
|
|
|
if resp.thought: print("thought:", resp.thought)
|
|
if resp.observation: print("observation:", resp.observation)
|
|
|
|
if resp.answer:
|
|
return web.json_response(
|
|
{ "answer": resp.answer }
|
|
)
|
|
|
|
# Can't happen, ook at the logic
|
|
raise RuntimeError("Strange state")
|
|
|
|
except Exception as e:
|
|
logging.error(f"Exception: {e}")
|
|
|
|
return web.json_response(
|
|
{ "error": str(e) }
|
|
)
|
|
|
|
finally:
|
|
await self.agent_in.unsubscribe(id)
|
|
|
|
async def embeddings(self, request):
|
|
|
|
id = str(uuid.uuid4())
|
|
|
|
try:
|
|
|
|
data = await request.json()
|
|
|
|
q = await self.embeddings_in.subscribe(id)
|
|
|
|
await self.embeddings_out.send(
|
|
id,
|
|
EmbeddingsRequest(
|
|
text=data["text"],
|
|
)
|
|
)
|
|
|
|
try:
|
|
resp = await asyncio.wait_for(q.get(), TIME_OUT)
|
|
except:
|
|
raise RuntimeError("Timeout waiting for response")
|
|
|
|
if resp.error:
|
|
return web.json_response(
|
|
{ "error": resp.error.message }
|
|
)
|
|
|
|
return web.json_response(
|
|
{ "vectors": resp.vectors }
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Exception: {e}")
|
|
|
|
return web.json_response(
|
|
{ "error": str(e) }
|
|
)
|
|
|
|
finally:
|
|
await self.embeddings_in.unsubscribe(id)
|
|
|
|
async def app_factory(self):
|
|
|
|
self.llm_pub_task = asyncio.create_task(self.llm_in.run())
|
|
self.llm_sub_task = asyncio.create_task(self.llm_out.run())
|
|
|
|
self.prompt_pub_task = asyncio.create_task(self.prompt_in.run())
|
|
self.prompt_sub_task = asyncio.create_task(self.prompt_out.run())
|
|
|
|
self.graph_rag_pub_task = asyncio.create_task(self.graph_rag_in.run())
|
|
self.graph_rag_sub_task = asyncio.create_task(self.graph_rag_out.run())
|
|
|
|
self.triples_query_pub_task = asyncio.create_task(
|
|
self.triples_query_in.run()
|
|
)
|
|
self.triples_query_sub_task = asyncio.create_task(
|
|
self.triples_query_out.run()
|
|
)
|
|
|
|
self.agent_pub_task = asyncio.create_task(self.agent_in.run())
|
|
self.agent_sub_task = asyncio.create_task(self.agent_out.run())
|
|
|
|
self.embeddings_pub_task = asyncio.create_task(
|
|
self.embeddings_in.run()
|
|
)
|
|
self.embeddings_sub_task = asyncio.create_task(
|
|
self.embeddings_out.run()
|
|
)
|
|
|
|
return self.app
|
|
|
|
def run(self):
|
|
web.run_app(self.app_factory(), port=self.port)
|
|
|
|
a = Api()
|
|
a.run()
|
|
|