Feature/wikipedia ddg (#185)

API-side support for Wikipedia, DBpedia and internet search functions  This incorporates a refactor of the API code to break it up, separate classes for endpoints to reduce duplication
This commit is contained in:
cybermaggedon 2024-12-02 17:41:30 +00:00 committed by GitHub
parent 212102c61c
commit 6d200c79c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
50 changed files with 1287 additions and 826 deletions

View file

@ -48,7 +48,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -46,7 +46,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -53,7 +53,7 @@ local chunker = import "chunker-recursive.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -45,7 +45,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -43,7 +43,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_limits("0.5", "128M") .with_limits("0.5", "128M")
.with_reservations("0.1", "128M"); .with_reservations("0.1", "128M");

View file

@ -19,7 +19,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"--prompt-request-queue", "--prompt-request-queue",
"non-persistent://tg/request/prompt-rag", "non-persistent://tg/request/prompt-rag",
"--prompt-response-queue", "--prompt-response-queue",
"non-persistent://tg/response/prompt-rag-response", "non-persistent://tg/response/prompt-rag",
]) ])
.with_limits("0.5", "128M") .with_limits("0.5", "128M")
.with_reservations("0.1", "128M"); .with_reservations("0.1", "128M");

View file

@ -50,7 +50,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -112,7 +112,7 @@ local url = import "values/url.jsonnet";
"--prompt-request-queue", "--prompt-request-queue",
"non-persistent://tg/request/prompt-rag", "non-persistent://tg/request/prompt-rag",
"--prompt-response-queue", "--prompt-response-queue",
"non-persistent://tg/response/prompt-rag-response", "non-persistent://tg/response/prompt-rag",
"--entity-limit", "--entity-limit",
std.toString($["graph-rag-entity-limit"]), std.toString($["graph-rag-entity-limit"]),
"--triple-limit", "--triple-limit",

View file

@ -40,7 +40,7 @@ local prompts = import "prompts/slm.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -40,7 +40,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -50,7 +50,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_env_var_secrets(envSecrets) .with_env_var_secrets(envSecrets)
.with_limits("0.5", "128M") .with_limits("0.5", "128M")

View file

@ -53,7 +53,7 @@ local default_prompts = import "prompts/default-prompts.jsonnet";
"--text-completion-request-queue", "--text-completion-request-queue",
"non-persistent://tg/request/text-completion", "non-persistent://tg/request/text-completion",
"--text-completion-response-queue", "--text-completion-response-queue",
"non-persistent://tg/response/text-completion-response", "non-persistent://tg/response/text-completion",
"--system-prompt", "--system-prompt",
$["prompts"]["system-template"], $["prompts"]["system-template"],
@ -92,11 +92,11 @@ local default_prompts = import "prompts/default-prompts.jsonnet";
"-i", "-i",
"non-persistent://tg/request/prompt-rag", "non-persistent://tg/request/prompt-rag",
"-o", "-o",
"non-persistent://tg/response/prompt-rag-response", "non-persistent://tg/response/prompt-rag",
"--text-completion-request-queue", "--text-completion-request-queue",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"--text-completion-response-queue", "--text-completion-response-queue",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
"--system-prompt", "--system-prompt",
$["prompts"]["system-template"], $["prompts"]["system-template"],

View file

@ -186,7 +186,7 @@ local prompt = import "prompt-template.jsonnet";
"-p", "-p",
url.pulsar, url.pulsar,
"-i", "-i",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_limits("0.5", "128M") .with_limits("0.5", "128M")
.with_reservations("0.1", "128M"); .with_reservations("0.1", "128M");

View file

@ -93,7 +93,7 @@ local prompts = import "prompts/mixtral.jsonnet";
"-i", "-i",
"non-persistent://tg/request/text-completion-rag", "non-persistent://tg/request/text-completion-rag",
"-o", "-o",
"non-persistent://tg/response/text-completion-rag-response", "non-persistent://tg/response/text-completion-rag",
]) ])
.with_limits("0.5", "256M") .with_limits("0.5", "256M")
.with_reservations("0.1", "256M") .with_reservations("0.1", "256M")

28
test-api/test-agent2-api Executable file
View file

@ -0,0 +1,28 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"question": "What is 14 plus 12. Justify your answer.",
}
resp = requests.post(
f"{url}agent",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["answer"])

30
test-api/test-dbpedia Executable file
View file

@ -0,0 +1,30 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"term": "Cornwall",
}
resp = requests.post(
f"{url}dbpedia",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["text"])
sys.exit(0)
############################################################################

30
test-api/test-encyclopedia Executable file
View file

@ -0,0 +1,30 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"term": "Cornwall",
}
resp = requests.post(
f"{url}encyclopedia",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["text"])
sys.exit(0)
############################################################################

30
test-api/test-internet-search Executable file
View file

@ -0,0 +1,30 @@
#!/usr/bin/env python3
import requests
import json
import sys
url = "http://localhost:8088/api/v1/"
############################################################################
input = {
"term": "Cornwall",
}
resp = requests.post(
f"{url}internet-search",
json=input,
)
resp = resp.json()
if "error" in resp:
print(f"Error: {resp['error']}")
sys.exit(1)
print(resp["text"])
sys.exit(0)
############################################################################

