mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 02:23:44 +02:00
Fixed a problem with the packages, api/__init__.py appeared in both (#196)
trustgraph-flow and trustgraph-base, moved the gateway stuff into a different directory.
This commit is contained in:
parent
7df7843dad
commit
67d69b5285
25 changed files with 1 additions and 1 deletions
3
trustgraph-flow/trustgraph/gateway/__init__.py
Normal file
3
trustgraph-flow/trustgraph/gateway/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/gateway/__main__.py
Executable file
7
trustgraph-flow/trustgraph/gateway/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
31
trustgraph-flow/trustgraph/gateway/agent.py
Normal file
31
trustgraph-flow/trustgraph/gateway/agent.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
from ... schema import AgentRequest, AgentResponse
|
||||
from ... schema import agent_request_queue
|
||||
from ... schema import agent_response_queue
|
||||
|
||||
from . endpoint import MultiResponseServiceEndpoint
|
||||
|
||||
class AgentEndpoint(MultiResponseServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(AgentEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
request_queue=agent_request_queue,
|
||||
response_queue=agent_response_queue,
|
||||
request_schema=AgentRequest,
|
||||
response_schema=AgentResponse,
|
||||
endpoint_path="/api/v1/agent",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
return AgentRequest(
|
||||
question=body["question"]
|
||||
)
|
||||
|
||||
def from_response(self, message):
|
||||
if message.answer:
|
||||
return { "answer": message.answer }, True
|
||||
else:
|
||||
return {}, False
|
||||
22
trustgraph-flow/trustgraph/gateway/auth.py
Normal file
22
trustgraph-flow/trustgraph/gateway/auth.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
class Authenticator:
|
||||
|
||||
def __init__(self, token=None, allow_all=False):
|
||||
|
||||
if not allow_all and token is None:
|
||||
raise RuntimeError("Need a token")
|
||||
|
||||
if not allow_all and token == "":
|
||||
raise RuntimeError("Need a token")
|
||||
|
||||
self.token = token
|
||||
self.allow_all = allow_all
|
||||
|
||||
def permitted(self, token, roles):
|
||||
|
||||
if self.allow_all: return True
|
||||
|
||||
if self.token != token: return False
|
||||
|
||||
return True
|
||||
|
||||
30
trustgraph-flow/trustgraph/gateway/dbpedia.py
Normal file
30
trustgraph-flow/trustgraph/gateway/dbpedia.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
|
||||
from ... schema import LookupRequest, LookupResponse
|
||||
from ... schema import dbpedia_lookup_request_queue
|
||||
from ... schema import dbpedia_lookup_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class DbpediaEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(DbpediaEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
request_queue=dbpedia_lookup_request_queue,
|
||||
response_queue=dbpedia_lookup_response_queue,
|
||||
request_schema=LookupRequest,
|
||||
response_schema=LookupResponse,
|
||||
endpoint_path="/api/v1/dbpedia",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
return LookupRequest(
|
||||
term=body["term"],
|
||||
kind=body.get("kind", None),
|
||||
)
|
||||
|
||||
def from_response(self, message):
|
||||
return { "text": message.text }
|
||||
|
||||
28
trustgraph-flow/trustgraph/gateway/embeddings.py
Normal file
28
trustgraph-flow/trustgraph/gateway/embeddings.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
|
||||
from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from ... schema import embeddings_request_queue
|
||||
from ... schema import embeddings_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class EmbeddingsEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(EmbeddingsEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
request_queue=embeddings_request_queue,
|
||||
response_queue=embeddings_response_queue,
|
||||
request_schema=EmbeddingsRequest,
|
||||
response_schema=EmbeddingsResponse,
|
||||
endpoint_path="/api/v1/embeddings",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
return EmbeddingsRequest(
|
||||
text=body["text"]
|
||||
)
|
||||
|
||||
def from_response(self, message):
|
||||
return { "vectors": message.vectors }
|
||||
30
trustgraph-flow/trustgraph/gateway/encyclopedia.py
Normal file
30
trustgraph-flow/trustgraph/gateway/encyclopedia.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
|
||||
from ... schema import LookupRequest, LookupResponse
|
||||
from ... schema import encyclopedia_lookup_request_queue
|
||||
from ... schema import encyclopedia_lookup_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class EncyclopediaEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(EncyclopediaEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
request_queue=encyclopedia_lookup_request_queue,
|
||||
response_queue=encyclopedia_lookup_response_queue,
|
||||
request_schema=LookupRequest,
|
||||
response_schema=LookupResponse,
|
||||
endpoint_path="/api/v1/encyclopedia",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
return LookupRequest(
|
||||
term=body["term"],
|
||||
kind=body.get("kind", None),
|
||||
)
|
||||
|
||||
def from_response(self, message):
|
||||
return { "text": message.text }
|
||||
|
||||
166
trustgraph-flow/trustgraph/gateway/endpoint.py
Normal file
166
trustgraph-flow/trustgraph/gateway/endpoint.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
|
||||
import asyncio
|
||||
from pulsar.schema import JsonSchema
|
||||
from aiohttp import web
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from . publisher import Publisher
|
||||
from . subscriber import Subscriber
|
||||
|
||||
logger = logging.getLogger("endpoint")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class ServiceEndpoint:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host,
|
||||
request_queue, request_schema,
|
||||
response_queue, response_schema,
|
||||
endpoint_path,
|
||||
auth,
|
||||
subscription="api-gateway", consumer_name="api-gateway",
|
||||
timeout=600,
|
||||
):
|
||||
|
||||
self.pub = Publisher(
|
||||
pulsar_host, request_queue,
|
||||
schema=JsonSchema(request_schema)
|
||||
)
|
||||
|
||||
self.sub = Subscriber(
|
||||
pulsar_host, response_queue,
|
||||
subscription, consumer_name,
|
||||
JsonSchema(response_schema)
|
||||
)
|
||||
|
||||
self.path = endpoint_path
|
||||
self.timeout = timeout
|
||||
self.auth = auth
|
||||
|
||||
self.operation = "service"
|
||||
|
||||
async def start(self):
|
||||
|
||||
self.pub.start()
|
||||
self.sub.start()
|
||||
|
||||
def add_routes(self, app):
|
||||
|
||||
app.add_routes([
|
||||
web.post(self.path, self.handle),
|
||||
])
|
||||
|
||||
def to_request(self, request):
|
||||
raise RuntimeError("Not defined")
|
||||
|
||||
def from_response(self, response):
|
||||
raise RuntimeError("Not defined")
|
||||
|
||||
async def handle(self, request):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
print(request.path, "...")
|
||||
|
||||
try:
|
||||
ht = request.headers["Authorization"]
|
||||
tokens = ht.split(" ", 2)
|
||||
if tokens[0] != "Bearer":
|
||||
return web.HTTPUnauthorized()
|
||||
token = tokens[1]
|
||||
except:
|
||||
token = ""
|
||||
|
||||
if not self.auth.permitted(token, self.operation):
|
||||
return web.HTTPUnauthorized()
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
print(data)
|
||||
|
||||
q = self.sub.subscribe(id)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.pub.send, id, self.to_request(data)
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Timeout")
|
||||
|
||||
print(resp)
|
||||
|
||||
if resp.error:
|
||||
print("Error")
|
||||
return web.json_response(
|
||||
{ "error": resp.error.message }
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
self.from_response(resp)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
finally:
|
||||
self.sub.unsubscribe(id)
|
||||
|
||||
|
||||
class MultiResponseServiceEndpoint(ServiceEndpoint):
|
||||
|
||||
async def handle(self, request):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
q = self.sub.subscribe(id)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.pub.send, id, self.to_request(data)
|
||||
)
|
||||
|
||||
# Keeps looking at responses...
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
resp = await asyncio.to_thread(q.get, timeout=self.timeout)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Timeout waiting for response")
|
||||
|
||||
if resp.error:
|
||||
return web.json_response(
|
||||
{ "error": resp.error.message }
|
||||
)
|
||||
|
||||
# Until from_response says we have a finished answer
|
||||
resp, fin = self.from_response(resp)
|
||||
|
||||
|
||||
if fin:
|
||||
return web.json_response(resp)
|
||||
|
||||
# Not finished, so loop round and continue
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
finally:
|
||||
self.sub.unsubscribe(id)
|
||||
60
trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py
Normal file
60
trustgraph-flow/trustgraph/gateway/graph_embeddings_load.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
|
||||
import asyncio
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
from ... schema import GraphEmbeddings
|
||||
from ... schema import graph_embeddings_store_queue
|
||||
|
||||
from . publisher import Publisher
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import to_subgraph, to_value
|
||||
|
||||
class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(
|
||||
self, pulsar_host, auth, path="/api/v1/load/graph-embeddings",
|
||||
):
|
||||
|
||||
super(GraphEmbeddingsLoadEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_host=pulsar_host
|
||||
|
||||
self.publisher = Publisher(
|
||||
self.pulsar_host, graph_embeddings_store_queue,
|
||||
schema=JsonSchema(GraphEmbeddings)
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
self.publisher.start()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
async for msg in ws:
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
break
|
||||
else:
|
||||
|
||||
data = msg.json()
|
||||
|
||||
elt = GraphEmbeddings(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
entity=to_value(data["entity"]),
|
||||
vectors=data["vectors"],
|
||||
)
|
||||
|
||||
self.publisher.send(None, elt)
|
||||
|
||||
|
||||
running.stop()
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
|
||||
from ... schema import GraphEmbeddings
|
||||
from ... schema import graph_embeddings_store_queue
|
||||
|
||||
from . subscriber import Subscriber
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import serialize_graph_embeddings
|
||||
|
||||
class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(
|
||||
self, pulsar_host, auth, path="/api/v1/stream/graph-embeddings"
|
||||
):
|
||||
|
||||
super(GraphEmbeddingsStreamEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_host=pulsar_host
|
||||
|
||||
self.subscriber = Subscriber(
|
||||
self.pulsar_host, graph_embeddings_store_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
schema=JsonSchema(GraphEmbeddings)
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
self.subscriber.start()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
q = self.subscriber.subscribe_all(id)
|
||||
|
||||
while running.get():
|
||||
try:
|
||||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_graph_embeddings(resp))
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
self.subscriber.unsubscribe_all(id)
|
||||
|
||||
running.stop()
|
||||
|
||||
31
trustgraph-flow/trustgraph/gateway/graph_rag.py
Normal file
31
trustgraph-flow/trustgraph/gateway/graph_rag.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
|
||||
from ... schema import GraphRagQuery, GraphRagResponse
|
||||
from ... schema import graph_rag_request_queue
|
||||
from ... schema import graph_rag_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class GraphRagEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(GraphRagEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
request_queue=graph_rag_request_queue,
|
||||
response_queue=graph_rag_response_queue,
|
||||
request_schema=GraphRagQuery,
|
||||
response_schema=GraphRagResponse,
|
||||
endpoint_path="/api/v1/graph-rag",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
return GraphRagQuery(
|
||||
query=body["query"],
|
||||
user=body.get("user", "trustgraph"),
|
||||
collection=body.get("collection", "default"),
|
||||
)
|
||||
|
||||
def from_response(self, message):
|
||||
return { "response": message.response }
|
||||
|
||||
30
trustgraph-flow/trustgraph/gateway/internet_search.py
Normal file
30
trustgraph-flow/trustgraph/gateway/internet_search.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
|
||||
from ... schema import LookupRequest, LookupResponse
|
||||
from ... schema import internet_search_request_queue
|
||||
from ... schema import internet_search_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class InternetSearchEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(InternetSearchEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
request_queue=internet_search_request_queue,
|
||||
response_queue=internet_search_response_queue,
|
||||
request_schema=LookupRequest,
|
||||
response_schema=LookupResponse,
|
||||
endpoint_path="/api/v1/internet-search",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
return LookupRequest(
|
||||
term=body["term"],
|
||||
kind=body.get("kind", None),
|
||||
)
|
||||
|
||||
def from_response(self, message):
|
||||
return { "text": message.text }
|
||||
|
||||
42
trustgraph-flow/trustgraph/gateway/prompt.py
Normal file
42
trustgraph-flow/trustgraph/gateway/prompt.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
|
||||
import json
|
||||
|
||||
from ... schema import PromptRequest, PromptResponse
|
||||
from ... schema import prompt_request_queue
|
||||
from ... schema import prompt_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class PromptEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(PromptEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
request_queue=prompt_request_queue,
|
||||
response_queue=prompt_response_queue,
|
||||
request_schema=PromptRequest,
|
||||
response_schema=PromptResponse,
|
||||
endpoint_path="/api/v1/prompt",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
return PromptRequest(
|
||||
id=body["id"],
|
||||
terms={
|
||||
k: json.dumps(v)
|
||||
for k, v in body["variables"].items()
|
||||
}
|
||||
)
|
||||
|
||||
def from_response(self, message):
|
||||
if message.object:
|
||||
return {
|
||||
"object": message.object
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"text": message.text
|
||||
}
|
||||
|
||||
53
trustgraph-flow/trustgraph/gateway/publisher.py
Normal file
53
trustgraph-flow/trustgraph/gateway/publisher.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
|
||||
import queue
|
||||
import time
|
||||
import pulsar
|
||||
import threading
|
||||
|
||||
class Publisher:
|
||||
|
||||
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
|
||||
chunking_enabled=False):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.schema = schema
|
||||
self.q = queue.Queue(maxsize=max_size)
|
||||
self.chunking_enabled = chunking_enabled
|
||||
|
||||
def start(self):
|
||||
self.task = threading.Thread(target=self.run)
|
||||
self.task.start()
|
||||
|
||||
def run(self):
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
|
||||
client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
)
|
||||
|
||||
producer = client.create_producer(
|
||||
topic=self.topic,
|
||||
schema=self.schema,
|
||||
chunking_enabled=self.chunking_enabled,
|
||||
)
|
||||
|
||||
while True:
|
||||
|
||||
id, item = self.q.get()
|
||||
|
||||
if id:
|
||||
producer.send(item, { "id": id })
|
||||
else:
|
||||
producer.send(item)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
# If handler drops out, sleep a retry
|
||||
time.sleep(2)
|
||||
|
||||
def send(self, id, msg):
|
||||
self.q.put((id, msg))
|
||||
5
trustgraph-flow/trustgraph/gateway/running.py
Normal file
5
trustgraph-flow/trustgraph/gateway/running.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
|
||||
class Running:
|
||||
def __init__(self): self.running = True
|
||||
def get(self): return self.running
|
||||
def stop(self): self.running = False
|
||||
57
trustgraph-flow/trustgraph/gateway/serialize.py
Normal file
57
trustgraph-flow/trustgraph/gateway/serialize.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
from ... schema import Value, Triple
|
||||
|
||||
def to_value(x):
|
||||
return Value(value=x["v"], is_uri=x["e"])
|
||||
|
||||
def to_subgraph(x):
|
||||
return [
|
||||
Triple(
|
||||
s=to_value(t["s"]),
|
||||
p=to_value(t["p"]),
|
||||
o=to_value(t["o"])
|
||||
)
|
||||
for t in x
|
||||
]
|
||||
|
||||
def serialize_value(v):
|
||||
return {
|
||||
"v": v.value,
|
||||
"e": v.is_uri,
|
||||
}
|
||||
|
||||
def serialize_triple(t):
|
||||
return {
|
||||
"s": serialize_value(t.s),
|
||||
"p": serialize_value(t.p),
|
||||
"o": serialize_value(t.o)
|
||||
}
|
||||
|
||||
def serialize_subgraph(sg):
|
||||
return [
|
||||
serialize_triple(t)
|
||||
for t in sg
|
||||
]
|
||||
|
||||
def serialize_triples(message):
|
||||
return {
|
||||
"metadata": {
|
||||
"id": message.metadata.id,
|
||||
"metadata": serialize_subgraph(message.metadata.metadata),
|
||||
"user": message.metadata.user,
|
||||
"collection": message.metadata.collection,
|
||||
},
|
||||
"triples": serialize_subgraph(message.triples),
|
||||
}
|
||||
|
||||
def serialize_graph_embeddings(message):
|
||||
return {
|
||||
"metadata": {
|
||||
"id": message.metadata.id,
|
||||
"metadata": serialize_subgraph(message.metadata.metadata),
|
||||
"user": message.metadata.user,
|
||||
"collection": message.metadata.collection,
|
||||
},
|
||||
"vectors": message.vectors,
|
||||
"entity": serialize_value(message.entity),
|
||||
}
|
||||
|
||||
318
trustgraph-flow/trustgraph/gateway/service.py
Executable file
318
trustgraph-flow/trustgraph/gateway/service.py
Executable file
|
|
@ -0,0 +1,318 @@
|
|||
"""
|
||||
API gateway. Offers HTTP services which are translated to interaction on the
|
||||
Pulsar bus.
|
||||
"""
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
# FIXME: Subscribes to Pulsar unnecessarily, should only do it when there
|
||||
# are active listeners
|
||||
|
||||
# FIXME: Connection errors in publishers / subscribers cause those threads
|
||||
# to fail and are not failed or retried
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
from aiohttp import web
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
|
||||
import pulsar
|
||||
from pulsar.schema import JsonSchema
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from ... log_level import LogLevel
|
||||
|
||||
from ... schema import Metadata, Document, TextDocument
|
||||
from ... schema import document_ingest_queue, text_ingest_queue
|
||||
|
||||
from . serialize import to_subgraph
|
||||
from . running import Running
|
||||
from . publisher import Publisher
|
||||
from . subscriber import Subscriber
|
||||
from . endpoint import ServiceEndpoint, MultiResponseServiceEndpoint
|
||||
from . text_completion import TextCompletionEndpoint
|
||||
from . prompt import PromptEndpoint
|
||||
from . graph_rag import GraphRagEndpoint
|
||||
from . triples_query import TriplesQueryEndpoint
|
||||
from . embeddings import EmbeddingsEndpoint
|
||||
from . encyclopedia import EncyclopediaEndpoint
|
||||
from . agent import AgentEndpoint
|
||||
from . dbpedia import DbpediaEndpoint
|
||||
from . internet_search import InternetSearchEndpoint
|
||||
from . triples_stream import TriplesStreamEndpoint
|
||||
from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
|
||||
from . triples_load import TriplesLoadEndpoint
|
||||
from . graph_embeddings_load import GraphEmbeddingsLoadEndpoint
|
||||
from . auth import Authenticator
|
||||
|
||||
logger = logging.getLogger("api")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
default_pulsar_host = os.getenv("PULSAR_HOST", "pulsar://pulsar:6650")
|
||||
default_timeout = 600
|
||||
default_port = 8088
|
||||
default_api_token = os.getenv("GATEWAY_SECRET", "")
|
||||
|
||||
class Api:
|
||||
|
||||
def __init__(self, **config):
|
||||
|
||||
self.app = web.Application(
|
||||
middlewares=[],
|
||||
client_max_size=256 * 1024 * 1024
|
||||
)
|
||||
|
||||
self.port = int(config.get("port", default_port))
|
||||
self.timeout = int(config.get("timeout", default_timeout))
|
||||
self.pulsar_host = config.get("pulsar_host", default_pulsar_host)
|
||||
|
||||
api_token = config.get("api_token", default_api_token)
|
||||
|
||||
# Token not set, or token equal empty string means no auth
|
||||
if api_token:
|
||||
self.auth = Authenticator(token=api_token)
|
||||
else:
|
||||
self.auth = Authenticator(allow_all=True)
|
||||
|
||||
self.endpoints = [
|
||||
TextCompletionEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
PromptEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
GraphRagEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
TriplesQueryEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
EmbeddingsEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
AgentEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
EncyclopediaEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
DbpediaEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
InternetSearchEndpoint(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
TriplesStreamEndpoint(
|
||||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
),
|
||||
GraphEmbeddingsStreamEndpoint(
|
||||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
),
|
||||
TriplesLoadEndpoint(
|
||||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
),
|
||||
GraphEmbeddingsLoadEndpoint(
|
||||
pulsar_host=self.pulsar_host,
|
||||
auth = self.auth,
|
||||
),
|
||||
]
|
||||
|
||||
self.document_out = Publisher(
|
||||
self.pulsar_host, document_ingest_queue,
|
||||
schema=JsonSchema(Document),
|
||||
chunking_enabled=True,
|
||||
)
|
||||
|
||||
self.text_out = Publisher(
|
||||
self.pulsar_host, text_ingest_queue,
|
||||
schema=JsonSchema(TextDocument),
|
||||
chunking_enabled=True,
|
||||
)
|
||||
|
||||
for ep in self.endpoints:
|
||||
ep.add_routes(self.app)
|
||||
|
||||
self.app.add_routes([
|
||||
web.post("/api/v1/load/document", self.load_document),
|
||||
web.post("/api/v1/load/text", self.load_text),
|
||||
])
|
||||
|
||||
async def load_document(self, request):
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
if "metadata" in data:
|
||||
metadata = to_subgraph(data["metadata"])
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
# Doing a base64 decode/encode here to make sure the
|
||||
# content is valid base64
|
||||
doc = base64.b64decode(data["data"])
|
||||
|
||||
resp = await asyncio.to_thread(
|
||||
self.document_out.send,
|
||||
None,
|
||||
Document(
|
||||
metadata=Metadata(
|
||||
id=data.get("id"),
|
||||
metadata=metadata,
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
),
|
||||
data=base64.b64encode(doc).decode("utf-8")
|
||||
)
|
||||
)
|
||||
|
||||
print("Document loaded.")
|
||||
|
||||
return web.json_response(
|
||||
{ }
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
async def load_text(self, request):
|
||||
|
||||
try:
|
||||
|
||||
data = await request.json()
|
||||
|
||||
if "metadata" in data:
|
||||
metadata = to_subgraph(data["metadata"])
|
||||
else:
|
||||
metadata = []
|
||||
|
||||
if "charset" in data:
|
||||
charset = data["charset"]
|
||||
else:
|
||||
charset = "utf-8"
|
||||
|
||||
# Text is base64 encoded
|
||||
text = base64.b64decode(data["text"]).decode(charset)
|
||||
|
||||
resp = await asyncio.to_thread(
|
||||
self.text_out.send,
|
||||
None,
|
||||
TextDocument(
|
||||
metadata=Metadata(
|
||||
id=data.get("id"),
|
||||
metadata=metadata,
|
||||
user=data.get("user", "trustgraph"),
|
||||
collection=data.get("collection", "default"),
|
||||
),
|
||||
text=text,
|
||||
)
|
||||
)
|
||||
|
||||
print("Text document loaded.")
|
||||
|
||||
return web.json_response(
|
||||
{ }
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
return web.json_response(
|
||||
{ "error": str(e) }
|
||||
)
|
||||
|
||||
async def app_factory(self):
|
||||
|
||||
for ep in self.endpoints:
|
||||
await ep.start()
|
||||
|
||||
self.document_out.start()
|
||||
self.text_out.start()
|
||||
|
||||
return self.app
|
||||
|
||||
def run(self):
|
||||
web.run_app(self.app_factory(), port=self.port)
|
||||
|
||||
def run():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="api-gateway",
|
||||
description=__doc__
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-p', '--pulsar-host',
|
||||
default=default_pulsar_host,
|
||||
help=f'Pulsar host (default: {default_pulsar_host})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--port',
|
||||
type=int,
|
||||
default=default_port,
|
||||
help=f'Port number to listen on (default: {default_port})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--timeout',
|
||||
type=int,
|
||||
default=default_timeout,
|
||||
help=f'API request timeout in seconds (default: {default_timeout})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--api-token',
|
||||
default=default_api_token,
|
||||
help=f'Secret API token (default: no auth)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--log-level',
|
||||
type=LogLevel,
|
||||
default=LogLevel.INFO,
|
||||
choices=list(LogLevel),
|
||||
help=f'Output queue (default: info)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--metrics',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help=f'Metrics enabled (default: true)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-P', '--metrics-port',
|
||||
type=int,
|
||||
default=8000,
|
||||
help=f'Prometheus metrics port (default: 8000)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = vars(args)
|
||||
|
||||
if args["metrics"]:
|
||||
start_http_server(args["metrics_port"])
|
||||
|
||||
a = Api(**args)
|
||||
a.run()
|
||||
|
||||
84
trustgraph-flow/trustgraph/gateway/socket.py
Normal file
84
trustgraph-flow/trustgraph/gateway/socket.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
|
||||
import asyncio
|
||||
from aiohttp import web, WSMsgType
|
||||
import logging
|
||||
|
||||
from . running import Running
|
||||
|
||||
logger = logging.getLogger("socket")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class SocketEndpoint:
|
||||
|
||||
def __init__(
|
||||
self, endpoint_path, auth,
|
||||
):
|
||||
|
||||
self.path = endpoint_path
|
||||
self.auth = auth
|
||||
self.operation = "socket"
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
async for msg in ws:
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
break
|
||||
else:
|
||||
# Ignore incoming messages
|
||||
pass
|
||||
|
||||
running.stop()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
while running.get():
|
||||
try:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except TimeoutError:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
|
||||
async def handle(self, request):
|
||||
|
||||
try:
|
||||
token = request.query['token']
|
||||
except:
|
||||
token = ""
|
||||
|
||||
if not self.auth.permitted(token, self.operation):
|
||||
return web.HTTPUnauthorized()
|
||||
|
||||
running = Running()
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
task = asyncio.create_task(self.async_thread(ws, running))
|
||||
|
||||
try:
|
||||
|
||||
await self.listener(ws, running)
|
||||
|
||||
except Exception as e:
|
||||
print(e, flush=True)
|
||||
|
||||
running.stop()
|
||||
|
||||
await ws.close()
|
||||
|
||||
await task
|
||||
|
||||
return ws
|
||||
|
||||
async def start(self):
|
||||
pass
|
||||
|
||||
def add_routes(self, app):
|
||||
|
||||
app.add_routes([
|
||||
web.get(self.path, self.handle),
|
||||
])
|
||||
|
||||
109
trustgraph-flow/trustgraph/gateway/subscriber.py
Normal file
109
trustgraph-flow/trustgraph/gateway/subscriber.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
|
||||
import queue
|
||||
import pulsar
|
||||
import threading
|
||||
import time
|
||||
|
||||
class Subscriber:
|
||||
|
||||
def __init__(self, pulsar_host, topic, subscription, consumer_name,
|
||||
schema=None, max_size=100):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.subscription = subscription
|
||||
self.consumer_name = consumer_name
|
||||
self.schema = schema
|
||||
self.q = {}
|
||||
self.full = {}
|
||||
self.max_size = max_size
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def start(self):
|
||||
self.task = threading.Thread(target=self.run)
|
||||
self.task.start()
|
||||
|
||||
def run(self):
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
|
||||
client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
)
|
||||
|
||||
consumer = client.subscribe(
|
||||
topic=self.topic,
|
||||
subscription_name=self.subscription,
|
||||
consumer_name=self.consumer_name,
|
||||
schema=self.schema,
|
||||
)
|
||||
|
||||
while True:
|
||||
|
||||
msg = consumer.receive()
|
||||
|
||||
# Acknowledge successful reception of the message
|
||||
consumer.acknowledge(msg)
|
||||
|
||||
try:
|
||||
id = msg.properties()["id"]
|
||||
except:
|
||||
id = None
|
||||
|
||||
value = msg.value()
|
||||
|
||||
with self.lock:
|
||||
|
||||
if id in self.q:
|
||||
try:
|
||||
self.q[id].put(value, timeout=0.5)
|
||||
except:
|
||||
pass
|
||||
|
||||
for q in self.full.values():
|
||||
try:
|
||||
q.put(value, timeout=0.5)
|
||||
except:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
# If handler drops out, sleep a retry
|
||||
time.sleep(2)
|
||||
|
||||
def subscribe(self, id):
|
||||
|
||||
with self.lock:
|
||||
|
||||
q = queue.Queue(maxsize=self.max_size)
|
||||
self.q[id] = q
|
||||
|
||||
return q
|
||||
|
||||
def unsubscribe(self, id):
|
||||
|
||||
with self.lock:
|
||||
|
||||
if id in self.q:
|
||||
# self.q[id].shutdown(immediate=True)
|
||||
del self.q[id]
|
||||
|
||||
def subscribe_all(self, id):
|
||||
|
||||
with self.lock:
|
||||
|
||||
q = queue.Queue(maxsize=self.max_size)
|
||||
self.full[id] = q
|
||||
|
||||
return q
|
||||
|
||||
def unsubscribe_all(self, id):
|
||||
|
||||
with self.lock:
|
||||
|
||||
if id in self.full:
|
||||
# self.full[id].shutdown(immediate=True)
|
||||
del self.full[id]
|
||||
|
||||
29
trustgraph-flow/trustgraph/gateway/text_completion.py
Normal file
29
trustgraph-flow/trustgraph/gateway/text_completion.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
|
||||
from ... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from ... schema import text_completion_request_queue
|
||||
from ... schema import text_completion_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
|
||||
class TextCompletionEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(TextCompletionEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
request_queue=text_completion_request_queue,
|
||||
response_queue=text_completion_response_queue,
|
||||
request_schema=TextCompletionRequest,
|
||||
response_schema=TextCompletionResponse,
|
||||
endpoint_path="/api/v1/text-completion",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
return TextCompletionRequest(
|
||||
system=body["system"],
|
||||
prompt=body["prompt"]
|
||||
)
|
||||
|
||||
def from_response(self, message):
|
||||
return { "response": message.response }
|
||||
57
trustgraph-flow/trustgraph/gateway/triples_load.py
Normal file
57
trustgraph-flow/trustgraph/gateway/triples_load.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
|
||||
import asyncio
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
from aiohttp import WSMsgType
|
||||
|
||||
from ... schema import Metadata
|
||||
from ... schema import Triples
|
||||
from ... schema import triples_store_queue
|
||||
|
||||
from . publisher import Publisher
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import to_subgraph
|
||||
|
||||
class TriplesLoadEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(self, pulsar_host, auth, path="/api/v1/load/triples"):
|
||||
|
||||
super(TriplesLoadEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_host=pulsar_host
|
||||
|
||||
self.publisher = Publisher(
|
||||
self.pulsar_host, triples_store_queue,
|
||||
schema=JsonSchema(Triples)
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
self.publisher.start()
|
||||
|
||||
async def listener(self, ws, running):
|
||||
|
||||
async for msg in ws:
|
||||
# On error, finish
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
break
|
||||
else:
|
||||
|
||||
data = msg.json()
|
||||
|
||||
elt = Triples(
|
||||
metadata=Metadata(
|
||||
id=data["metadata"]["id"],
|
||||
metadata=to_subgraph(data["metadata"]["metadata"]),
|
||||
user=data["metadata"]["user"],
|
||||
collection=data["metadata"]["collection"],
|
||||
),
|
||||
triples=to_subgraph(data["triples"]),
|
||||
)
|
||||
|
||||
self.publisher.send(None, elt)
|
||||
|
||||
|
||||
running.stop()
|
||||
54
trustgraph-flow/trustgraph/gateway/triples_query.py
Normal file
54
trustgraph-flow/trustgraph/gateway/triples_query.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
|
||||
from ... schema import TriplesQueryRequest, TriplesQueryResponse, Triples
|
||||
from ... schema import triples_request_queue
|
||||
from ... schema import triples_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . serialize import to_value, serialize_subgraph
|
||||
|
||||
class TriplesQueryEndpoint(ServiceEndpoint):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(TriplesQueryEndpoint, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
request_queue=triples_request_queue,
|
||||
response_queue=triples_response_queue,
|
||||
request_schema=TriplesQueryRequest,
|
||||
response_schema=TriplesQueryResponse,
|
||||
endpoint_path="/api/v1/triples-query",
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
||||
if "s" in body:
|
||||
s = to_value(body["s"])
|
||||
else:
|
||||
s = None
|
||||
|
||||
if "p" in body:
|
||||
p = to_value(body["p"])
|
||||
else:
|
||||
p = None
|
||||
|
||||
if "o" in body:
|
||||
o = to_value(body["o"])
|
||||
else:
|
||||
o = None
|
||||
|
||||
limit = int(body.get("limit", 10000))
|
||||
|
||||
return TriplesQueryRequest(
|
||||
s = s, p = p, o = o,
|
||||
limit = limit,
|
||||
user = body.get("user", "trustgraph"),
|
||||
collection = body.get("collection", "default"),
|
||||
)
|
||||
|
||||
def from_response(self, message):
|
||||
print(message)
|
||||
return {
|
||||
"response": serialize_subgraph(message.triples)
|
||||
}
|
||||
|
||||
55
trustgraph-flow/trustgraph/gateway/triples_stream.py
Normal file
55
trustgraph-flow/trustgraph/gateway/triples_stream.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
|
||||
import asyncio
|
||||
import queue
|
||||
from pulsar.schema import JsonSchema
|
||||
import uuid
|
||||
|
||||
from ... schema import Triples
|
||||
from ... schema import triples_store_queue
|
||||
|
||||
from . subscriber import Subscriber
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import serialize_triples
|
||||
|
||||
class TriplesStreamEndpoint(SocketEndpoint):
|
||||
|
||||
def __init__(self, pulsar_host, auth, path="/api/v1/stream/triples"):
|
||||
|
||||
super(TriplesStreamEndpoint, self).__init__(
|
||||
endpoint_path=path, auth=auth,
|
||||
)
|
||||
|
||||
self.pulsar_host=pulsar_host
|
||||
|
||||
self.subscriber = Subscriber(
|
||||
self.pulsar_host, triples_store_queue,
|
||||
"api-gateway", "api-gateway",
|
||||
schema=JsonSchema(Triples)
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
|
||||
self.subscriber.start()
|
||||
|
||||
async def async_thread(self, ws, running):
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
q = self.subscriber.subscribe_all(id)
|
||||
|
||||
while running.get():
|
||||
try:
|
||||
resp = await asyncio.to_thread(q.get, timeout=0.5)
|
||||
await ws.send_json(serialize_triples(resp))
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"Exception: {str(e)}", flush=True)
|
||||
break
|
||||
|
||||
self.subscriber.unsubscribe_all(id)
|
||||
|
||||
running.stop()
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue