Feature/api gateway (#164)

* Bare bones API gateway
* Working for LLM + prompt
* RAG query works
* Triples query
* Added agent API
* Embeddings API
* Put API tests in a subdir
This commit is contained in:
cybermaggedon 2024-11-20 19:55:40 +00:00 committed by GitHub
parent b536d78b57
commit 92b84441eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 768 additions and 0 deletions

28
test-api/test-agent-api Executable file
View file

@ -0,0 +1,28 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"question": "What is the highest risk aspect of running a space shuttle program? Provide 5 detailed reasons to justify our answer.",
}
resp = requests.post(
f"{url}agent",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["answer"])

25
test-api/test-embeddings-api Executable file
View file

@ -0,0 +1,25 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"text": "What is the highest risk aspect of running a space shuttle program? Provide 5 detailed reasons to justify our answer.",
}
resp = requests.post(
f"{url}embeddings",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)

31
test-api/test-graph-rag-api Executable file
View file

@ -0,0 +1,31 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"query": "Give me 10 facts",
}
resp = requests.post(
f"{url}graph-rag",
json=input,
)
resp = resp.json()
print(resp)
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["response"])
sys.exit(0)
############################################################################

31
test-api/test-llm-api Executable file
View file

@ -0,0 +1,31 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"system": "Respond in French. Use long word, form of numbers, no digits",
# "prompt": "Add 2 and 12"
"prompt": "Add 12 and 14, and then make a poem about llamas which incorporates that number. Then write a joke about llamas"
}
resp = requests.post(
f"{url}text-completion",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["response"])
############################################################################

38
test-api/test-prompt-api Executable file
View file

@ -0,0 +1,38 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"id": "question",
"variables": {
"question": "Write a joke about llamas."
}
}
resp = requests.post(
f"{url}prompt",
json=input,
)
resp = resp.json()
print(resp)
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
if "object" in resp:
print(f"Object: {resp['object']}")
sys.exit(1)
print(resp["text"])
sys.exit(0)
############################################################################

39
test-api/test-prompt2-api Executable file
View file

@ -0,0 +1,39 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"id": "extract-definitions",
"variables": {
"text": "A cat is a large mammal."
}
}
resp = requests.post(
f"{url}prompt",
json=input,
)
resp = resp.json()
print(resp)
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
if "object" in resp:
object = json.loads(resp["object"])
print(json.dumps(object, indent=4))
sys.exit(1)
print(resp["text"])
sys.exit(0)
############################################################################

35
test-api/test-triples-query-api Executable file
View file

@ -0,0 +1,35 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"p": "http://www.w3.org/2000/01/rdf-schema#label",
"limit": 10
}
resp = requests.post(
f"{url}triples-query",
json=input,
)
print(resp.text)
resp = resp.json()
print(resp)
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["response"])
sys.exit(0)
############################################################################

View file