View file

@ -22,7 +22,6 @@ resp = requests.post(
resp = resp.json() resp = resp.json()
print(resp)
if "error" in resp: if "error" in resp:
print(f"Error: {resp['error']}") print(f"Error: {resp['error']}")
sys.exit(1) sys.exit(1)

View file

@ -22,7 +22,6 @@ resp = requests.post(
resp = resp.json() resp = resp.json()
print(resp)
if "error" in resp: if "error" in resp:
print(f"Error: {resp['error']}") print(f"Error: {resp['error']}")
sys.exit(1) sys.exit(1)

View file

@ -9,7 +9,10 @@ url = "http://localhost:8088/api/v1/"
############################################################################ ############################################################################
input = { input = {
"p": "http://www.w3.org/2000/01/rdf-schema#label", "p": {
"v": "http://www.w3.org/2000/01/rdf-schema#label",
"e": True,
},
"limit": 10 "limit": 10
} }

View file

@ -9,4 +9,6 @@ from . graph import *
from . retrieval import * from . retrieval import *
from . metadata import * from . metadata import *
from . agent import * from . agent import *
from . lookup import *

View file

@ -0,0 +1,42 @@
from pulsar.schema import Record, String
from . types import Error, Value, Triple
from . topic import topic
from . metadata import Metadata
############################################################################
# Lookups
class LookupRequest(Record):
kind = String()
term = String()
class LookupResponse(Record):
text = String()
error = Error()
encyclopedia_lookup_request_queue = topic(
'encyclopedia', kind='non-persistent', namespace='request'
)
encyclopedia_lookup_response_queue = topic(
'encyclopedia', kind='non-persistent', namespace='response',
)
dbpedia_lookup_request_queue = topic(
'dbpedia', kind='non-persistent', namespace='request'
)
dbpedia_lookup_response_queue = topic(
'dbpedia', kind='non-persistent', namespace='response',
)
internet_search_request_queue = topic(
'internet-search', kind='non-persistent', namespace='request'
)
internet_search_response_queue = topic(
'internet-search', kind='non-persistent', namespace='response',
)
############################################################################

View file

@ -93,7 +93,6 @@ async def loader(ge_queue, t_queue, path, format, user, collection):
if collection: if collection:
unpacked["metadata"]["collection"] = collection unpacked["metadata"]["collection"] = collection
if unpacked[0] == "t": if unpacked[0] == "t":
await t_queue.put(unpacked[1]) await t_queue.put(unpacked[1])
t_counts += 1 t_counts += 1

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.external.wikipedia import run
run()

View file

@ -106,5 +106,6 @@ setuptools.setup(
"scripts/triples-query-neo4j", "scripts/triples-query-neo4j",
"scripts/triples-write-cassandra", "scripts/triples-write-cassandra",
"scripts/triples-write-neo4j", "scripts/triples-write-neo4j",
"scripts/wikipedia-lookup",
] ]
) )

View file

@ -0,0 +1,30 @@
from ... schema import AgentRequest, AgentResponse
from ... schema import agent_request_queue
from ... schema import agent_response_queue
from . endpoint import MultiResponseServiceEndpoint
class AgentEndpoint(MultiResponseServiceEndpoint):
def __init__(self, pulsar_host, timeout):
super(AgentEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=agent_request_queue,
response_queue=agent_response_queue,
request_schema=AgentRequest,
response_schema=AgentResponse,
endpoint_path="/api/v1/agent",
timeout=timeout,
)
def to_request(self, body):
return AgentRequest(
question=body["question"]
)
def from_response(self, message):
if message.answer:
return { "answer": message.answer }, True
else:
return {}, False

View file

@ -0,0 +1,29 @@
from ... schema import LookupRequest, LookupResponse
from ... schema import dbpedia_lookup_request_queue
from ... schema import dbpedia_lookup_response_queue
from . endpoint import ServiceEndpoint
class DbpediaEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
super(DbpediaEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=dbpedia_lookup_request_queue,
response_queue=dbpedia_lookup_response_queue,
request_schema=LookupRequest,
response_schema=LookupResponse,
endpoint_path="/api/v1/dbpedia",
timeout=timeout,
)
def to_request(self, body):
return LookupRequest(
term=body["term"],
kind=body.get("kind", None),
)
def from_response(self, message):
return { "text": message.text }

View file

@ -0,0 +1,27 @@
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue
from . endpoint import ServiceEndpoint
class EmbeddingsEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
super(EmbeddingsEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=embeddings_request_queue,
response_queue=embeddings_response_queue,
request_schema=EmbeddingsRequest,
response_schema=EmbeddingsResponse,
endpoint_path="/api/v1/embeddings",
timeout=timeout,
)
def to_request(self, body):
return EmbeddingsRequest(
text=body["text"]
)
def from_response(self, message):
return { "vectors": message.vectors }

View file

@ -0,0 +1,29 @@
from ... schema import LookupRequest, LookupResponse
from ... schema import encyclopedia_lookup_request_queue
from ... schema import encyclopedia_lookup_response_queue
from . endpoint import ServiceEndpoint
class EncyclopediaEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
super(EncyclopediaEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=encyclopedia_lookup_request_queue,
response_queue=encyclopedia_lookup_response_queue,
request_schema=LookupRequest,
response_schema=LookupResponse,
endpoint_path="/api/v1/encyclopedia",
timeout=timeout,
)
def to_request(self, body):
return LookupRequest(
term=body["term"],
kind=body.get("kind", None),
)
def from_response(self, message):
return { "text": message.text }

