Librarian (#304)

This commit is contained in:
cybermaggedon 2025-02-11 16:01:03 +00:00 committed by GitHub
parent e99c0ac238
commit a0bf2362f6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 922 additions and 66 deletions

View file

@ -0,0 +1,6 @@
#!/usr/bin/env python3
from trustgraph.librarian import run
run()

View file

@ -49,6 +49,7 @@ setuptools.setup(
"langchain-community",
"langchain-core",
"langchain-text-splitters",
"minio",
"neo4j",
"ollama",
"openai",
@ -78,8 +79,8 @@ setuptools.setup(
"scripts/de-write-qdrant",
"scripts/document-embeddings",
"scripts/document-rag",
"scripts/embeddings-ollama",
"scripts/embeddings-fastembed",
"scripts/embeddings-ollama",
"scripts/ge-query-milvus",
"scripts/ge-query-pinecone",
"scripts/ge-query-qdrant",
@ -91,6 +92,7 @@ setuptools.setup(
"scripts/kg-extract-definitions",
"scripts/kg-extract-relationships",
"scripts/kg-extract-topics",
"scripts/librarian",
"scripts/metering",
"scripts/object-extract-row",
"scripts/oe-write-milvus",

View file

@ -7,8 +7,8 @@ from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import DocumentEmbeddings, ChunkEmbeddings
from .. schema import document_embeddings_store_queue
from .. base import Publisher
from . publisher import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph

View file

@ -6,8 +6,8 @@ import uuid
from .. schema import DocumentEmbeddings
from .. schema import document_embeddings_store_queue
from .. base import Subscriber
from . subscriber import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_document_embeddings

View file

@ -5,8 +5,8 @@ from aiohttp import web
import uuid
import logging
from . publisher import Publisher
from . subscriber import Subscriber
from .. base import Publisher
from .. base import Subscriber
logger = logging.getLogger("endpoint")
logger.setLevel(logging.INFO)

View file

@ -7,8 +7,8 @@ from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import GraphEmbeddings, EntityEmbeddings
from .. schema import graph_embeddings_store_queue
from .. base import Publisher
from . publisher import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph, to_value

View file

@ -6,8 +6,8 @@ import uuid
from .. schema import GraphEmbeddings
from .. schema import graph_embeddings_store_queue
from .. base import Subscriber
from . subscriber import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_graph_embeddings

View file

@ -0,0 +1,57 @@
from .. schema import LibrarianRequest, LibrarianResponse, Triples
from .. schema import librarian_request_queue
from .. schema import librarian_response_queue
from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
from . serialize import serialize_document_package, serialize_document_info
from . serialize import to_document_package, to_document_info, to_criteria
class LibrarianRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
super(LibrarianRequestor, self).__init__(
pulsar_host=pulsar_host,
request_queue=librarian_request_queue,
response_queue=librarian_response_queue,
request_schema=LibrarianRequest,
response_schema=LibrarianResponse,
timeout=timeout,
)
def to_request(self, body):
if "document" in body:
dp = to_document_package(body["document"])
else:
dp = None
if "criteria" in body:
criteria = to_criteria(body["criteria"])
else:
criteria = None
limit = int(body.get("limit", 10000))
return LibrarianRequest(
operation = body.get("operation", None),
id = body.get("id", None),
document = dp,
user = body.get("user", None),
collection = body.get("collection", None),
criteria = criteria,
)
def from_response(self, message):
response = {}
if message.document:
response["document"] = serialize_document_package(message.document)
if message.info:
response["info"] = serialize_document_info(message.info)
return response, True

View file

@ -1,54 +0,0 @@
import queue
import time
import pulsar
import threading
class Publisher:
def __init__(self, pulsar_host, topic, schema=None, max_size=10,
chunking_enabled=True, listener=None):
self.pulsar_host = pulsar_host
self.topic = topic
self.schema = schema
self.q = queue.Queue(maxsize=max_size)
self.chunking_enabled = chunking_enabled
self.listener_name = listener
def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()
def run(self):
while True:
try:
client = pulsar.Client(
self.pulsar_host, listener_name=self.listener_name
)
producer = client.create_producer(
topic=self.topic,
schema=self.schema,
chunking_enabled=self.chunking_enabled,
)
while True:
id, item = self.q.get()
if id:
producer.send(item, { "id": id })
else:
producer.send(item)
except Exception as e:
print("Exception:", e, flush=True)
# If handler drops out, sleep a retry
time.sleep(2)
def send(self, id, msg):
self.q.put((id, msg))

View file

@ -4,8 +4,8 @@ from pulsar.schema import JsonSchema
import uuid
import logging
from . publisher import Publisher
from . subscriber import Subscriber
from .. base import Publisher
from .. base import Subscriber
logger = logging.getLogger("requestor")
logger.setLevel(logging.INFO)
@ -68,7 +68,10 @@ class ServiceRequestor:
raise RuntimeError("Timeout")
if resp.error:
err = { "error": resp.error.message }
err = { "error": {
"type": resp.error.type,
"message": resp.error.message,
} }
if responder:
await responder(err, True)
return err
@ -87,7 +90,10 @@ class ServiceRequestor:
logging.error(f"Exception: {e}")
err = { "error": str(e) }
err = { "error": {
"type": "gateway-error",
"message": str(e),
} }
if responder:
await responder(err, True)
return err

View file

@ -6,7 +6,7 @@ from pulsar.schema import JsonSchema
import uuid
import logging
from . publisher import Publisher
from .. base import Publisher
logger = logging.getLogger("sender")
logger.setLevel(logging.INFO)

View file

@ -1,4 +1,7 @@
from .. schema import Value, Triple
import base64
from .. schema import Value, Triple, DocumentPackage, DocumentInfo
def to_value(x):
return Value(value=x["v"], is_uri=x["e"])
@ -77,3 +80,69 @@ def serialize_document_embeddings(message):
],
}
def serialize_document_package(message):
ret = {}
if message.metadata:
ret["metadata"] = serialize_subgraph(message.metdata)
if message.document:
blob = base64.b64encode(
message.document.encode("utf-8")
).decode("utf-8")
ret["document"] = blob
if message.kind:
ret["kind"] = message.kind
if message.user:
ret["user"] = message.user
if message.collection:
ret["collection"] = message.collection
return ret
def serialize_document_info(message):
ret = {}
if message.metadata:
ret["metadata"] = serialize_subgraph(message.metdata)
if message.kind:
ret["kind"] = message.kind
if message.user:
ret["user"] = message.user
if message.collection:
ret["collection"] = message.collection
return ret
def to_document_package(x):
return DocumentPackage(
metadata = to_subgraph(x["metadata"]),
document = base64.b64decode(x["document"].encode("utf-8")),
kind = x.get("kind", None),
user = x.get("user", None),
collection = x.get("collection", None),
)
def to_document_info(x):
return DocumentInfo(
metadata = to_subgraph(x["metadata"]),
kind = x.get("kind", None),
user = x.get("user", None),
collection = x.get("collection", None),
)
def to_criteria(x):
return [
Critera(v["key"], v["value"], v["operator"])
for v in x
]

View file

@ -26,8 +26,6 @@ from .. log_level import LogLevel
from . serialize import to_subgraph
from . running import Running
from . publisher import Publisher
from . subscriber import Subscriber
from . text_completion import TextCompletionRequestor
from . prompt import PromptRequestor
from . graph_rag import GraphRagRequestor
@ -39,6 +37,7 @@ from . encyclopedia import EncyclopediaRequestor
from . agent import AgentRequestor
from . dbpedia import DbpediaRequestor
from . internet_search import InternetSearchRequestor
from . librarian import LibrarianRequestor
from . triples_stream import TriplesStreamEndpoint
from . graph_embeddings_stream import GraphEmbeddingsStreamEndpoint
from . document_embeddings_stream import DocumentEmbeddingsStreamEndpoint
@ -123,6 +122,10 @@ class Api:
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"librarian": LibrarianRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
),
"encyclopedia": EncyclopediaRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
auth = self.auth,
@ -177,6 +180,10 @@ class Api:
endpoint_path = "/api/v1/agent", auth=self.auth,
requestor = self.services["agent"],
),
ServiceEndpoint(
endpoint_path = "/api/v1/librarian", auth=self.auth,
requestor = self.services["librarian"],
),
ServiceEndpoint(
endpoint_path = "/api/v1/encyclopedia", auth=self.auth,
requestor = self.services["encyclopedia"],

View file

@ -1,111 +0,0 @@
import queue
import pulsar
import threading
import time
class Subscriber:
def __init__(self, pulsar_host, topic, subscription, consumer_name,
schema=None, max_size=100, listener=None):
self.pulsar_host = pulsar_host
self.topic = topic
self.subscription = subscription
self.consumer_name = consumer_name
self.schema = schema
self.q = {}
self.full = {}
self.max_size = max_size
self.lock = threading.Lock()
self.listener_name = listener
def start(self):
self.task = threading.Thread(target=self.run)
self.task.start()
def run(self):
while True:
try:
client = pulsar.Client(
self.pulsar_host,
listener_name=self.listener_name,
)
consumer = client.subscribe(
topic=self.topic,
subscription_name=self.subscription,
consumer_name=self.consumer_name,
schema=self.schema,
)
while True:
msg = consumer.receive()
# Acknowledge successful reception of the message
consumer.acknowledge(msg)
try:
id = msg.properties()["id"]
except:
id = None
value = msg.value()
with self.lock:
if id in self.q:
try:
self.q[id].put(value, timeout=0.5)
except:
pass
for q in self.full.values():
try:
q.put(value, timeout=0.5)
except:
pass
except Exception as e:
print("Exception:", e, flush=True)
# If handler drops out, sleep a retry
time.sleep(2)
def subscribe(self, id):
with self.lock:
q = queue.Queue(maxsize=self.max_size)
self.q[id] = q
return q
def unsubscribe(self, id):
with self.lock:
if id in self.q:
# self.q[id].shutdown(immediate=True)
del self.q[id]
def subscribe_all(self, id):
with self.lock:
q = queue.Queue(maxsize=self.max_size)
self.full[id] = q
return q
def unsubscribe_all(self, id):
with self.lock:
if id in self.full:
# self.full[id].shutdown(immediate=True)
del self.full[id]

View file

@ -7,8 +7,8 @@ from aiohttp import WSMsgType
from .. schema import Metadata
from .. schema import Triples
from .. schema import triples_store_queue
from .. base import Publisher
from . publisher import Publisher
from . socket import SocketEndpoint
from . serialize import to_subgraph

View file

@ -6,8 +6,8 @@ import uuid
from .. schema import Triples
from .. schema import triples_store_queue
from .. base import Subscriber
from . subscriber import Subscriber
from . socket import SocketEndpoint
from . serialize import serialize_triples

View file

@ -0,0 +1,3 @@
from . service import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . service import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,51 @@
from .. schema import LibrarianRequest, LibrarianResponse, Error
from .. knowledge import hash
from .. exceptions import RequestError
from minio import Minio
import time
import io
class BlobStore:
def __init__(
self,
minio_host, minio_access_key, minio_secret_key, bucket_name,
):
self.minio = Minio(
minio_host,
access_key = minio_access_key,
secret_key = minio_secret_key,
secure = False,
)
self.bucket_name = bucket_name
print("Connected to minio", flush=True)
self.ensure_bucket()
def ensure_bucket(self):
# Make the bucket if it doesn't exist.
found = self.minio.bucket_exists(self.bucket_name)
if not found:
self.minio.make_bucket(self.bucket_name)
print("Created bucket", self.bucket_name, flush=True)
else:
print("Bucket", self.bucket_name, "already exists", flush=True)
def add(self, object_id, blob, kind):
# FIXME: Loop retry
self.minio.put_object(
bucket_name = self.bucket_name,
object_name = "doc/" + str(object_id),
length = len(blob),
data = io.BytesIO(blob),
content_type = kind,
)
print("Add blob complete", flush=True)

View file

@ -0,0 +1,55 @@
from .. schema import LibrarianRequest, LibrarianResponse, Error
from .. knowledge import hash
from .. exceptions import RequestError
from . table_store import TableStore
from . blob_store import BlobStore
import uuid
class Librarian:
def __init__(
self,
cassandra_host, cassandra_user, cassandra_password,
minio_host, minio_access_key, minio_secret_key,
bucket_name, keyspace, load_document, load_text,
):
self.blob_store = BlobStore(
minio_host, minio_access_key, minio_secret_key, bucket_name
)
self.table_store = TableStore(
cassandra_host, cassandra_user, cassandra_password, keyspace
)
self.load_document = load_document
self.load_text = load_text
def add(self, id, document):
if document.kind not in (
"text/plain", "application/pdf"
):
raise RequestError("Invalid document kind: " + document.kind)
# Create object ID as a hash of the document
object_id = uuid.UUID(hash(document.document))
self.blob_store.add(object_id, document.document, document.kind)
self.table_store.add(object_id, document)
if document.kind == "application/pdf":
self.load_document(id, document)
elif document.kind == "text/plain":
self.load_text(id, document)
print("Add complete", flush=True)
return LibrarianResponse(
error = None,
document = None,
info = None,
)

View file

@ -0,0 +1,352 @@
"""
Librarian service, manages documents in collections
"""
from functools import partial
import asyncio
import threading
import queue
from pulsar.schema import JsonSchema
from .. schema import LibrarianRequest, LibrarianResponse, Error
from .. schema import librarian_request_queue, librarian_response_queue
from .. schema import GraphEmbeddings
from .. schema import graph_embeddings_store_queue
from .. schema import Triples
from .. schema import triples_store_queue
from .. schema import DocumentEmbeddings
from .. schema import document_embeddings_store_queue
from .. schema import Document, Metadata
from .. schema import document_ingest_queue
from .. schema import TextDocument, Metadata
from .. schema import text_ingest_queue
from .. base import Publisher
from .. base import Subscriber
from .. log_level import LogLevel
from .. base import ConsumerProducer
from .. exceptions import RequestError
from . librarian import Librarian
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = librarian_request_queue
default_output_queue = librarian_response_queue
default_subscriber = module
default_minio_host = "minio:9000"
default_minio_access_key = "minioadmin"
default_minio_secret_key = "minioadmin"
default_cassandra_host = "cassandra"
bucket_name = "library"
# FIXME: How to ensure this doesn't conflict with other usage?
keyspace = "librarian"
class Processor(ConsumerProducer):
def __init__(self, **params):
self.running = True
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
minio_host = params.get("minio_host", default_minio_host)
minio_access_key = params.get(
"minio_access_key",
default_minio_access_key
)
minio_secret_key = params.get(
"minio_secret_key",
default_minio_secret_key
)
cassandra_host = params.get("cassandra_host", default_cassandra_host)
cassandra_user = params.get("cassandra_user")
cassandra_password = params.get("cassandra_password")
triples_queue = params.get("triples_queue")
graph_embeddings_queue = params.get("graph_embeddings_queue")
document_embeddings_queue = params.get("document_embeddings_queue")
document_load_queue = params.get("document_load_queue")
text_load_queue = params.get("text_load_queue")
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": LibrarianRequest,
"output_schema": LibrarianResponse,
"minio_host": minio_host,
"minio_access_key": minio_access_key,
"cassandra_host": cassandra_host,
"cassandra_user": cassandra_user,
}
)
self.document_load = Publisher(
self.pulsar_host, document_load_queue, JsonSchema(Document),
listener=self.pulsar_listener,
)
self.text_load = Publisher(
self.pulsar_host, text_load_queue, JsonSchema(TextDocument),
listener=self.pulsar_listener,
)
self.triples_load = Subscriber(
self.pulsar_host, triples_store_queue,
"librarian", "librarian",
schema=JsonSchema(Triples),
listener=self.pulsar_listener,
)
self.document_load.start()
self.text_load.start()
self.triples_load.start()
self.triples_sub = self.triples_load.subscribe_all("x")
self.triples_reader = threading.Thread(target=self.receive_triples)
self.triples_reader.start()
self.librarian = Librarian(
cassandra_host = cassandra_host.split(","),
cassandra_user = cassandra_user,
cassandra_password = cassandra_password,
minio_host = minio_host,
minio_access_key = minio_access_key,
minio_secret_key = minio_secret_key,
bucket_name = bucket_name,
keyspace = keyspace,
load_document = self.load_document,
load_text = self.load_text,
)
print("Initialised.", flush=True)
def receive_triples(self):
print("Receive triples!")
while self.running:
try:
msg = self.triples_sub.get(timeout=1)
except queue.Empty:
print("Tick")
continue
print(msg)
print("BYE")
def __del__(self):
self.running = False
if hasattr(self, "triples_sub"):
self.triples_sub.unsubscribe_all("x")
if hasattr(self, "document_load"):
self.document_load.stop()
self.document_load.join()
if hasattr(self, "text_load"):
self.text_load.stop()
self.text_load.join()
if hasattr(self, "triples_load"):
self.triples_load.stop()
self.triples_load.join()
def load_document(self, id, document):
doc = Document(
metadata = Metadata(
id = id,
metadata = document.metadata,
user = document.user,
collection = document.collection
),
data = document.document
)
self.document_load.send(None, doc)
def load_text(self, id, document):
doc = TextDocument(
metadata = Metadata(
id = id,
metadata = document.metadata,
user = document.user,
collection = document.collection
),
text = document.document
)
self.text_load.send(None, doc)
def parse_request(self, v):
if v.operation is None:
raise RequestError("Null operation")
if v.operation == "add":
print(v)
if (
v.id and v.document and v.document.metadata and
v.document.document and v.document.kind
):
return partial(
self.librarian.add,
id = v.id,
document = v.document,
)
else:
raise RequestError("Invalid call")
raise RequestError("Invalid operation: " + v.operation)
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
try:
func = self.parse_request(v)
except RequestError as e:
resp = LibrarianResponse(
error = Error(
type = "request-error",
message = str(e),
)
)
self.producer.send(resp, properties={"id": id})
return
try:
resp = func()
except RequestError as e:
resp = LibrarianResponse(
error = Error(
type = "request-error",
message = str(e),
)
)
self.producer.send(resp, properties={"id": id})
return
except Exception as e:
print("Exception:", e, flush=True)
resp = LibrarianResponse(
error = Error(
type = "processing-error",
message = "Unhandled error: " + str(e),
)
)
self.producer.send(resp, properties={"id": id})
return
print("Send response...", flush=True)
self.producer.send(resp, properties={"id": id})
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--minio-host',
default=default_minio_host,
help=f'Minio hostname (default: {default_minio_host})',
)
parser.add_argument(
'--minio-access-key',
default='minioadmin',
help='Minio access key / username '
f'(default: {default_minio_access_key})',
)
parser.add_argument(
'--minio-secret-key',
default='minioadmin',
help='Minio secret key / password '
f'(default: {default_minio_access_key})',
)
parser.add_argument(
'--cassandra-host',
default="cassandra",
help=f'Graph host (default: cassandra)'
)
parser.add_argument(
'--cassandra-user',
default=None,
help=f'Cassandra user'
)
parser.add_argument(
'--cassandra-password',
default=None,
help=f'Cassandra password'
)
parser.add_argument(
'--triples-queue',
default=triples_store_queue,
help=f'Triples queue (default: {triples_store_queue})'
)
parser.add_argument(
'--graph-embeddings-queue',
default=graph_embeddings_store_queue,
help=f'Graph embeddings queue (default: {triples_store_queue})'
)
parser.add_argument(
'--document-embeddings-queue',
default=document_embeddings_store_queue,
help='Document embeddings queue '
f'(default: {document_embeddings_store_queue})'
)
parser.add_argument(
'--document-load-queue',
default=document_ingest_queue,
help='Document load queue '
f'(default: {document_ingest_queue})'
)
parser.add_argument(
'--text-load-queue',
default=text_ingest_queue,
help='Text ingest queue '
f'(default: {text_ingest_queue})'
)
def run():
Processor.start(module, __doc__)