@ -0,0 +1,540 @@
#!/usr/bin/env python3
import asyncio
from aiohttp import web
import json
import logging
import uuid
import pulsar
from pulsar.asyncio import Client
from pulsar.schema import JsonSchema
import _pulsar
import aiopulsar
from trustgraph.clients.llm_client import LlmClient
from trustgraph.clients.prompt_client import PromptClient
from trustgraph.schema import TextCompletionRequest, TextCompletionResponse
from trustgraph.schema import text_completion_request_queue
from trustgraph.schema import text_completion_response_queue
from trustgraph.schema import PromptRequest, PromptResponse
from trustgraph.schema import prompt_request_queue
from trustgraph.schema import prompt_response_queue
from trustgraph.schema import GraphRagQuery, GraphRagResponse
from trustgraph.schema import graph_rag_request_queue
from trustgraph.schema import graph_rag_response_queue
from trustgraph.schema import TriplesQueryRequest, TriplesQueryResponse, Value
from trustgraph.schema import triples_request_queue
from trustgraph.schema import triples_response_queue
from trustgraph.schema import AgentRequest, AgentResponse
from trustgraph.schema import agent_request_queue
from trustgraph.schema import agent_response_queue
from trustgraph.schema import EmbeddingsRequest, EmbeddingsResponse
from trustgraph.schema import embeddings_request_queue
from trustgraph.schema import embeddings_response_queue
logger = logging.getLogger("api")
logger.setLevel(logging.INFO)
pulsar_host = "pulsar://localhost:6650"
TIME_OUT = 600
class Publisher:
def __init__(self, pulsar_host, topic, schema=None, max_size=10):
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = asyncio.Queue(maxsize=max_size)
async def run(self):
async with aiopulsar.connect(self.pulsar_host) as client:
async with client.create_producer(
topic=self.topic,
schema=self.schema,
) as producer:
while True:
id, item = await self.q.get()
await producer.send(item, { "id": id })
# print("message out")
async def send(self, id, msg):
await self.q.put((id, msg))
class Subscriber:
def __init__(self, pulsar_host, topic, subscription, consumer_name,
schema=None, max_size=10):
self.pulsar_host = pulsar_host
self.topic = topic
self.subscription = subscription
self.consumer_name = consumer_name
self.schema = schema
self.q = {}
async def run(self):
async with aiopulsar.connect(pulsar_host) as client:
async with client.subscribe(
topic=self.topic,
subscription_name=self.subscription,
consumer_name=self.consumer_name,
schema=self.schema,
) as consumer:
while True:
msg = await consumer.receive()
# print("message in", self.topic)
id = msg.properties()["id"]
value = msg.value()
if id in self.q:
await self.q[id].put(value)
async def subscribe(self, id):
q = asyncio.Queue()
self.q[id] = q
return q
async def unsubscribe(self, id):
if id in self.q:
del self.q[id]
class Api:
def __init__(self, **config):
self.port = int(config.get("port", "8088"))
self.app = web.Application(middlewares=[])
self.llm_out = Publisher(
pulsar_host, text_completion_request_queue,
schema=JsonSchema(TextCompletionRequest)
)
self.llm_in = Subscriber(
pulsar_host, text_completion_response_queue,
"api-gateway", "api-gateway",
JsonSchema(TextCompletionResponse)
)
self.prompt_out = Publisher(
pulsar_host, prompt_request_queue,
schema=JsonSchema(PromptRequest)
)
self.prompt_in = Subscriber(
pulsar_host, prompt_response_queue,
"api-gateway", "api-gateway",
JsonSchema(PromptResponse)
)
self.graph_rag_out = Publisher(
pulsar_host, graph_rag_request_queue,
schema=JsonSchema(GraphRagQuery)
)
self.graph_rag_in = Subscriber(
pulsar_host, graph_rag_response_queue,
"api-gateway", "api-gateway",
JsonSchema(GraphRagResponse)
)
self.triples_query_out = Publisher(
pulsar_host, triples_request_queue,
schema=JsonSchema(TriplesQueryRequest)
)
self.triples_query_in = Subscriber(
pulsar_host, triples_response_queue,
"api-gateway", "api-gateway",
JsonSchema(TriplesQueryResponse)
)
self.agent_out = Publisher(
pulsar_host, agent_request_queue,
schema=JsonSchema(AgentRequest)
)
self.agent_in = Subscriber(
pulsar_host, agent_response_queue,
"api-gateway", "api-gateway",
JsonSchema(AgentResponse)
)
self.embeddings_out = Publisher(
pulsar_host, embeddings_request_queue,
schema=JsonSchema(EmbeddingsRequest)
)
self.embeddings_in = Subscriber(
pulsar_host, embeddings_response_queue,
"api-gateway", "api-gateway",
JsonSchema(EmbeddingsResponse)
)
self.app.add_routes([
web.post("/api/v1/text-completion", self.llm),
web.post("/api/v1/prompt", self.prompt),
web.post("/api/v1/graph-rag", self.graph_rag),
web.post("/api/v1/triples-query", self.triples_query),
web.post("/api/v1/agent", self.agent),
web.post("/api/v1/embeddings", self.embeddings),
])
async def llm(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.llm_in.subscribe(id)
await self.llm_out.send(
id,
TextCompletionRequest(
system=data["system"],
prompt=data["prompt"]
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{ "response": resp.response }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.llm_in.unsubscribe(id)
async def prompt(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.prompt_in.subscribe(id)
terms = {
k: json.dumps(v)
for k, v in data["variables"].items()
}
await self.prompt_out.send(
id,
PromptRequest(
id=data["id"],
terms=terms
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
if resp.object:
return web.json_response(
{ "object": resp.object }
)
return web.json_response(
{ "text": resp.text }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.prompt_in.unsubscribe(id)
async def graph_rag(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.graph_rag_in.subscribe(id)
await self.graph_rag_out.send(
id,
GraphRagQuery(
query=data["query"],
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{ "response": resp.response }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.graph_rag_in.unsubscribe(id)
async def triples_query(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.triples_query_in.subscribe(id)
if "s" in data:
if data["s"].startswith("http:") or data["s"].startswith("https:"):
s = Value(value=data["s"], is_uri=True)
else:
s = Value(value=data["s"], is_uri=True)
else:
s = None
if "p" in data:
if data["p"].startswith("http:") or data["p"].startswith("https:"):
p = Value(value=data["p"], is_uri=True)
else:
p = Value(value=data["p"], is_uri=True)
else:
p = None
if "o" in data:
if data["o"].startswith("http:") or data["o"].startswith("https:"):
o = Value(value=data["o"], is_uri=True)
else:
o = Value(value=data["o"], is_uri=True)
else:
o = None
limit = int(data.get("limit", 10000))
await self.triples_query_out.send(
id,
TriplesQueryRequest(
s = s, p = p, o = o,
limit = limit,
user = data.get("user", "trustgraph"),
collection = data.get("collection", "default"),
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{
"response": [
{
"s": {
"v": t.s.value,
"e": t.s.is_uri,
},
"p": {
"v": t.p.value,
"e": t.p.is_uri,
},
"o": {
"v": t.o.value,
"e": t.o.is_uri,
}
}
for t in resp.triples
]
}
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.graph_rag_in.unsubscribe(id)
async def agent(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.agent_in.subscribe(id)
await self.agent_out.send(
id,
AgentRequest(
question=data["question"],
)
)
while True:
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
if resp.answer: break
if resp.thought: print("thought:", resp.thought)
if resp.observation: print("observation:", resp.observation)
if resp.answer:
return web.json_response(
{ "answer": resp.answer }
)
# Can't happen, ook at the logic
raise RuntimeError("Strange state")
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.agent_in.unsubscribe(id)
async def embeddings(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.embeddings_in.subscribe(id)
await self.embeddings_out.send(
id,
EmbeddingsRequest(
text=data["text"],
)
)
try:
resp = await asyncio.wait_for(q.get(), TIME_OUT)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{ "vectors": resp.vectors }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.embeddings_in.unsubscribe(id)
async def app_factory(self):
self.llm_pub_task = asyncio.create_task(self.llm_in.run())
self.llm_sub_task = asyncio.create_task(self.llm_out.run())
self.prompt_pub_task = asyncio.create_task(self.prompt_in.run())
self.prompt_sub_task = asyncio.create_task(self.prompt_out.run())
self.graph_rag_pub_task = asyncio.create_task(self.graph_rag_in.run())
self.graph_rag_sub_task = asyncio.create_task(self.graph_rag_out.run())
self.triples_query_pub_task = asyncio.create_task(
self.triples_query_in.run()
)
self.triples_query_sub_task = asyncio.create_task(
self.triples_query_out.run()
)
self.agent_pub_task = asyncio.create_task(self.agent_in.run())
self.agent_sub_task = asyncio.create_task(self.agent_out.run())
self.embeddings_pub_task = asyncio.create_task(
self.embeddings_in.run()
)
self.embeddings_sub_task = asyncio.create_task(
self.embeddings_out.run()
)
return self.app
def run(self):
web.run_app(self.app_factory(), port=self.port)
a = Api()
a.run()

View file

@ -58,6 +58,7 @@ setuptools.setup(
"google-generativeai", "google-generativeai",
"ibis", "ibis",
"jsonschema", "jsonschema",
"aiohttp",
], ],
scripts=[ scripts=[
"scripts/agent-manager-react", "scripts/agent-manager-react",