View file

@ -0,0 +1,153 @@
import asyncio
from pulsar.schema import JsonSchema
from aiohttp import web
import uuid
import logging
from . publisher import Publisher
from . subscriber import Subscriber
logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO)
class ServiceEndpoint:
def __init__(
self,
pulsar_host,
request_queue, request_schema,
response_queue, response_schema,
endpoint_path,
subscription="api-gateway", consumer_name="api-gateway",
timeout=600,
):
self.pub = Publisher(
pulsar_host, request_queue,
schema=JsonSchema(request_schema)
)
self.sub = Subscriber(
pulsar_host, response_queue,
subscription, consumer_name,
JsonSchema(response_schema)
)
self.path = endpoint_path
self.timeout = timeout
async def start(self):
self.pub_task = asyncio.create_task(self.pub.run())
self.sub_task = asyncio.create_task(self.sub.run())
def add_routes(self, app):
app.add_routes([
web.post(self.path, self.handle),
])
def to_request(self, request):
raise RuntimeError("Not defined")
def from_response(self, response):
raise RuntimeError("Not defined")
async def handle(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.sub.subscribe(id)
print(data)
await self.pub.send(
id,
self.to_request(data),
)
try:
resp = await asyncio.wait_for(q.get(), self.timeout)
except:
raise RuntimeError("Timeout waiting for response")
print(resp)
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
self.from_response(resp)
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.sub.unsubscribe(id)
class MultiResponseServiceEndpoint(ServiceEndpoint):
async def handle(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.sub.subscribe(id)
print(data)
await self.pub.send(
id,
self.to_request(data),
)
# Keeps looking at responses...
while True:
try:
resp = await asyncio.wait_for(q.get(), self.timeout)
except:
raise RuntimeError("Timeout waiting for response")
print(resp)
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
# Until from_response says we have a finished answer
resp, fin = self.from_response(resp)
if fin:
return web.json_response(resp)
# Not finished, so loop round and continue
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.sub.unsubscribe(id)

View file

@ -0,0 +1,60 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
from ... schema import Metadata
from ... schema import GraphEmbeddings
from ... schema import graph_embeddings_store_queue
from . publisher import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph, to_value
class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, path="/api/v1/load/graph-embeddings"):
super(GraphEmbeddingsLoadEndpoint, self).__init__(
endpoint_path=path
)
self.pulsar_host=pulsar_host
self.publisher = Publisher(
self.pulsar_host, graph_embeddings_store_queue,
schema=JsonSchema(GraphEmbeddings)
)
async def start(self):
self.task = asyncio.create_task(
self.publisher.run()
)
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = GraphEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entity=to_value(data["entity"]),
vectors=data["vectors"],
)
await self.publisher.send(None, elt)
running.stop()

View file

@ -0,0 +1,56 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from ... schema import GraphEmbeddings
from ... schema import graph_embeddings_store_queue
from . subscriber import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_graph_embeddings
class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, path="/api/v1/stream/graph-embeddings"):
super(GraphEmbeddingsStreamEndpoint, self).__init__(
endpoint_path=path
)
self.pulsar_host=pulsar_host
self.subscriber = Subscriber(
self.pulsar_host, graph_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(GraphEmbeddings)
)
async def start(self):
self.task = asyncio.create_task(
self.subscriber.run()
)
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = await self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.wait_for(q.get(), 0.5)
await ws.send_json(serialize_graph_embeddings(resp))
except TimeoutError:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -0,0 +1,30 @@
from ... schema import GraphRagQuery, GraphRagResponse
from ... schema import graph_rag_request_queue
from ... schema import graph_rag_response_queue
from . endpoint import ServiceEndpoint
class GraphRagEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
super(GraphRagEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=graph_rag_request_queue,
response_queue=graph_rag_response_queue,
request_schema=GraphRagQuery,
response_schema=GraphRagResponse,
endpoint_path="/api/v1/graph-rag",
timeout=timeout,
)
def to_request(self, body):
return GraphRagQuery(
query=body["query"],
user=body.get("user", "trustgraph"),
collection=body.get("collection", "default"),
)
def from_response(self, message):
return { "response": message.response }

View file

@ -0,0 +1,29 @@
from ... schema import LookupRequest, LookupResponse
from ... schema import internet_search_request_queue
from ... schema import internet_search_response_queue
from . endpoint import ServiceEndpoint
class InternetSearchEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
super(InternetSearchEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=internet_search_request_queue,
response_queue=internet_search_response_queue,
request_schema=LookupRequest,
response_schema=LookupResponse,
endpoint_path="/api/v1/internet-search",
timeout=timeout,
)
def to_request(self, body):
return LookupRequest(
term=body["term"],
kind=body.get("kind", None),
)
def from_response(self, message):
return { "text": message.text }

View file

