mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 02:23:44 +02:00
Librarian (#304)
This commit is contained in:
parent
e99c0ac238
commit
a0bf2362f6
32 changed files with 922 additions and 66 deletions
|
|
@ -7,8 +7,8 @@ from aiohttp import WSMsgType
|
|||
from .. schema import Metadata
|
||||
from .. schema import DocumentEmbeddings, ChunkEmbeddings
|
||||
from .. schema import document_embeddings_store_queue
|
||||
from .. base import Publisher
|
||||
|
||||
from . publisher import Publisher
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import to_subgraph
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import uuid
|
|||
|
||||
from .. schema import DocumentEmbeddings
|
||||
from .. schema import document_embeddings_store_queue
|
||||
from .. base import Subscriber
|
||||
|
||||
from . subscriber import Subscriber
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import serialize_document_embeddings
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ from aiohttp import web
|
|||
import uuid
|
||||
import logging
|
||||
|
||||
from . publisher import Publisher
|
||||
from . subscriber import Subscriber
|
||||
from .. base import Publisher
|
||||
from .. base import Subscriber
|
||||
|
||||
logger = logging.getLogger("endpoint")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from aiohttp import WSMsgType
|
|||
from .. schema import Metadata
|
||||
from .. schema import GraphEmbeddings, EntityEmbeddings
|
||||
from .. schema import graph_embeddings_store_queue
|
||||
from .. base import Publisher
|
||||
|
||||
from . publisher import Publisher
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import to_subgraph, to_value
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import uuid
|
|||
|
||||
from .. schema import GraphEmbeddings
|
||||
from .. schema import graph_embeddings_store_queue
|
||||
from .. base import Subscriber
|
||||
|
||||
from . subscriber import Subscriber
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import serialize_graph_embeddings
|
||||
|
||||
|
|
|
|||
57
trustgraph-flow/trustgraph/gateway/librarian.py
Normal file
57
trustgraph-flow/trustgraph/gateway/librarian.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
|
||||
from .. schema import LibrarianRequest, LibrarianResponse, Triples
|
||||
from .. schema import librarian_request_queue
|
||||
from .. schema import librarian_response_queue
|
||||
|
||||
from . endpoint import ServiceEndpoint
|
||||
from . requestor import ServiceRequestor
|
||||
from . serialize import serialize_document_package, serialize_document_info
|
||||
from . serialize import to_document_package, to_document_info, to_criteria
|
||||
|
||||
class LibrarianRequestor(ServiceRequestor):
|
||||
def __init__(self, pulsar_host, timeout, auth):
|
||||
|
||||
super(LibrarianRequestor, self).__init__(
|
||||
pulsar_host=pulsar_host,
|
||||
request_queue=librarian_request_queue,
|
||||
response_queue=librarian_response_queue,
|
||||
request_schema=LibrarianRequest,
|
||||
response_schema=LibrarianResponse,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def to_request(self, body):
|
||||
|
||||
if "document" in body:
|
||||
dp = to_document_package(body["document"])
|
||||
else:
|
||||
dp = None
|
||||
|
||||
if "criteria" in body:
|
||||
criteria = to_criteria(body["criteria"])
|
||||
else:
|
||||
criteria = None
|
||||
|
||||
limit = int(body.get("limit", 10000))
|
||||
|
||||
return LibrarianRequest(
|
||||
operation = body.get("operation", None),
|
||||
id = body.get("id", None),
|
||||
document = dp,
|
||||
user = body.get("user", None),
|
||||
collection = body.get("collection", None),
|
||||
criteria = criteria,
|
||||
)
|
||||
|
||||
def from_response(self, message):
|
||||
|
||||
response = {}
|
||||
|
||||
if message.document:
|
||||
response["document"] = serialize_document_package(message.document)
|
||||
|
||||
if message.info:
|
||||
response["info"] = serialize_document_info(message.info)
|
||||
|
||||
return response, True
|
||||
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
|
||||
import queue
|
||||
import time
|
||||
import pulsar
|
||||
import threading
|
||||
|
||||
class Publisher:
|
||||
|
||||
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
|
||||
chunking_enabled=True, listener=None):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.schema = schema
|
||||
self.q = queue.Queue(maxsize=max_size)
|
||||
self.chunking_enabled = chunking_enabled
|
||||
self.listener_name = listener
|
||||
|
||||
def start(self):
|
||||
self.task = threading.Thread(target=self.run)
|
||||
self.task.start()
|
||||
|
||||
def run(self):
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
|
||||
client = pulsar.Client(
|
||||
self.pulsar_host, listener_name=self.listener_name
|
||||
)
|
||||
|
||||
producer = client.create_producer(
|
||||
topic=self.topic,
|
||||
schema=self.schema,
|
||||
chunking_enabled=self.chunking_enabled,
|
||||
)
|
||||
|
||||
while True:
|
||||
|
||||
id, item = self.q.get()
|
||||
|
||||
if id:
|
||||
producer.send(item, { "id": id })
|
||||
else:
|
||||
producer.send(item)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
# If handler drops out, sleep a retry
|
||||
time.sleep(2)
|
||||
|
||||
def send(self, id, msg):
|
||||
self.q.put((id, msg))
|
||||
|
|
@ -4,8 +4,8 @@ from pulsar.schema import JsonSchema
|
|||
import uuid
|
||||
import logging
|
||||
|
||||
from . publisher import Publisher
|
||||
from . subscriber import Subscriber
|
||||
from .. base import Publisher
|
||||
from .. base import Subscriber
|
||||
|
||||
logger = logging.getLogger("requestor")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
@ -68,7 +68,10 @@ class ServiceRequestor:
|
|||
raise RuntimeError("Timeout")
|
||||
|
||||
if resp.error:
|
||||
err = { "error": resp.error.message }
|
||||
err = { "error": {
|
||||
"type": resp.error.type,
|
||||
"message": resp.error.message,
|
||||
} }
|
||||
if responder:
|
||||
await responder(err, True)
|
||||
return err
|
||||
|
|
@ -87,7 +90,10 @@ class ServiceRequestor:
|
|||
|
||||
logging.error(f"Exception: {e}")
|
||||
|
||||
err = { "error": str(e) }
|
||||
err = { "error": {
|
||||
"type": "gateway-error",
|
||||
"message": str(e),
|
||||
} }
|
||||
if responder:
|
||||
await responder(err, True)
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from pulsar.schema import JsonSchema
|
|||
import uuid
|
||||
import logging
|
||||
|
||||
from . publisher import Publisher
|
||||
from .. base import Publisher
|
||||
|
||||
logger = logging.getLogger("sender")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
from .. schema import Value, Triple
|
||||
|
||||
import base64
|
||||
|
||||
from .. schema import Value, Triple, DocumentPackage, DocumentInfo
|
||||
|
||||
def to_value(x):
|
||||
return Value(value=x["v"], is_uri=x["e"])
|
||||
|
|
@ -77,3 +80,69 @@ def serialize_document_embeddings(message):
|
|||
],
|
||||
}
|
||||
|
||||
def serialize_document_package(message):
|
||||
|
||||
ret = {}
|
||||
|
||||
if message.metadata:
|
||||
ret["metadata"] = serialize_subgraph(message.metdata)
|
||||
|
||||
if message.document:
|
||||
blob = base64.b64encode(
|
||||
message.document.encode("utf-8")
|
||||
).decode("utf-8")
|
||||
ret["document"] = blob
|
||||
|
||||
if message.kind:
|
||||
ret["kind"] = message.kind
|
||||
|
||||
if message.user:
|
||||
ret["user"] = message.user
|
||||
|
||||
if message.collection:
|
||||
ret["collection"] = message.collection
|
||||
|
||||
return ret
|
||||
|
||||
def serialize_document_info(message):
|
||||
|
||||
ret = {}
|
||||
|
||||
if message.metadata:
|
||||
ret["metadata"] = serialize_subgraph(message.metdata)
|
||||
|
||||
if message.kind:
|
||||
ret["kind"] = message.kind
|
||||
|
||||
if message.user:
|
||||
ret["user"] = message.user
|
||||
|
||||
if message.collection:
|
||||
ret["collection"] = message.collection
|
||||
|
||||
return ret
|
||||
|
||||
def to_document_package(x):
|
||||
|
||||
return DocumentPackage(
|
||||
metadata = to_subgraph(x["metadata"]),
|
||||
document = base64.b64decode(x["document"].encode("utf-8")),
|
||||
kind = x.get("kind", None),
|
||||
user = x.get("user", None),
|
||||
collection = x.get("collection", None),
|
||||
)
|
||||
|
||||
def to_document_info(x):
|
||||
|
||||
return DocumentInfo(
|
||||
metadata = to_subgraph(x["metadata"]),
|
||||
kind = x.get("kind", None),
|
||||
user = x.get("user", None),
|
||||
collection = x.get("collection", None),
|
||||
)
|
||||
|
||||
def to_criteria(x):
|
||||
return [
|
||||
Critera(v["key"], v["value"], v["operator"])
|
||||
for v in x
|
||||
]
|
||||
|
|
|
|||
|
|
@ -26,8 +26,6 @@ from .. log_level import LogLevel
|
|||
|
||||
from . serialize import to_subgraph
|
||||
from . running import Running
|
||||
from . publisher import Publisher
|
||||
from . subscriber import Subscriber
|
||||
from . text_completion import TextCompletionRequestor
|
||||
from . prompt import PromptRequestor
|
||||
from . graph_rag import GraphRagRequestor
|
||||
|
|
@ -39,6 +37,7 @@ from . encyclopedia import EncyclopediaRequestor
|
|||
from . agent import AgentRequestor
|
||||
from . dbpedia import DbpediaRequestor
|
||||
from . internet_search import InternetSearchRequestor
|
||||
from . librarian import LibrarianRequestor
|
||||
from . triples_stream import TriplesStreamEndpoint
|
||||
from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
|
||||
from . document_embeddings_stream import DocumentEmbeddingsStreamEndpoint
|
||||
|
|
@ -123,6 +122,10 @@ class Api:
|
|||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
"librarian": LibrarianRequestor(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
),
|
||||
"encyclopedia": EncyclopediaRequestor(
|
||||
pulsar_host=self.pulsar_host, timeout=self.timeout,
|
||||
auth = self.auth,
|
||||
|
|
@ -177,6 +180,10 @@ class Api:
|
|||
endpoint_path = "/api/v1/agent", auth=self.auth,
|
||||
requestor = self.services["agent"],
|
||||
),
|
||||
ServiceEndpoint(
|
||||
endpoint_path = "/api/v1/librarian", auth=self.auth,
|
||||
requestor = self.services["librarian"],
|
||||
),
|
||||
ServiceEndpoint(
|
||||
endpoint_path = "/api/v1/encyclopedia", auth=self.auth,
|
||||
requestor = self.services["encyclopedia"],
|
||||
|
|
|
|||
|
|
@ -1,111 +0,0 @@
|
|||
|
||||
import queue
|
||||
import pulsar
|
||||
import threading
|
||||
import time
|
||||
|
||||
class Subscriber:
|
||||
|
||||
def __init__(self, pulsar_host, topic, subscription, consumer_name,
|
||||
schema=None, max_size=100, listener=None):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.topic = topic
|
||||
self.subscription = subscription
|
||||
self.consumer_name = consumer_name
|
||||
self.schema = schema
|
||||
self.q = {}
|
||||
self.full = {}
|
||||
self.max_size = max_size
|
||||
self.lock = threading.Lock()
|
||||
self.listener_name = listener
|
||||
|
||||
def start(self):
|
||||
self.task = threading.Thread(target=self.run)
|
||||
self.task.start()
|
||||
|
||||
def run(self):
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
|
||||
client = pulsar.Client(
|
||||
self.pulsar_host,
|
||||
listener_name=self.listener_name,
|
||||
)
|
||||
|
||||
consumer = client.subscribe(
|
||||
topic=self.topic,
|
||||
subscription_name=self.subscription,
|
||||
consumer_name=self.consumer_name,
|
||||
schema=self.schema,
|
||||
)
|
||||
|
||||
while True:
|
||||
|
||||
msg = consumer.receive()
|
||||
|
||||
# Acknowledge successful reception of the message
|
||||
consumer.acknowledge(msg)
|
||||
|
||||
try:
|
||||
id = msg.properties()["id"]
|
||||
except:
|
||||
id = None
|
||||
|
||||
value = msg.value()
|
||||
|
||||
with self.lock:
|
||||
|
||||
if id in self.q:
|
||||
try:
|
||||
self.q[id].put(value, timeout=0.5)
|
||||
except:
|
||||
pass
|
||||
|
||||
for q in self.full.values():
|
||||
try:
|
||||
q.put(value, timeout=0.5)
|
||||
except:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
# If handler drops out, sleep a retry
|
||||
time.sleep(2)
|
||||
|
||||
def subscribe(self, id):
|
||||
|
||||
with self.lock:
|
||||
|
||||
q = queue.Queue(maxsize=self.max_size)
|
||||
self.q[id] = q
|
||||
|
||||
return q
|
||||
|
||||
def unsubscribe(self, id):
|
||||
|
||||
with self.lock:
|
||||
|
||||
if id in self.q:
|
||||
# self.q[id].shutdown(immediate=True)
|
||||
del self.q[id]
|
||||
|
||||
def subscribe_all(self, id):
|
||||
|
||||
with self.lock:
|
||||
|
||||
q = queue.Queue(maxsize=self.max_size)
|
||||
self.full[id] = q
|
||||
|
||||
return q
|
||||
|
||||
def unsubscribe_all(self, id):
|
||||
|
||||
with self.lock:
|
||||
|
||||
if id in self.full:
|
||||
# self.full[id].shutdown(immediate=True)
|
||||
del self.full[id]
|
||||
|
||||
|
|
@ -7,8 +7,8 @@ from aiohttp import WSMsgType
|
|||
from .. schema import Metadata
|
||||
from .. schema import Triples
|
||||
from .. schema import triples_store_queue
|
||||
from .. base import Publisher
|
||||
|
||||
from . publisher import Publisher
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import to_subgraph
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import uuid
|
|||
|
||||
from .. schema import Triples
|
||||
from .. schema import triples_store_queue
|
||||
from .. base import Subscriber
|
||||
|
||||
from . subscriber import Subscriber
|
||||
from . socket import SocketEndpoint
|
||||
from . serialize import serialize_triples
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue