Fixed a problem with the packages, api/__init__.py appeared in both (#196)

trustgraph-flow and trustgraph-base, moved the gateway stuff into a
different directory.
This commit is contained in:
cybermaggedon 2024-12-06 13:05:56 +00:00 committed by GitHub
parent 7df7843dad
commit 67d69b5285
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1 additions and 1 deletions

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . service import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,31 @@
from ... schema import AgentRequest, AgentResponse
from ... schema import agent_request_queue
from ... schema import agent_response_queue
from . endpoint import MultiResponseServiceEndpoint
class AgentEndpoint(MultiResponseServiceEndpoint):
def __init__(self, pulsar_host, timeout, auth):
super(AgentEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=agent_request_queue,
response_queue=agent_response_queue,
request_schema=AgentRequest,
response_schema=AgentResponse,
endpoint_path="/api/v1/agent",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
return AgentRequest(
question=body["question"]
)
def from_response(self, message):
if message.answer:
return { "answer": message.answer }, True
else:
return {}, False

View file

@ -0,0 +1,22 @@
class Authenticator:
def __init__(self, token=None, allow_all=False):
if not allow_all and token is None:
raise RuntimeError("Need a token")
if not allow_all and token == "":
raise RuntimeError("Need a token")
self.token = token
self.allow_all = allow_all
def permitted(self, token, roles):
if self.allow_all: return True
if self.token != token: return False
return True

View file

@ -0,0 +1,30 @@
from ... schema import LookupRequest, LookupResponse
from ... schema import dbpedia_lookup_request_queue
from ... schema import dbpedia_lookup_response_queue
from . endpoint import ServiceEndpoint
class DbpediaEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout, auth):
super(DbpediaEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=dbpedia_lookup_request_queue,
response_queue=dbpedia_lookup_response_queue,
request_schema=LookupRequest,
response_schema=LookupResponse,
endpoint_path="/api/v1/dbpedia",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
return LookupRequest(
term=body["term"],
kind=body.get("kind", None),
)
def from_response(self, message):
return { "text": message.text }

View file

@ -0,0 +1,28 @@
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import embeddings_request_queue
from ... schema import embeddings_response_queue
from . endpoint import ServiceEndpoint
class EmbeddingsEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout, auth):
super(EmbeddingsEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=embeddings_request_queue,
response_queue=embeddings_response_queue,
request_schema=EmbeddingsRequest,
response_schema=EmbeddingsResponse,
endpoint_path="/api/v1/embeddings",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
return EmbeddingsRequest(
text=body["text"]
)
def from_response(self, message):
return { "vectors": message.vectors }

View file

@ -0,0 +1,30 @@
from ... schema import LookupRequest, LookupResponse
from ... schema import encyclopedia_lookup_request_queue
from ... schema import encyclopedia_lookup_response_queue
from . endpoint import ServiceEndpoint
class EncyclopediaEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout, auth):
super(EncyclopediaEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=encyclopedia_lookup_request_queue,
response_queue=encyclopedia_lookup_response_queue,
request_schema=LookupRequest,
response_schema=LookupResponse,
endpoint_path="/api/v1/encyclopedia",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
return LookupRequest(
term=body["term"],
kind=body.get("kind", None),
)
def from_response(self, message):
return { "text": message.text }

View file

@ -0,0 +1,166 @@
import asyncio
from pulsar.schema import JsonSchema
from aiohttp import web
import uuid
import logging
from . publisher import Publisher
from . subscriber import Subscriber
logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO)
class ServiceEndpoint:
def __init__(
self,
pulsar_host,
request_queue, request_schema,
response_queue, response_schema,
endpoint_path,
auth,
subscription="api-gateway", consumer_name="api-gateway",
timeout=600,
):
self.pub = Publisher(
pulsar_host, request_queue,
schema=JsonSchema(request_schema)
)
self.sub = Subscriber(
pulsar_host, response_queue,
subscription, consumer_name,
JsonSchema(response_schema)
)
self.path = endpoint_path
self.timeout = timeout
self.auth = auth
self.operation = "service"
async def start(self):
self.pub.start()
self.sub.start()
def add_routes(self, app):
app.add_routes([
web.post(self.path, self.handle),
])
def to_request(self, request):
raise RuntimeError("Not defined")
def from_response(self, response):
raise RuntimeError("Not defined")
async def handle(self, request):
id = str(uuid.uuid4())
print(request.path, "...")
try:
ht = request.headers["Authorization"]
tokens = ht.split(" ", 2)
if tokens[0] != "Bearer":
return web.HTTPUnauthorized()
token = tokens[1]
except:
token = ""
if not self.auth.permitted(token, self.operation):
return web.HTTPUnauthorized()
try:
data = await request.json()
print(data)
q = self.sub.subscribe(id)
await asyncio.to_thread(
self.pub.send, id, self.to_request(data)
)
try:
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
except Exception as e:
raise RuntimeError("Timeout")
print(resp)
if resp.error:
print("Error")
return web.json_response(
{ "error": resp.error.message }
)
return web.json_response(
self.from_response(resp)
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
self.sub.unsubscribe(id)
class MultiResponseServiceEndpoint(ServiceEndpoint):
async def handle(self, request):
id = str(uuid.uuid4())
try:
data = await request.json()
q = self.sub.subscribe(id)
await asyncio.to_thread(
self.pub.send, id, self.to_request(data)
)
# Keeps looking at responses...
while True:
try:
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
except Exception as e:
raise RuntimeError("Timeout waiting for response")
if resp.error:
return web.json_response(
{ "error": resp.error.message }
)
# Until from_response says we have a finished answer
resp, fin = self.from_response(resp)
if fin:
return web.json_response(resp)
# Not finished, so loop round and continue
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
finally:
self.sub.unsubscribe(id)

View file

@ -0,0 +1,60 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
from ... schema import Metadata
from ... schema import GraphEmbeddings
from ... schema import graph_embeddings_store_queue
from . publisher import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph, to_value
class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth, path="/api/v1/load/graph-embeddings",
):
super(GraphEmbeddingsLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.publisher = Publisher(
self.pulsar_host, graph_embeddings_store_queue,
schema=JsonSchema(GraphEmbeddings)
)
async def start(self):
self.publisher.start()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = GraphEmbeddings(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
entity=to_value(data["entity"]),
vectors=data["vectors"],
)
self.publisher.send(None, elt)
running.stop()

View file

@ -0,0 +1,57 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from ... schema import GraphEmbeddings
from ... schema import graph_embeddings_store_queue
from . subscriber import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_graph_embeddings
class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth, path="/api/v1/stream/graph-embeddings"
):
super(GraphEmbeddingsStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.subscriber = Subscriber(
self.pulsar_host, graph_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(GraphEmbeddings)
)
async def start(self):
self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_graph_embeddings(resp))
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
self.subscriber.unsubscribe_all(id)
running.stop()

View file

@ -0,0 +1,31 @@
from ... schema import GraphRagQuery, GraphRagResponse
from ... schema import graph_rag_request_queue
from ... schema import graph_rag_response_queue
from . endpoint import ServiceEndpoint
class GraphRagEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout, auth):
super(GraphRagEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=graph_rag_request_queue,
response_queue=graph_rag_response_queue,
request_schema=GraphRagQuery,
response_schema=GraphRagResponse,
endpoint_path="/api/v1/graph-rag",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
return GraphRagQuery(
query=body["query"],
user=body.get("user", "trustgraph"),
collection=body.get("collection", "default"),
)
def from_response(self, message):
return { "response": message.response }

View file

@ -0,0 +1,30 @@
from ... schema import LookupRequest, LookupResponse
from ... schema import internet_search_request_queue
from ... schema import internet_search_response_queue
from . endpoint import ServiceEndpoint
class InternetSearchEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout, auth):
super(InternetSearchEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=internet_search_request_queue,
response_queue=internet_search_response_queue,
request_schema=LookupRequest,
response_schema=LookupResponse,
endpoint_path="/api/v1/internet-search",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
return LookupRequest(
term=body["term"],
kind=body.get("kind", None),
)
def from_response(self, message):
return { "text": message.text }

View file

@ -0,0 +1,42 @@
import json
from ... schema import PromptRequest, PromptResponse
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from . endpoint import ServiceEndpoint
class PromptEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout, auth):
super(PromptEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=prompt_request_queue,
response_queue=prompt_response_queue,
request_schema=PromptRequest,
response_schema=PromptResponse,
endpoint_path="/api/v1/prompt",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
return PromptRequest(
id=body["id"],
terms={
k: json.dumps(v)
for k, v in body["variables"].items()
}
)
def from_response(self, message):
if message.object:
return {
"object": message.object
}
else:
return {
"text": message.text
}

View file

@ -0,0 +1,53 @@
import queue
import time
import pulsar
import threading
class Publisher:
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
chunking_enabled=False):
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = queue.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()
def run(self):
while True:
try:
client = pulsar.Client(
self.pulsar_host,
)
producer = client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
)
while True:
id, item = self.q.get()
if id:
producer.send(item, { "id": id })
else:
producer.send(item)
except Exception as e:
print("Exception:", e, flush=True)
# If handler drops out, sleep a retry
time.sleep(2)
def send(self, id, msg):
self.q.put((id, msg))