@ -0,0 +1,41 @@
import json
from ... schema import PromptRequest, PromptResponse
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from . endpoint import ServiceEndpoint
class PromptEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
super(PromptEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=prompt_request_queue,
response_queue=prompt_response_queue,
request_schema=PromptRequest,
response_schema=PromptResponse,
endpoint_path="/api/v1/prompt",
timeout=timeout,
)
def to_request(self, body):
return PromptRequest(
id=body["id"],
terms={
k: json.dumps(v)
for k, v in body["variables"].items()
}
)
def from_response(self, message):
if message.object:
return {
"object": message.object
}
else:
return {
"text": message.text
}

View file

@ -0,0 +1,41 @@
import asyncio
import aiopulsar
class Publisher:
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
chunking_enabled=False):
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = asyncio.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
async def run(self):
while True:
try:
async with aiopulsar.connect(self.pulsar_host) as client:
async with client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
) as producer:
while True:
id, item = await self.q.get()
if id:
await producer.send(item, { "id": id })
else:
await producer.send(item)
except Exception as e:
print("Exception:", e, flush=True)
# If handler drops out, sleep a retry
await asyncio.sleep(2)
async def send(self, id, msg):
await self.q.put((id, msg))

View file

@ -0,0 +1,5 @@
class Running:
def __init__(self): self.running = True
def get(self): return self.running
def stop(self): self.running = False

View file

@ -0,0 +1,57 @@
from ... schema import Value, Triple
def to_value(x):
return Value(value=x["v"], is_uri=x["e"])
def to_subgraph(x):
return [
Triple(
s=to_value(t["s"]),
p=to_value(t["p"]),
o=to_value(t["o"])
)
for t in x
]
def serialize_value(v):
return {
"v": v.value,
"e": v.is_uri,
}
def serialize_triple(t):
return {
"s": serialize_value(t.s),
"p": serialize_value(t.p),
"o": serialize_value(t.o)
}
def serialize_subgraph(sg):
return [
serialize_triple(t)
for t in sg
]
def serialize_triples(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"triples": serialize_subgraph(message.triples),
}
def serialize_graph_embeddings(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"vectors": message.vectors,
"entity": serialize_value(message.entity),
}

View file

