#!/usr/bin/env python3 # FIXME: Subscribes to Pulsar unnecessarily, should only do it when there # are active listeners # FIXME: Connection errors in publishers / subscribers cause those threads # to fail and are not failed or retried import asyncio from aiohttp import web import json import logging import uuid import os import pulsar from pulsar.asyncio import Client from pulsar.schema import JsonSchema import _pulsar import aiopulsar from trustgraph.clients.llm_client import LlmClient from trustgraph.clients.prompt_client import PromptClient from trustgraph.schema import TextCompletionRequest, TextCompletionResponse from trustgraph.schema import text_completion_request_queue from trustgraph.schema import text_completion_response_queue from trustgraph.schema import PromptRequest, PromptResponse from trustgraph.schema import prompt_request_queue from trustgraph.schema import prompt_response_queue from trustgraph.schema import GraphRagQuery, GraphRagResponse from trustgraph.schema import graph_rag_request_queue from trustgraph.schema import graph_rag_response_queue from trustgraph.schema import TriplesQueryRequest, TriplesQueryResponse, Value from trustgraph.schema import triples_request_queue from trustgraph.schema import triples_response_queue from trustgraph.schema import AgentRequest, AgentResponse from trustgraph.schema import agent_request_queue from trustgraph.schema import agent_response_queue from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse from trustgraph.schema import embeddings_request_queue from trustgraph.schema import embeddings_response_queue logger = logging.getLogger("api") logger.setLevel(logging.INFO) pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650") TIME_OUT = 600 class Publisher: def __init__(self, pulsar_host, topic, schema=None, max_size=10): self.pulsar_host = pulsar_host self.topic = topic self.schema = schema self.q = asyncio.Queue(maxsize=max_size) async def run(self): try: async with aiopulsar.connect(self.pulsar_host) as client: async with client.create_producer( topic=self.topic, schema=self.schema, ) as producer: while True: id, item = await self.q.get() await producer.send(item, { "id": id }) # print("message out") except Exception as e: print("Exception:", e, flush=True) async def send(self, id, msg): await self.q.put((id, msg)) class Subscriber: def __init__(self, pulsar_host, topic, subscription, consumer_name, schema=None, max_size=10): self.pulsar_host = pulsar_host self.topic = topic self.subscription = subscription self.consumer_name = consumer_name self.schema = schema self.q = {} async def run(self): try: async with aiopulsar.connect(pulsar_host) as client: async with client.subscribe( topic=self.topic, subscription_name=self.subscription, consumer_name=self.consumer_name, schema=self.schema, ) as consumer: while True: msg = await consumer.receive() # print("message in", self.topic) id = msg.properties()["id"] value = msg.value() if id in self.q: await self.q[id].put(value) except Exception as e: print("Exception:", e, flush=True) async def subscribe(self, id): q = asyncio.Queue() self.q[id] = q return q async def unsubscribe(self, id): if id in self.q: del self.q[id] class Api: def __init__(self, **config): self.port = int(config.get("port", "8088")) self.app = web.Application(middlewares=[]) self.llm_out = Publisher( pulsar_host, text_completion_request_queue, schema=JsonSchema(TextCompletionRequest) ) self.llm_in = Subscriber( pulsar_host, text_completion_response_queue, "api-gateway", "api-gateway", JsonSchema(TextCompletionResponse) ) self.prompt_out = Publisher( pulsar_host, prompt_request_queue, schema=JsonSchema(PromptRequest) ) self.prompt_in = Subscriber( pulsar_host, prompt_response_queue, "api-gateway", "api-gateway", JsonSchema(PromptResponse) ) self.graph_rag_out = Publisher( pulsar_host, graph_rag_request_queue, schema=JsonSchema(GraphRagQuery) ) self.graph_rag_in = Subscriber( pulsar_host, graph_rag_response_queue, "api-gateway", "api-gateway", JsonSchema(GraphRagResponse) ) self.triples_query_out = Publisher( pulsar_host, triples_request_queue, schema=JsonSchema(TriplesQueryRequest) ) self.triples_query_in = Subscriber( pulsar_host, triples_response_queue, "api-gateway", "api-gateway", JsonSchema(TriplesQueryResponse) ) self.agent_out = Publisher( pulsar_host, agent_request_queue, schema=JsonSchema(AgentRequest) ) self.agent_in = Subscriber( pulsar_host, agent_response_queue, "api-gateway", "api-gateway", JsonSchema(AgentResponse) ) self.embeddings_out = Publisher( pulsar_host, embeddings_request_queue, schema=JsonSchema(EmbeddingsRequest) ) self.embeddings_in = Subscriber( pulsar_host, embeddings_response_queue, "api-gateway", "api-gateway", JsonSchema(EmbeddingsResponse) ) self.app.add_routes([ web.post("/api/v1/text-completion", self.llm), web.post("/api/v1/prompt", self.prompt), web.post("/api/v1/graph-rag", self.graph_rag), web.post("/api/v1/triples-query", self.triples_query), web.post("/api/v1/agent", self.agent), web.post("/api/v1/embeddings", self.embeddings), ]) async def llm(self, request): id = str(uuid.uuid4()) try: data = await request.json() q = await self.llm_in.subscribe(id) await self.llm_out.send( id, TextCompletionRequest( system=data["system"], prompt=data["prompt"] ) ) try: resp = await asyncio.wait_for(q.get(), TIME_OUT) except: raise RuntimeError("Timeout waiting for response") if resp.error: return web.json_response( { "error": resp.error.message } ) return web.json_response( { "response": resp.response } ) except Exception as e: logging.error(f"Exception: {e}") return web.json_response( { "error": str(e) } ) finally: await self.llm_in.unsubscribe(id) async def prompt(self, request): id = str(uuid.uuid4()) try: data = await request.json() q = await self.prompt_in.subscribe(id) terms = { k: json.dumps(v) for k, v in data["variables"].items() } await self.prompt_out.send( id, PromptRequest( id=data["id"], terms=terms ) ) try: resp = await asyncio.wait_for(q.get(), TIME_OUT) except: raise RuntimeError("Timeout waiting for response") if resp.error: return web.json_response( { "error": resp.error.message } ) if resp.object: return web.json_response( { "object": resp.object } ) return web.json_response( { "text": resp.text } ) except Exception as e: logging.error(f"Exception: {e}") return web.json_response( { "error": str(e) } ) finally: await self.prompt_in.unsubscribe(id) async def graph_rag(self, request): id = str(uuid.uuid4()) try: data = await request.json() q = await self.graph_rag_in.subscribe(id) await self.graph_rag_out.send( id, GraphRagQuery( query=data["query"], user=data.get("user", "trustgraph"), collection=data.get("collection", "default"), ) ) try: resp = await asyncio.wait_for(q.get(), TIME_OUT) except: raise RuntimeError("Timeout waiting for response") if resp.error: return web.json_response( { "error": resp.error.message } ) return web.json_response( { "response": resp.response } ) except Exception as e: logging.error(f"Exception: {e}") return web.json_response( { "error": str(e) } ) finally: await self.graph_rag_in.unsubscribe(id) async def triples_query(self, request): id = str(uuid.uuid4()) try: data = await request.json() q = await self.triples_query_in.subscribe(id) if "s" in data: if data["s"].startswith("http:") or data["s"].startswith("https:"): s = Value(value=data["s"], is_uri=True) else: s = Value(value=data["s"], is_uri=True) else: s = None if "p" in data: if data["p"].startswith("http:") or data["p"].startswith("https:"): p = Value(value=data["p"], is_uri=True) else: p = Value(value=data["p"], is_uri=True) else: p = None if "o" in data: if data["o"].startswith("http:") or data["o"].startswith("https:"): o = Value(value=data["o"], is_uri=True) else: o = Value(value=data["o"], is_uri=True) else: o = None limit = int(data.get("limit", 10000)) await self.triples_query_out.send( id, TriplesQueryRequest( s = s, p = p, o = o, limit = limit, user = data.get("user", "trustgraph"), collection = data.get("collection", "default"), ) ) try: resp = await asyncio.wait_for(q.get(), TIME_OUT) except: raise RuntimeError("Timeout waiting for response") if resp.error: return web.json_response( { "error": resp.error.message } ) return web.json_response( { "response": [ { "s": { "v": t.s.value, "e": t.s.is_uri, }, "p": { "v": t.p.value, "e": t.p.is_uri, }, "o": { "v": t.o.value, "e": t.o.is_uri, } } for t in resp.triples ] } ) except Exception as e: logging.error(f"Exception: {e}") return web.json_response( { "error": str(e) } ) finally: await self.graph_rag_in.unsubscribe(id) async def agent(self, request): id = str(uuid.uuid4()) try: data = await request.json() q = await self.agent_in.subscribe(id) await self.agent_out.send( id, AgentRequest( question=data["question"], ) ) while True: try: resp = await asyncio.wait_for(q.get(), TIME_OUT) except: raise RuntimeError("Timeout waiting for response") if resp.error: return web.json_response( { "error": resp.error.message } ) if resp.answer: break if resp.thought: print("thought:", resp.thought) if resp.observation: print("observation:", resp.observation) if resp.answer: return web.json_response( { "answer": resp.answer } ) # Can't happen, ook at the logic raise RuntimeError("Strange state") except Exception as e: logging.error(f"Exception: {e}") return web.json_response( { "error": str(e) } ) finally: await self.agent_in.unsubscribe(id) async def embeddings(self, request): id = str(uuid.uuid4()) try: data = await request.json() q = await self.embeddings_in.subscribe(id) await self.embeddings_out.send( id, EmbeddingsRequest( text=data["text"], ) ) try: resp = await asyncio.wait_for(q.get(), TIME_OUT) except: raise RuntimeError("Timeout waiting for response") if resp.error: return web.json_response( { "error": resp.error.message } ) return web.json_response( { "vectors": resp.vectors } ) except Exception as e: logging.error(f"Exception: {e}") return web.json_response( { "error": str(e) } ) finally: await self.embeddings_in.unsubscribe(id) async def app_factory(self): self.llm_pub_task = asyncio.create_task(self.llm_in.run()) self.llm_sub_task = asyncio.create_task(self.llm_out.run()) self.prompt_pub_task = asyncio.create_task(self.prompt_in.run()) self.prompt_sub_task = asyncio.create_task(self.prompt_out.run()) self.graph_rag_pub_task = asyncio.create_task(self.graph_rag_in.run()) self.graph_rag_sub_task = asyncio.create_task(self.graph_rag_out.run()) self.triples_query_pub_task = asyncio.create_task( self.triples_query_in.run() ) self.triples_query_sub_task = asyncio.create_task( self.triples_query_out.run() ) self.agent_pub_task = asyncio.create_task(self.agent_in.run()) self.agent_sub_task = asyncio.create_task(self.agent_out.run()) self.embeddings_pub_task = asyncio.create_task( self.embeddings_in.run() ) self.embeddings_sub_task = asyncio.create_task( self.embeddings_out.run() ) return self.app def run(self): web.run_app(self.app_factory(), port=self.port) a = Api() a.run()