View file

@ -0,0 +1,5 @@
class Running:
def __init__(self): self.running = True
def get(self): return self.running
def stop(self): self.running = False

View file

@ -0,0 +1,57 @@
from ... schema import Value, Triple
def to_value(x):
return Value(value=x["v"], is_uri=x["e"])
def to_subgraph(x):
return [
Triple(
s=to_value(t["s"]),
p=to_value(t["p"]),
o=to_value(t["o"])
)
for t in x
]
def serialize_value(v):
return {
"v": v.value,
"e": v.is_uri,
}
def serialize_triple(t):
return {
"s": serialize_value(t.s),
"p": serialize_value(t.p),
"o": serialize_value(t.o)
}
def serialize_subgraph(sg):
return [
serialize_triple(t)
for t in sg
]
def serialize_triples(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"triples": serialize_subgraph(message.triples),
}
def serialize_graph_embeddings(message):
return {
"metadata": {
"id": message.metadata.id,
"metadata": serialize_subgraph(message.metadata.metadata),
"user": message.metadata.user,
"collection": message.metadata.collection,
},
"vectors": message.vectors,
"entity": serialize_value(message.entity),
}

View file

@ -0,0 +1,318 @@
"""
API gateway. Offers HTTP services which are translated to interaction on the
Pulsar bus.
"""
module = ".".join(__name__.split(".")[1:-1])
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
# are active listeners
# FIXME: Connection errors in publishers / subscribers cause those threads
# to fail and are not failed or retried
import asyncio
import argparse
from aiohttp import web
import logging
import os
import base64
import pulsar
from pulsar.schema import JsonSchema
from prometheus_client import start_http_server
from ... log_level import LogLevel
from ... schema import Metadata, Document, TextDocument
from ... schema import document_ingest_queue, text_ingest_queue
from . serialize import to_subgraph
from . running import Running
from . publisher import Publisher
from . subscriber import Subscriber
from . endpoint import ServiceEndpoint, MultiResponseServiceEndpoint
from . text_completion import TextCompletionEndpoint
from . prompt import PromptEndpoint
from . graph_rag import GraphRagEndpoint
from . triples_query import TriplesQueryEndpoint
from . embeddings import EmbeddingsEndpoint
from . encyclopedia import EncyclopediaEndpoint
from . agent import AgentEndpoint
from . dbpedia import DbpediaEndpoint
from . internet_search import InternetSearchEndpoint
from . triples_stream import TriplesStreamEndpoint
from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
from . triples_load import TriplesLoadEndpoint
from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
from . auth import Authenticator
logger = logging.getLogger("api")
logger.setLevel(logging.INFO)
default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
default_timeout = 600
default_port = 8088
default_api_token = os.getenv("GATEWAY_SECRET", "")
class Api:
def __init__(self, **config):
self.app = web.Application(
middlewares=[],
client_max_size=256 * 1024 * 1024
)
self.port = int(config.get("port", default_port))
self.timeout = int(config.get("timeout", default_timeout))
self.pulsar_host = config.get("pulsar_host", default_pulsar_host)
api_token = config.get("api_token", default_api_token)
# Token not set, or token equal empty string means no auth
if api_token:
self.auth = Authenticator(token=api_token)
else:
self.auth = Authenticator(allow_all=True)
self.endpoints = [
TextCompletionEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
PromptEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
GraphRagEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
TriplesQueryEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
EmbeddingsEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
AgentEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
EncyclopediaEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
DbpediaEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
InternetSearchEndpoint(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
TriplesStreamEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
),
GraphEmbeddingsStreamEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
),
TriplesLoadEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
),
GraphEmbeddingsLoadEndpoint(
pulsar_host=self.pulsar_host,
auth = self.auth,
),
]
self.document_out = Publisher(
self.pulsar_host, document_ingest_queue,
schema=JsonSchema(Document),
chunking_enabled=True,
)
self.text_out = Publisher(
self.pulsar_host, text_ingest_queue,
schema=JsonSchema(TextDocument),
chunking_enabled=True,
)
for ep in self.endpoints:
ep.add_routes(self.app)
self.app.add_routes([
web.post("/api/v1/load/document", self.load_document),
web.post("/api/v1/load/text", self.load_text),
])
async def load_document(self, request):
try:
data = await request.json()
if "metadata" in data:
metadata = to_subgraph(data["metadata"])
else:
metadata = []
# Doing a base64 decode/encode here to make sure the
# content is valid base64
doc = base64.b64decode(data["data"])
resp = await asyncio.to_thread(
self.document_out.send,
None,
Document(
metadata=Metadata(
id=data.get("id"),
metadata=metadata,
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
data=base64.b64encode(doc).decode("utf-8")
)
)
print("Document loaded.")
return web.json_response(
{ }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
async def load_text(self, request):
try:
data = await request.json()
if "metadata" in data:
metadata = to_subgraph(data["metadata"])
else:
metadata = []
if "charset" in data:
charset = data["charset"]
else:
charset = "utf-8"
# Text is base64 encoded
text = base64.b64decode(data["text"]).decode(charset)
resp = await asyncio.to_thread(
self.text_out.send,
None,
TextDocument(
metadata=Metadata(
id=data.get("id"),
metadata=metadata,
user=data.get("user", "trustgraph"),
collection=data.get("collection", "default"),
),
text=text,
)
)
print("Text document loaded.")
return web.json_response(
{ }
)
except Exception as e:
logging.error(f"Exception: {e}")
return web.json_response(
{ "error": str(e) }
)
async def app_factory(self):
for ep in self.endpoints:
await ep.start()
self.document_out.start()
self.text_out.start()
return self.app
def run(self):
web.run_app(self.app_factory(), port=self.port)
def run():
parser = argparse.ArgumentParser(
prog="api-gateway",
description=__doc__
)
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'--port',
type=int,
default=default_port,
help=f'Port number to listen on (default: {default_port})',
)
parser.add_argument(
'--timeout',
type=int,
default=default_timeout,
help=f'API request timeout in seconds (default: {default_timeout})',
)
parser.add_argument(
'--api-token',
default=default_api_token,
help=f'Secret API token (default: no auth)',
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.INFO,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)
parser.add_argument(
'--metrics',
action=argparse.BooleanOptionalAction,
default=True,
help=f'Metrics enabled (default: true)',
)
parser.add_argument(
'-P', '--metrics-port',
type=int,
default=8000,
help=f'Prometheus metrics port (default: 8000)',
)
args = parser.parse_args()
args = vars(args)
if args["metrics"]:
start_http_server(args["metrics_port"])
a = Api(**args)
a.run()

View file

@ -0,0 +1,84 @@
import asyncio
from aiohttp import web, WSMsgType
import logging
from . running import Running
logger = logging.getLogger("socket")
logger.setLevel(logging.INFO)
class SocketEndpoint:
def __init__(
self, endpoint_path, auth,
):
self.path = endpoint_path
self.auth = auth
self.operation = "socket"
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
# Ignore incoming messages
pass
running.stop()
async def async_thread(self, ws, running):
while running.get():
try:
await asyncio.sleep(1)
except TimeoutError:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
async def handle(self, request):
try:
token = request.query['token']
except:
token = ""
if not self.auth.permitted(token, self.operation):
return web.HTTPUnauthorized()
running = Running()
ws = web.WebSocketResponse()
await ws.prepare(request)
task = asyncio.create_task(self.async_thread(ws, running))
try:
await self.listener(ws, running)
except Exception as e:
print(e, flush=True)
running.stop()
await ws.close()
await task
return ws
async def start(self):
pass
def add_routes(self, app):
app.add_routes([
web.get(self.path, self.handle),
])

View file

@ -0,0 +1,109 @@
import queue
import pulsar
import threading
import time
class Subscriber:
def __init__(self, pulsar_host, topic, subscription, consumer_name,
schema=None, max_size=100):
self.pulsar_host = pulsar_host
self.topic = topic
self.subscription = subscription
self.consumer_name = consumer_name
self.schema = schema
self.q = {}
self.full = {}
self.max_size = max_size
self.lock = threading.Lock()
def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()
def run(self):
while True:
try:
client = pulsar.Client(
self.pulsar_host,
)
consumer = client.subscribe(
topic=self.topic,
subscription_name=self.subscription,
consumer_name=self.consumer_name,
schema=self.schema,
)
while True:
msg = consumer.receive()
# Acknowledge successful reception of the message
consumer.acknowledge(msg)
try:
id = msg.properties()["id"]
except:
id = None
value = msg.value()
with self.lock:
if id in self.q:
try:
self.q[id].put(value, timeout=0.5)
except:
pass
for q in self.full.values():
try:
q.put(value, timeout=0.5)
except:
pass
except Exception as e:
print("Exception:", e, flush=True)
# If handler drops out, sleep a retry
time.sleep(2)
def subscribe(self, id):
with self.lock:
q = queue.Queue(maxsize=self.max_size)
self.q[id] = q
return q
def unsubscribe(self, id):
with self.lock:
if id in self.q:
# self.q[id].shutdown(immediate=True)
del self.q[id]
def subscribe_all(self, id):
with self.lock:
q = queue.Queue(maxsize=self.max_size)
self.full[id] = q
return q
def unsubscribe_all(self, id):
with self.lock:
if id in self.full:
# self.full[id].shutdown(immediate=True)
del self.full[id]

View file

@ -0,0 +1,29 @@
from ... schema import TextCompletionRequest, TextCompletionResponse
from ... schema import text_completion_request_queue
from ... schema import text_completion_response_queue
from . endpoint import ServiceEndpoint
class TextCompletionEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout, auth):
super(TextCompletionEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=text_completion_request_queue,
response_queue=text_completion_response_queue,
request_schema=TextCompletionRequest,
response_schema=TextCompletionResponse,
endpoint_path="/api/v1/text-completion",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
return TextCompletionRequest(
system=body["system"],
prompt=body["prompt"]
)
def from_response(self, message):
return { "response": message.response }

View file

@ -0,0 +1,57 @@
import asyncio
from pulsar.schema import JsonSchema
import uuid
from aiohttp import WSMsgType
from ... schema import Metadata
from ... schema import Triples
from ... schema import triples_store_queue
from . publisher import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph
class TriplesLoadEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, auth, path="/api/v1/load/triples"):
super(TriplesLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.publisher = Publisher(
self.pulsar_host, triples_store_queue,
schema=JsonSchema(Triples)
)
async def start(self):
self.publisher.start()
async def listener(self, ws, running):
async for msg in ws:
# On error, finish
if msg.type == WSMsgType.ERROR:
break
else:
data = msg.json()
elt = Triples(
metadata=Metadata(
id=data["metadata"]["id"],
metadata=to_subgraph(data["metadata"]["metadata"]),
user=data["metadata"]["user"],
collection=data["metadata"]["collection"],
),
triples=to_subgraph(data["triples"]),
)
self.publisher.send(None, elt)
running.stop()

View file

@ -0,0 +1,54 @@
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples
from ... schema import triples_request_queue
from ... schema import triples_response_queue
from . endpoint import ServiceEndpoint
from . serialize import to_value, serialize_subgraph
class TriplesQueryEndpoint(ServiceEndpoint):
def __init__(self, pulsar_host, timeout, auth):
super(TriplesQueryEndpoint, self).__init__(
pulsar_host=pulsar_host,
request_queue=triples_request_queue,
response_queue=triples_response_queue,
request_schema=TriplesQueryRequest,
response_schema=TriplesQueryResponse,
endpoint_path="/api/v1/triples-query",
timeout=timeout,
auth=auth,
)
def to_request(self, body):
if "s" in body:
s = to_value(body["s"])
else:
s = None
if "p" in body:
p = to_value(body["p"])
else:
p = None
if "o" in body:
o = to_value(body["o"])
else:
o = None
limit = int(body.get("limit", 10000))
return TriplesQueryRequest(
s = s, p = p, o = o,
limit = limit,
user = body.get("user", "trustgraph"),
collection = body.get("collection", "default"),
)
def from_response(self, message):
print(message)
return {
"response": serialize_subgraph(message.triples)
}

View file

@ -0,0 +1,55 @@
import asyncio
import queue
from pulsar.schema import JsonSchema
import uuid
from ... schema import Triples
from ... schema import triples_store_queue
from . subscriber import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_triples
class TriplesStreamEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, auth, path="/api/v1/stream/triples"):
super(TriplesStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.subscriber = Subscriber(
self.pulsar_host, triples_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(Triples)
)
async def start(self):
self.subscriber.start()
async def async_thread(self, ws, running):
id = str(uuid.uuid4())
q = self.subscriber.subscribe_all(id)
while running.get():
try:
resp = await asyncio.to_thread(q.get, timeout=0.5)
await ws.send_json(serialize_triples(resp))
except queue.Empty:
continue
except Exception as e:
print(f"Exception: {str(e)}", flush=True)
break
self.subscriber.unsubscribe_all(id)
running.stop()