@ -1,4 +1,3 @@
""" """
API gateway. Offers HTTP services which are translated to interaction on the API gateway. Offers HTTP services which are translated to interaction on the
Pulsar bus. Pulsar bus.
@ -14,57 +13,39 @@ module = ".".join(__name__.split(".")[1:-1])
import asyncio import asyncio
import argparse import argparse
from aiohttp import web, WSMsgType from aiohttp import web
import json
import logging import logging
import uuid
import os import os
import base64 import base64
import pulsar import pulsar
from pulsar.asyncio import Client
from pulsar.schema import JsonSchema from pulsar.schema import JsonSchema
import _pulsar
import aiopulsar
from prometheus_client import start_http_server from prometheus_client import start_http_server
from ... log_level import LogLevel from ... log_level import LogLevel
from trustgraph.clients.llm_client import LlmClient from ... schema import Metadata, Document, TextDocument
from trustgraph.clients.prompt_client import PromptClient
from ... schema import Value, Metadata, Document, TextDocument, Triple
from ... schema import TextCompletionRequest, TextCompletionResponse
from ... schema import text_completion_request_queue
from ... schema import text_completion_response_queue
from ... schema import PromptRequest, PromptResponse
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from ... schema import GraphRagQuery, GraphRagResponse
from ... schema import graph_rag_request_queue
from ... schema import graph_rag_response_queue
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples
from ... schema import triples_request_queue
from ... schema import triples_response_queue
from ... schema import triples_store_queue
from ... schema import GraphEmbeddings
from ... schema import graph_embeddings_store_queue
from ... schema import AgentRequest, AgentResponse
from ... schema import agent_request_queue
from ... schema import agent_response_queue
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue
from ... schema import document_ingest_queue, text_ingest_queue from ... schema import document_ingest_queue, text_ingest_queue
from . serialize import to_subgraph
from . running import Running
from . publisher import Publisher
from . subscriber import Subscriber
from . endpoint import ServiceEndpoint, MultiResponseServiceEndpoint
from . text_completion import TextCompletionEndpoint
from . prompt import PromptEndpoint
from . graph_rag import GraphRagEndpoint
from . triples_query import TriplesQueryEndpoint
from . embeddings import EmbeddingsEndpoint
from . encyclopedia import EncyclopediaEndpoint
from . agent import AgentEndpoint
from . dbpedia import DbpediaEndpoint
from . internet_search import InternetSearchEndpoint
from . triples_stream import TriplesStreamEndpoint
from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
from . triples_load import TriplesLoadEndpoint
from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
logger = logging.getLogger("api") logger = logging.getLogger("api")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
@ -72,168 +53,6 @@ default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
default_timeout = 600 default_timeout = 600
default_port = 8088 default_port = 8088
def to_value(x):
return Value(value=x["v"], is_uri=x["e"])
def to_subgraph(x):
return [
Triple(
s=to_value(t["s"]),
p=to_value(t["p"]),
o=to_value(t["o"])
)
for t in x
]
class Running:
def __init__(self): self.running = True
def get(self): return self.running
def stop(self): self.running = False
class Publisher:
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
chunking_enabled=False):
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = asyncio.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
async def run(self):
while True:
try:
async with aiopulsar.connect(self.pulsar_host) as client:
async with client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
) as producer:
while True:
id, item = await self.q.get()
if id:
await producer.send(item, { "id": id })
else:
await producer.send(item)
except Exception as e:
print("Exception:", e, flush=True)
# If handler drops out, sleep a retry
await asyncio.sleep(2)
async def send(self, id, msg):
await self.q.put((id, msg))
class Subscriber:
def __init__(self, pulsar_host, topic, subscription, consumer_name,
schema=None, max_size=10):
self.pulsar_host = pulsar_host
self.topic = topic
self.subscription = subscription
self.consumer_name = consumer_name
self.schema = schema
self.q = {}
self.full = {}
async def run(self):
while True:
try:
async with aiopulsar.connect(self.pulsar_host) as client:
async with client.subscribe(
topic=self.topic,
subscription_name=self.subscription,
consumer_name=self.consumer_name,
schema=self.schema,
) as consumer:
while True:
msg = await consumer.receive()
# Acknowledge successful reception of the message
await consumer.acknowledge(msg)
try:
id = msg.properties()["id"]
except:
id = None
value = msg.value()
if id in self.q:
await self.q[id].put(value)
for q in self.full.values():
await q.put(value)
except Exception as e:
print("Exception:", e, flush=True)
# If handler drops out, sleep a retry
await asyncio.sleep(2)
async def subscribe(self, id):
q = asyncio.Queue()
self.q[id] = q
return q
async def unsubscribe(self, id):
if id in self.q:
del self.q[id]
async def subscribe_all(self, id):
q = asyncio.Queue()
self.full[id] = q
return q
async def unsubscribe_all(self, id):
if id in self.full:
del self.full[id]
def serialize_value(v):
return {
"v": v.value,
"e": v.is_uri,
}
def serialize_triple(t):
return {
"s": serialize_value(t.s),
"p": serialize_value(t.p),
"o": serialize_value(t.o)
}
def serialize_subgraph(sg):
return [
serialize_triple(t)
for t in sg
]
def serialize_triples(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"triples": serialize_subgraph(message.triples),
}
def serialize_graph_embeddings(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"vectors": message.vectors,
"entity": message.entity,
}
class Api: class Api:
def __init__(self, **config): def __init__(self, **config):
@ -247,93 +66,47 @@ class Api:
self.timeout = int(config.get("timeout", default_timeout)) self.timeout = int(config.get("timeout", default_timeout))
self.pulsar_host = config.get("pulsar_host", default_pulsar_host) self.pulsar_host = config.get("pulsar_host", default_pulsar_host)
self.llm_out = Publisher( self.endpoints = [
self.pulsar_host, text_completion_request_queue, TextCompletionEndpoint(
schema=JsonSchema(TextCompletionRequest) pulsar_host=self.pulsar_host, timeout=self.timeout,
) ),
PromptEndpoint(
self.llm_in = Subscriber( pulsar_host=self.pulsar_host, timeout=self.timeout,
self.pulsar_host, text_completion_response_queue, ),
"api-gateway", "api-gateway", GraphRagEndpoint(
JsonSchema(TextCompletionResponse) pulsar_host=self.pulsar_host, timeout=self.timeout,
) ),
TriplesQueryEndpoint(
self.prompt_out = Publisher( pulsar_host=self.pulsar_host, timeout=self.timeout,
self.pulsar_host, prompt_request_queue, ),
schema=JsonSchema(PromptRequest) EmbeddingsEndpoint(
) pulsar_host=self.pulsar_host, timeout=self.timeout,
),
self.prompt_in = Subscriber( AgentEndpoint(
self.pulsar_host, prompt_response_queue, pulsar_host=self.pulsar_host, timeout=self.timeout,
"api-gateway", "api-gateway", ),
JsonSchema(PromptResponse) EncyclopediaEndpoint(
) pulsar_host=self.pulsar_host, timeout=self.timeout,
),
self.graph_rag_out = Publisher( DbpediaEndpoint(
self.pulsar_host, graph_rag_request_queue, pulsar_host=self.pulsar_host, timeout=self.timeout,
schema=JsonSchema(GraphRagQuery) ),
) InternetSearchEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
self.graph_rag_in = Subscriber( ),
self.pulsar_host, graph_rag_response_queue, TriplesStreamEndpoint(
"api-gateway", "api-gateway", pulsar_host=self.pulsar_host
JsonSchema(GraphRagResponse) ),
) GraphEmbeddingsStreamEndpoint(
pulsar_host=self.pulsar_host
self.triples_query_out = Publisher( ),
self.pulsar_host, triples_request_queue, TriplesLoadEndpoint(
schema=JsonSchema(TriplesQueryRequest) pulsar_host=self.pulsar_host
) ),
GraphEmbeddingsLoadEndpoint(
self.triples_query_in = Subscriber( pulsar_host=self.pulsar_host
self.pulsar_host, triples_response_queue, ),
"api-gateway", "api-gateway", ]
JsonSchema(TriplesQueryResponse)
)
self.agent_out = Publisher(
self.pulsar_host, agent_request_queue,
schema=JsonSchema(AgentRequest)
)
self.agent_in = Subscriber(
self.pulsar_host, agent_response_queue,
"api-gateway", "api-gateway",
JsonSchema(AgentResponse)
)
self.embeddings_out = Publisher(
self.pulsar_host, embeddings_request_queue,
schema=JsonSchema(EmbeddingsRequest)
)
self.embeddings_in = Subscriber(
self.pulsar_host, embeddings_response_queue,
"api-gateway", "api-gateway",
JsonSchema(EmbeddingsResponse)
)
self.triples_tap = Subscriber(
self.pulsar_host, triples_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(Triples)
)
self.triples_pub = Publisher(
self.pulsar_host, triples_store_queue,
schema=JsonSchema(Triples)
)
self.graph_embeddings_tap = Subscriber(
self.pulsar_host, graph_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(GraphEmbeddings)
)
self.graph_embeddings_pub = Publisher(
self.pulsar_host, graph_embeddings_store_queue,
schema=JsonSchema(GraphEmbeddings)
)
self.document_out = Publisher( self.document_out = Publisher(
self.pulsar_host, document_ingest_queue, self.pulsar_host, document_ingest_queue,
@ -347,323 +120,14 @@ class Api:
chunking_enabled=True, chunking_enabled=True,
) )
for ep in self.endpoints:
ep.add_routes(self.app)
self.app.add_routes([ self.app.add_routes([
web.post("/api/v1/text-completion", self.llm),
web.post("/api/v1/prompt", self.prompt),
web.post("/api/v1/graph-rag", self.graph_rag),
web.post("/api/v1/triples-query", self.triples_query),
web.post("/api/v1/agent", self.agent),
web.post("/api/v1/embeddings", self.embeddings),
web.post("/api/v1/load/document", self.load_document), web.post("/api/v1/load/document", self.load_document),
web.post("/api/v1/load/text", self.load_text), web.post("/api/v1/load/text", self.load_text),
web.get("/api/v1/ws", self.socket),
web.get("/api/v1/stream/triples", self.stream_triples),
web.get(
"/api/v1/stream/graph-embeddings",
self.stream_graph_embeddings
),
web.get("/api/v1/load/triples", self.load_triples),
web.get(
"/api/v1/load/graph-embeddings",
self.load_graph_embeddings
),
]) ])
async def llm(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.llm_in.subscribe(id)
await self.llm_out.send(
id,
TextCompletionRequest(
system=data["system"],
prompt=data["prompt"]
)
)
try:
resp = await asyncio.wait_for(q.get(), self.timeout)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{ "response": resp.response }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.llm_in.unsubscribe(id)
async def prompt(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.prompt_in.subscribe(id)
terms = {
k: json.dumps(v)
for k, v in data["variables"].items()
}
await self.prompt_out.send(
id,
PromptRequest(
id=data["id"],
terms=terms
)
)
try:
resp = await asyncio.wait_for(q.get(), self.timeout)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
if resp.object:
return web.json_response(
{ "object": resp.object }
)
return web.json_response(
{ "text": resp.text }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.prompt_in.unsubscribe(id)
async def graph_rag(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.graph_rag_in.subscribe(id)
await self.graph_rag_out.send(
id,
GraphRagQuery(
query=data["query"],
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
)
)
try:
resp = await asyncio.wait_for(q.get(), self.timeout)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{ "response": resp.response }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.graph_rag_in.unsubscribe(id)
async def triples_query(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.triples_query_in.subscribe(id)
if "s" in data:
s = to_value(data["s"])
else:
s = None
if "p" in data:
p = to_value(data["p"])
else:
p = None
if "o" in data:
o = to_value(data["o"])
else:
o = None
limit = int(data.get("limit", 10000))
await self.triples_query_out.send(
id,
TriplesQueryRequest(
s = s, p = p, o = o,
limit = limit,
user = data.get("user", "trustgraph"),
collection = data.get("collection", "default"),
)
)
try:
resp = await asyncio.wait_for(q.get(), self.timeout)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{
"response": serialize_subgraph(resp.triples),
}
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.graph_rag_in.unsubscribe(id)
async def agent(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.agent_in.subscribe(id)
await self.agent_out.send(
id,
AgentRequest(
question=data["question"],
)
)
while True:
try:
resp = await asyncio.wait_for(q.get(), self.timeout)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
if resp.answer: break
if resp.thought: print("thought:", resp.thought)
if resp.observation: print("observation:", resp.observation)
if resp.answer:
return web.json_response(
{ "answer": resp.answer }
)
# Can't happen, ook at the logic
raise RuntimeError("Strange state")
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.agent_in.unsubscribe(id)
async def embeddings(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = await self.embeddings_in.subscribe(id)
await self.embeddings_out.send(
id,
EmbeddingsRequest(
text=data["text"],
)
)
try:
resp = await asyncio.wait_for(q.get(), self.timeout)
except:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
{ "vectors": resp.vectors }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
await self.embeddings_in.unsubscribe(id)
async def load_document(self, request): async def load_document(self, request):
try: try:
@ -750,215 +214,12 @@ class Api:
{ "error": str(e) } { "error": str(e) }
) )
async def socket(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
if msg.type == WSMsgType.TEXT:
if msg.data == 'close':
await ws.close()
else:
await ws.send_str(msg.data + '/answer')
elif msg.type == WSMsgType.ERROR:
print('ws connection closed with exception %s' %
ws.exception())
print('websocket connection closed')
return ws
async def stream(self, q, ws, running, fn):
while running.get():
try:
resp = await asyncio.wait_for(q.get(), 0.5)
await ws.send_json(fn(resp))
except TimeoutError:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
async def stream_triples(self, request):
id = str(uuid.uuid4())
q = await self.triples_tap.subscribe_all(id)
running = Running()
ws = web.WebSocketResponse()
await ws.prepare(request)
tsk = asyncio.create_task(self.stream(
q,
ws,
running,
serialize_triples,
))
async for msg in ws:
if msg.type == WSMsgType.ERROR:
break
else:
# Ignore incoming messages
pass
running.stop()
await self.triples_tap.unsubscribe_all(id)
await tsk
return ws
async def stream_graph_embeddings(self, request):
id = str(uuid.uuid4())
q = await self.graph_embeddings_tap.subscribe_all(id)
running = Running()
ws = web.WebSocketResponse()
await ws.prepare(request)
tsk = asyncio.create_task(self.stream(
q,
ws,
running,
serialize_graph_embeddings,
))
async for msg in ws:
if msg.type == WSMsgType.ERROR:
break
else:
# Ignore incoming messages
pass
running.stop()
await self.graph_embeddings_tap.unsubscribe_all(id)
await tsk
return ws
async def load_triples(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
try:
if msg.type == WSMsgType.TEXT:
data = msg.json()
elt = Triples(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
triples=to_subgraph(data["triples"]),
)
await self.triples_pub.send(None, elt)
elif msg.type == WSMsgType.ERROR:
break
except Exception as e:
print("Exception:", e)
return ws
async def load_graph_embeddings(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
try:
if msg.type == WSMsgType.TEXT:
data = msg.json()
elt = GraphEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entity=to_value(data["entity"]),
vectors=data["vectors"],
)
await self.graph_embeddings_pub.send(None, elt)
elif msg.type == WSMsgType.ERROR:
break
except Exception as e:
print("Exception:", e)
return ws
async def app_factory(self): async def app_factory(self):
self.llm_pub_task = asyncio.create_task(self.llm_in.run()) for ep in self.endpoints:
self.llm_sub_task = asyncio.create_task(self.llm_out.run()) await ep.start()
self.prompt_pub_task = asyncio.create_task(self.prompt_in.run())
self.prompt_sub_task = asyncio.create_task(self.prompt_out.run())
self.graph_rag_pub_task = asyncio.create_task(self.graph_rag_in.run())
self.graph_rag_sub_task = asyncio.create_task(self.graph_rag_out.run())
self.triples_query_pub_task = asyncio.create_task(
self.triples_query_in.run()
)
self.triples_query_sub_task = asyncio.create_task(
self.triples_query_out.run()
)
self.agent_pub_task = asyncio.create_task(self.agent_in.run())
self.agent_sub_task = asyncio.create_task(self.agent_out.run())
self.embeddings_pub_task = asyncio.create_task(
self.embeddings_in.run()
)
self.embeddings_sub_task = asyncio.create_task(
self.embeddings_out.run()
)
self.triples_tap_task = asyncio.create_task(
self.triples_tap.run()
)
self.triples_pub_task = asyncio.create_task(
self.triples_pub.run()
)
self.graph_embeddings_tap_task = asyncio.create_task(
self.graph_embeddings_tap.run()
)
self.graph_embeddings_pub_task = asyncio.create_task(
self.graph_embeddings_pub.run()
)
self.doc_ingest_pub_task = asyncio.create_task(self.document_out.run()) self.doc_ingest_pub_task = asyncio.create_task(self.document_out.run())
self.text_ingest_pub_task = asyncio.create_task(self.text_out.run()) self.text_ingest_pub_task = asyncio.create_task(self.text_out.run())
return self.app return self.app

View file

@ -0,0 +1,68 @@
import asyncio
from aiohttp import web, WSMsgType
import logging
from . running import Running
logger = logging.getLogger("socket")
logger.setLevel(logging.INFO)
class SocketEndpoint:
def __init__(
self,
endpoint_path="/api/v1/socket",
):
self.path = endpoint_path
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
# Ignore incoming messages
pass
running.stop()
async def async_thread(self, ws, running):
while running.get():
try:
await asyncio.sleep(1)
except TimeoutError:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
async def handle(self, request):
running = Running()
ws = web.WebSocketResponse()
await ws.prepare(request)
task = asyncio.create_task(self.async_thread(ws, running))
await self.listener(ws, running)
await task
running.stop()
return ws
async def start(self):
pass
def add_routes(self, app):
app.add_routes([
web.get(self.path, self.handle),
])

View file

@ -0,0 +1,68 @@
import asyncio
import aiopulsar
class Subscriber:
def __init__(self, pulsar_host, topic, subscription, consumer_name,
schema=None, max_size=10):
self.pulsar_host = pulsar_host
self.topic = topic
self.subscription = subscription
self.consumer_name = consumer_name
self.schema = schema
self.q = {}
self.full = {}
async def run(self):
while True:
try:
async with aiopulsar.connect(self.pulsar_host) as client:
async with client.subscribe(
topic=self.topic,
subscription_name=self.subscription,
consumer_name=self.consumer_name,
schema=self.schema,
) as consumer:
while True:
msg = await consumer.receive()
# Acknowledge successful reception of the message
await consumer.acknowledge(msg)
try:
id = msg.properties()["id"]
except:
id = None
value = msg.value()
if id in self.q:
await self.q[id].put(value)
for q in self.full.values():
await q.put(value)
except Exception as e:
print("Exception:", e, flush=True)
# If handler drops out, sleep a retry
await asyncio.sleep(2)
async def subscribe(self, id):
q = asyncio.Queue()
self.q[id] = q
return q
async def unsubscribe(self, id):
if id in self.q:
del self.q[id]
async def subscribe_all(self, id):
q = asyncio.Queue()
self.full[id] = q
return q
async def unsubscribe_all(self, id):
if id in self.full:
del self.full[id]

View file

@ -0,0 +1,28 @@
from ... schema import TextCompletionRequest, TextCompletionResponse
from ... schema import text_completion_request_queue
from ... schema import text_completion_response_queue
from . endpoint import ServiceEndpoint
class TextCompletionEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
super(TextCompletionEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=text_completion_request_queue,
response_queue=text_completion_response_queue,
request_schema=TextCompletionRequest,
response_schema=TextCompletionResponse,
endpoint_path="/api/v1/text-completion",
timeout=timeout,
)
def to_request(self, body):
return TextCompletionRequest(
system=body["system"],
prompt=body["prompt"]
)
def from_response(self, message):
return { "response": message.response }

View file

@ -0,0 +1,59 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
from ... schema import Metadata
from ... schema import Triples
from ... schema import triples_store_queue
from . publisher import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph
class TriplesLoadEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, path="/api/v1/load/triples"):
super(TriplesLoadEndpoint, self).__init__(
endpoint_path=path
)
self.pulsar_host=pulsar_host
self.publisher = Publisher(
self.pulsar_host, triples_store_queue,
schema=JsonSchema(Triples)
)
async def start(self):
self.task = asyncio.create_task(
self.publisher.run()
)
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = Triples(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
triples=to_subgraph(data["triples"]),
)
await self.publisher.send(None, elt)
running.stop()

View file

@ -0,0 +1,53 @@
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples
from ... schema import triples_request_queue
from ... schema import triples_response_queue
from . endpoint import ServiceEndpoint
from . serialize import to_value, serialize_subgraph
class TriplesQueryEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout):
super(TriplesQueryEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=triples_request_queue,
response_queue=triples_response_queue,
request_schema=TriplesQueryRequest,
response_schema=TriplesQueryResponse,
endpoint_path="/api/v1/triples-query",
timeout=timeout,
)
def to_request(self, body):
if "s" in body:
s = to_value(body["s"])
else:
s = None
if "p" in body:
p = to_value(body["p"])
else:
p = None
if "o" in body:
o = to_value(body["o"])
else:
o = None
limit = int(body.get("limit", 10000))
return TriplesQueryRequest(
s = s, p = p, o = o,
limit = limit,
user = body.get("user", "trustgraph"),
collection = body.get("collection", "default"),
)
def from_response(self, message):
print(message)
return {
"response": serialize_subgraph(message.triples)
}

View file

@ -0,0 +1,56 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from ... schema import Triples
from ... schema import triples_store_queue
from . subscriber import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_triples
class TriplesStreamEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, path="/api/v1/stream/triples"):
super(TriplesStreamEndpoint, self).__init__(
endpoint_path=path
)
self.pulsar_host=pulsar_host
self.subscriber = Subscriber(
self.pulsar_host, triples_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(Triples)
)
async def start(self):
self.task = asyncio.create_task(
self.subscriber.run()
)
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = await self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.wait_for(q.get(), 0.5)
await ws.send_json(serialize_triples(resp))
except TimeoutError:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
await self.subscriber.unsubscribe_all(id)
running.stop()

View file

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . service import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,102 @@
"""
Wikipedia lookup service. Fetchs an extract from the Wikipedia page
using the API.
"""
from trustgraph.schema import LookupRequest, LookupResponse, Error
from trustgraph.schema import encyclopedia_lookup_request_queue
from trustgraph.schema import encyclopedia_lookup_response_queue
from trustgraph.log_level import LogLevel
from trustgraph.base import ConsumerProducer
import requests
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = encyclopedia_lookup_request_queue
default_output_queue = encyclopedia_lookup_response_queue
default_subscriber = module
default_url="https://en.wikipedia.org/"
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
url = params.get("url", default_url)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": LookupRequest,
"output_schema": LookupResponse,
}
)
self.url = url
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling {v.kind} / {v.term}...", flush=True)
try:
url = f"{self.url}/api/rest_v1/page/summary/{v.term}"
resp = Result = requests.get(url).json()
resp = resp["extract"]
r = LookupResponse(
error=None,
text=resp
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
return
except Exception as e:
r = LookupResponse(
error=Error(
type = "lookup-error",
message = str(e),
),
text=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
return
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-u', '--url',
default=default_url,
help=f'LLM model (default: {default_url})'
)
def run():
Processor.start(module, __doc__)