View file

@ -0,0 +1,131 @@
from .. schema import LibrarianRequest, LibrarianResponse, Error
from .. knowledge import hash
from .. exceptions import RequestError
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from cassandra.query import BatchStatement
import uuid
import time
class TableStore:
def __init__(
self,
cassandra_host, cassandra_user, cassandra_password, keyspace,
):
self.keyspace = keyspace
print("Connecting to Cassandra...", flush=True)
if cassandra_user and cassandra_password:
auth_provider = PlainTextAuthProvider(
username=cassandra_user, password=cassandra_password
)
self.cluster = Cluster(
cassandra_host,
auth_provider=auth_provider
)
else:
self.cluster = Cluster(cassandra_host)
self.cassandra = self.cluster.connect()
print("Connected.", flush=True)
self.ensure_cassandra_schema()
self.insert_document_stmt = self.cassandra.prepare("""
insert into document
(id, user, collection, kind, object_id, metadata)
values (?, ?, ?, ?, ?, ?)
""")
def ensure_cassandra_schema(self):
print("Ensure Cassandra schema...", flush=True)
print("Keyspace...", flush=True)
# FIXME: Replication factor should be configurable
self.cassandra.execute(f"""
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
}};
""");
self.cassandra.set_keyspace(self.keyspace)
print("document table...", flush=True)
self.cassandra.execute("""
create table if not exists document (
user text,
collection text,
id uuid,
kind text,
object_id uuid,
metadata list<tuple<
text, boolean, text, boolean, text, boolean
>>,
PRIMARY KEY (user, collection, id)
);
""");
print("object index...", flush=True)
self.cassandra.execute("""
create index if not exists document_object
on document ( object_id)
""");
print("Cassandra schema OK.", flush=True)
def add(self, object_id, document):
if document.kind not in (
"text/plain", "application/pdf"
):
raise RequestError("Invalid document kind: " + document.kind)
# Create random doc ID
doc_id = uuid.uuid4()
print("Adding", object_id, doc_id)
metadata = [
(
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
v.o.value, v.o.is_uri
)
for v in document.metadata
]
while True:
try:
resp = self.cassandra.execute(
self.insert_document_stmt,
(
doc_id, document.user, document.collection,
document.kind, object_id, metadata
)
)
break
except Exception as e:
print("Exception:", type(e))
print(f"{e}, retry...", flush=True)
time.sleep(1)
print("Add complete", flush=True)