Feature/librarian (#307)

* Bring QDrant up-to-date

* Tables for data from queue outputs

- Pass single Pulsar client to everything in gateway & librarian
- Pulsar listener-name support in gateway
- PDF and text load working in librarian

* Complete Cassandra schema

* Add librarian support to templates
This commit is contained in:
cybermaggedon 2025-02-12 23:39:24 +00:00 committed by GitHub
parent f350abb415
commit f7df2df266
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 500 additions and 145 deletions

View file

@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class AgentRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
def __init__(self, pulsar_client, timeout, auth):
super(AgentRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=agent_request_queue,
response_queue=agent_response_queue,
request_schema=AgentRequest,

View file

@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class DbpediaRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
def __init__(self, pulsar_client, timeout, auth):
super(DbpediaRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=dbpedia_lookup_request_queue,
response_queue=dbpedia_lookup_response_queue,
request_schema=LookupRequest,

View file

@ -15,17 +15,17 @@ from . serialize import to_subgraph
class DocumentEmbeddingsLoadEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth, path="/api/v1/load/document-embeddings",
self, pulsar_client, auth, path="/api/v1/load/document-embeddings",
):
super(DocumentEmbeddingsLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.pulsar_client=pulsar_client
self.publisher = Publisher(
self.pulsar_host, document_embeddings_store_queue,
self.pulsar_client, document_embeddings_store_queue,
schema=JsonSchema(DocumentEmbeddings)
)

View file

@ -14,17 +14,18 @@ from . serialize import serialize_document_embeddings
class DocumentEmbeddingsStreamEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth, path="/api/v1/stream/document-embeddings"
self, pulsar_client, auth,
path="/api/v1/stream/document-embeddings"
):
super(DocumentEmbeddingsStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.pulsar_client=pulsar_client
self.subscriber = Subscriber(
self.pulsar_host, document_embeddings_store_queue,
self.pulsar_client, document_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(DocumentEmbeddings)
)

View file

@ -8,10 +8,10 @@ from . sender import ServiceSender
from . serialize import to_subgraph
class DocumentLoadSender(ServiceSender):
def __init__(self, pulsar_host):
def __init__(self, pulsar_client):
super(DocumentLoadSender, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=document_ingest_queue,
request_schema=Document,
)

View file

@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class DocumentRagRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
def __init__(self, pulsar_client, timeout, auth):
super(DocumentRagRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=document_rag_request_queue,
response_queue=document_rag_response_queue,
request_schema=DocumentRagQuery,

View file

@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class EmbeddingsRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
def __init__(self, pulsar_client, timeout, auth):
super(EmbeddingsRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=embeddings_request_queue,
response_queue=embeddings_response_queue,
request_schema=EmbeddingsRequest,

View file

@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class EncyclopediaRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
def __init__(self, pulsar_client, timeout, auth):
super(EncyclopediaRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=encyclopedia_lookup_request_queue,
response_queue=encyclopedia_lookup_response_queue,
request_schema=LookupRequest,

View file

@ -15,17 +15,17 @@ from . serialize import to_subgraph, to_value
class GraphEmbeddingsLoadEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth, path="/api/v1/load/graph-embeddings",
self, pulsar_client, auth, path="/api/v1/load/graph-embeddings",
):
super(GraphEmbeddingsLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.pulsar_client=pulsar_client
self.publisher = Publisher(
self.pulsar_host, graph_embeddings_store_queue,
self.pulsar_client, graph_embeddings_store_queue,
schema=JsonSchema(GraphEmbeddings)
)

View file

@ -8,10 +8,10 @@ from . requestor import ServiceRequestor
from . serialize import serialize_value
class GraphEmbeddingsQueryRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
def __init__(self, pulsar_client, timeout, auth):
super(GraphEmbeddingsQueryRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=graph_embeddings_request_queue,
response_queue=graph_embeddings_response_queue,
request_schema=GraphEmbeddingsRequest,

View file

@ -14,17 +14,17 @@ from . serialize import serialize_graph_embeddings
class GraphEmbeddingsStreamEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth, path="/api/v1/stream/graph-embeddings"
self, pulsar_client, auth, path="/api/v1/stream/graph-embeddings"
):
super(GraphEmbeddingsStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.pulsar_client=pulsar_client
self.subscriber = Subscriber(
self.pulsar_host, graph_embeddings_store_queue,
self.pulsar_client, graph_embeddings_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(GraphEmbeddings)
)

View file

@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class GraphRagRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
def __init__(self, pulsar_client, timeout, auth):
super(GraphRagRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=graph_rag_request_queue,
response_queue=graph_rag_response_queue,
request_schema=GraphRagQuery,

View file

@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class InternetSearchRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
def __init__(self, pulsar_client, timeout, auth):
super(InternetSearchRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=internet_search_request_queue,
response_queue=internet_search_response_queue,
request_schema=LookupRequest,

View file

@ -9,10 +9,10 @@ from . serialize import serialize_document_package, serialize_document_info
from . serialize import to_document_package, to_document_info, to_criteria
class LibrarianRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
def __init__(self, pulsar_client, timeout, auth):
super(LibrarianRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=librarian_request_queue,
response_queue=librarian_response_queue,
request_schema=LibrarianRequest,
@ -22,17 +22,19 @@ class LibrarianRequestor(ServiceRequestor):
def to_request(self, body):
print("TRR")
if "document" in body:
dp = to_document_package(body["document"])
else:
dp = None
print("GOT")
if "criteria" in body:
criteria = to_criteria(body["criteria"])
else:
criteria = None
limit = int(body.get("limit", 10000))
print("ASLDKJ")
return LibrarianRequest(
operation = body.get("operation", None),

View file

@ -18,7 +18,7 @@ MAX_QUEUE_SIZE = 10
class MuxEndpoint(SocketEndpoint):
def __init__(
self, pulsar_host, auth,
self, pulsar_client, auth,
services,
path="/api/v1/socket",
):

View file

@ -9,10 +9,10 @@ from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class PromptRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
def __init__(self, pulsar_client, timeout, auth):
super(PromptRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=prompt_request_queue,
response_queue=prompt_response_queue,
request_schema=PromptRequest,

View file

@ -14,7 +14,7 @@ class ServiceRequestor:
def __init__(
self,
pulsar_host,
pulsar_client,
request_queue, request_schema,
response_queue, response_schema,
subscription="api-gateway", consumer_name="api-gateway",
@ -22,12 +22,12 @@ class ServiceRequestor:
):
self.pub = Publisher(
pulsar_host, request_queue,
schema=JsonSchema(request_schema)
pulsar_client, request_queue,
schema=JsonSchema(request_schema),
)
self.sub = Subscriber(
pulsar_host, response_queue,
pulsar_client, response_queue,
subscription, consumer_name,
JsonSchema(response_schema)
)
@ -53,9 +53,11 @@ class ServiceRequestor:
q = self.sub.subscribe(id)
print("BOUT TO SEDN")
await asyncio.to_thread(
self.pub.send, id, self.to_request(request)
)
print("SENT")
while True:

View file

@ -15,13 +15,13 @@ class ServiceSender:
def __init__(
self,
pulsar_host,
pulsar_client,
request_queue, request_schema,
):
self.pub = Publisher(
pulsar_host, request_queue,
schema=JsonSchema(request_schema)
pulsar_client, request_queue,
schema=JsonSchema(request_schema),
)
async def start(self):
@ -53,4 +53,3 @@ class ServiceSender:
return err

View file

@ -126,7 +126,7 @@ def to_document_package(x):
return DocumentPackage(
metadata = to_subgraph(x["metadata"]),
document = base64.b64decode(x["document"].encode("utf-8")),
document = x.get("document", None),
kind = x.get("kind", None),
user = x.get("user", None),
collection = x.get("collection", None),

View file

@ -73,6 +73,11 @@ class Api:
self.port = int(config.get("port", default_port))
self.timeout = int(config.get("timeout", default_timeout))
self.pulsar_host = config.get("pulsar_host", default_pulsar_host)
self.pulsar_listener = config.get("pulsar_listener", None)
self.pulsar_client = pulsar.Client(
self.pulsar_host, listener_name=self.pulsar_listener
)
self.prometheus_url = config.get(
"prometheus_url", default_prometheus_url,
@ -91,58 +96,58 @@ class Api:
self.services = {
"text-completion": TextCompletionRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
"prompt": PromptRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
"graph-rag": GraphRagRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
"document-rag": DocumentRagRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
"triples-query": TriplesQueryRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
"graph-embeddings-query": GraphEmbeddingsQueryRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
"embeddings": EmbeddingsRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
"agent": AgentRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
"librarian": LibrarianRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
"encyclopedia": EncyclopediaRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
"dbpedia": DbpediaRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
"internet-search": InternetSearchRequestor(
pulsar_host=self.pulsar_host, timeout=self.timeout,
pulsar_client=self.pulsar_client, timeout=self.timeout,
auth = self.auth,
),
"document-load": DocumentLoadSender(
pulsar_host=self.pulsar_host,
pulsar_client=self.pulsar_client,
),
"text-load": TextLoadSender(
pulsar_host=self.pulsar_host,
pulsar_client=self.pulsar_client,
),
}
@ -205,31 +210,31 @@ class Api:
requestor = self.services["text-load"],
),
TriplesStreamEndpoint(
pulsar_host=self.pulsar_host,
pulsar_client=self.pulsar_client,
auth = self.auth,
),
GraphEmbeddingsStreamEndpoint(
pulsar_host=self.pulsar_host,
pulsar_client=self.pulsar_client,
auth = self.auth,
),
DocumentEmbeddingsStreamEndpoint(
pulsar_host=self.pulsar_host,
pulsar_client=self.pulsar_client,
auth = self.auth,
),
TriplesLoadEndpoint(
pulsar_host=self.pulsar_host,
pulsar_client=self.pulsar_client,
auth = self.auth,
),
GraphEmbeddingsLoadEndpoint(
pulsar_host=self.pulsar_host,
pulsar_client=self.pulsar_client,
auth = self.auth,
),
DocumentEmbeddingsLoadEndpoint(
pulsar_host=self.pulsar_host,
pulsar_client=self.pulsar_client,
auth = self.auth,
),
MuxEndpoint(
pulsar_host=self.pulsar_host,
pulsar_client=self.pulsar_client,
auth = self.auth,
services = self.services,
),
@ -266,6 +271,11 @@ def run():
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'--pulsar-listener',
help=f'Pulsar listener (default: none)',
)
parser.add_argument(
'-m', '--prometheus-url',
default=default_prometheus_url,

View file

@ -7,10 +7,10 @@ from . endpoint import ServiceEndpoint
from . requestor import ServiceRequestor
class TextCompletionRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
def __init__(self, pulsar_client, timeout, auth):
super(TextCompletionRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=text_completion_request_queue,
response_queue=text_completion_response_queue,
request_schema=TextCompletionRequest,

View file

@ -8,10 +8,10 @@ from . sender import ServiceSender
from . serialize import to_subgraph
class TextLoadSender(ServiceSender):
def __init__(self, pulsar_host):
def __init__(self, pulsar_client):
super(TextLoadSender, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=text_ingest_queue,
request_schema=TextDocument,
)

View file

@ -14,16 +14,16 @@ from . serialize import to_subgraph
class TriplesLoadEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, auth, path="/api/v1/load/triples"):
def __init__(self, pulsar_client, auth, path="/api/v1/load/triples"):
super(TriplesLoadEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.pulsar_client=pulsar_client
self.publisher = Publisher(
self.pulsar_host, triples_store_queue,
self.pulsar_client, triples_store_queue,
schema=JsonSchema(Triples)
)

View file

@ -8,10 +8,10 @@ from . requestor import ServiceRequestor
from . serialize import to_value, serialize_subgraph
class TriplesQueryRequestor(ServiceRequestor):
def __init__(self, pulsar_host, timeout, auth):
def __init__(self, pulsar_client, timeout, auth):
super(TriplesQueryRequestor, self).__init__(
pulsar_host=pulsar_host,
pulsar_client=pulsar_client,
request_queue=triples_request_queue,
response_queue=triples_response_queue,
request_schema=TriplesQueryRequest,

View file

@ -13,16 +13,16 @@ from . serialize import serialize_triples
class TriplesStreamEndpoint(SocketEndpoint):
def __init__(self, pulsar_host, auth, path="/api/v1/stream/triples"):
def __init__(self, pulsar_client, auth, path="/api/v1/stream/triples"):
super(TriplesStreamEndpoint, self).__init__(
endpoint_path=path, auth=auth,
)
self.pulsar_host=pulsar_host
self.pulsar_client=pulsar_client
self.subscriber = Subscriber(
self.pulsar_host, triples_store_queue,
self.pulsar_client, triples_store_queue,
"api-gateway", "api-gateway",
schema=JsonSchema(Triples)
)

View file

@ -53,4 +53,22 @@ class Librarian:
info = None,
)
def handle_triples(self, m):
self.table_store.add_triples(m)
def handle_graph_embeddings(self, m):
self.table_store.add_graph_embeddings(m)
def handle_document_embeddings(self, m):
self.table_store.add_document_embeddings(m)
def handle_triples(self, m):
self.table_store.add_triples(m)
def handle_graph_embeddings(self, m):
self.table_store.add_graph_embeddings(m)
def handle_document_embeddings(self, m):
self.table_store.add_document_embeddings(m)

View file

@ -7,6 +7,7 @@ from functools import partial
import asyncio
import threading
import queue
import base64
from pulsar.schema import JsonSchema
@ -94,23 +95,38 @@ class Processor(ConsumerProducer):
)
self.document_load = Publisher(
self.pulsar_host, document_load_queue, JsonSchema(Document),
listener=self.pulsar_listener,
self.client, document_load_queue, JsonSchema(Document),
)
self.text_load = Publisher(
self.pulsar_host, text_load_queue, JsonSchema(TextDocument),
listener=self.pulsar_listener,
self.client, text_load_queue, JsonSchema(TextDocument),
)
self.triples_load = Subscriber(
self.pulsar_host, triples_store_queue,
self.triples_brk = Subscriber(
self.client, triples_store_queue,
"librarian", "librarian",
schema=JsonSchema(Triples),
listener=self.pulsar_listener,
)
self.graph_embeddings_brk = Subscriber(
self.client, graph_embeddings_store_queue,
"librarian", "librarian",
schema=JsonSchema(GraphEmbeddings),
)
self.document_embeddings_brk = Subscriber(
self.client, document_embeddings_store_queue,
"librarian", "librarian",
schema=JsonSchema(DocumentEmbeddings),
)
self.triples_reader = threading.Thread(target=self.receive_triples)
self.triples_reader = threading.Thread(
target=self.receive_triples
)
self.graph_embeddings_reader = threading.Thread(
target=self.receive_graph_embeddings
)
self.document_embeddings_reader = threading.Thread(
target=self.receive_document_embeddings
)
self.librarian = Librarian(
cassandra_host = cassandra_host.split(","),
@ -131,34 +147,23 @@ class Processor(ConsumerProducer):
self.document_load.start()
self.text_load.start()
self.triples_load.start()
self.triples_sub = self.triples_load.subscribe_all("x")
self.triples_brk.start()
self.graph_embeddings_brk.start()
self.document_embeddings_brk.start()
self.triples_sub = self.triples_brk.subscribe_all("x")
self.graph_embeddings_sub = self.graph_embeddings_brk.subscribe_all("x")
self.document_embeddings_sub = self.document_embeddings_brk.subscribe_all("x")
self.triples_reader.start()
def receive_triples(self):
print("Receive triples!")
while self.running:
try:
msg = self.triples_sub.get(timeout=1)
except queue.Empty:
print("Tick")
continue
print(msg)
print("BYE")
self.graph_embeddings_reader.start()
self.document_embeddings_reader.start()
def __del__(self):
self.running = False
if hasattr(self, "triples_sub"):
self.triples_sub.unsubscribe_all("x")
if hasattr(self, "document_load"):
self.document_load.stop()
self.document_load.join()
@ -167,9 +172,56 @@ class Processor(ConsumerProducer):
self.text_load.stop()
self.text_load.join()
if hasattr(self, "triples_load"):
self.triples_load.stop()
self.triples_load.join()
if hasattr(self, "triples_sub"):
self.triples_sub.unsubscribe_all("x")
if hasattr(self, "graph_embeddings_sub"):
self.graph_embeddings_sub.unsubscribe_all("x")
if hasattr(self, "document_embeddings_sub"):
self.document_embeddings_sub.unsubscribe_all("x")
if hasattr(self, "triples_brk"):
self.triples_brk.stop()
self.triples_brk.join()
if hasattr(self, "graph_embeddings_brk"):
self.graph_embeddings_brk.stop()
self.graph_embeddings_brk.join()
if hasattr(self, "document_embeddings_brk"):
self.document_embeddings_brk.stop()
self.document_embeddings_brk.join()
def receive_triples(self):
while self.running:
try:
msg = self.triples_sub.get(timeout=1)
except queue.Empty:
continue
self.librarian.handle_triples(msg)
def receive_graph_embeddings(self):
while self.running:
try:
msg = self.graph_embeddings_sub.get(timeout=1)
except queue.Empty:
continue
self.librarian.handle_graph_embeddings(msg)
def receive_document_embeddings(self):
while self.running:
try:
msg = self.document_embeddings_sub.get(timeout=1)
except queue.Empty:
continue
self.librarian.handle_document_embeddings(msg)
async def load_document(self, id, document):
@ -187,6 +239,9 @@ class Processor(ConsumerProducer):
async def load_text(self, id, document):
text = base64.b64decode(document.document)
text = text.decode("utf-8")
doc = TextDocument(
metadata = Metadata(
id = id,
@ -194,7 +249,7 @@ class Processor(ConsumerProducer):
user = document.user,
collection = document.collection
),
text = document.document
text = text,
)
self.text_load.send(None, doc)

View file

@ -36,11 +36,7 @@ class TableStore:
self.ensure_cassandra_schema()
self.insert_document_stmt = self.cassandra.prepare("""
insert into document
(id, user, collection, kind, object_id, metadata)
values (?, ?, ?, ?, ?, ?)
""")
self.prepare_statements()
def ensure_cassandra_schema(self):
@ -62,10 +58,13 @@ class TableStore:
print("document table...", flush=True)
self.cassandra.execute("""
create table if not exists document (
CREATE TABLE IF NOT EXISTS document (
user text,
collection text,
id uuid,
time timestamp,
title text,
comments text,
kind text,
object_id uuid,
metadata list<tuple<
@ -78,12 +77,113 @@ class TableStore:
print("object index...", flush=True)
self.cassandra.execute("""
create index if not exists document_object
on document ( object_id)
CREATE INDEX IF NOT EXISTS document_object
ON document (object_id)
""");
print("triples table...", flush=True)
self.cassandra.execute("""
CREATE TABLE IF NOT EXISTS triples (
user text,
collection text,
document_id text,
id uuid,
time timestamp,
metadata list<tuple<
text, boolean, text, boolean, text, boolean
>>,
triples list<tuple<
text, boolean, text, boolean, text, boolean
>>,
PRIMARY KEY (user, collection, document_id, id)
);
""");
print("graph_embeddings table...", flush=True)
self.cassandra.execute("""
create table if not exists graph_embeddings (
user text,
collection text,
document_id text,
id uuid,
time timestamp,
metadata list<tuple<
text, boolean, text, boolean, text, boolean
>>,
entity_embeddings list<
tuple<
tuple<text, boolean>,
list<list<double>>
>
>,
PRIMARY KEY (user, collection, document_id, id)
);
""");
print("document_embeddings table...", flush=True)
self.cassandra.execute("""
create table if not exists document_embeddings (
user text,
collection text,
document_id text,
id uuid,
time timestamp,
metadata list<tuple<
text, boolean, text, boolean, text, boolean
>>,
chunks list<
tuple<
blob,
list<list<double>>
>
>,
PRIMARY KEY (user, collection, document_id, id)
);
""");
print("Cassandra schema OK.", flush=True)
def prepare_statements(self):
self.insert_document_stmt = self.cassandra.prepare("""
INSERT INTO document
(
id, user, collection, kind, object_id, time, title, comments,
metadata
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""")
self.insert_triples_stmt = self.cassandra.prepare("""
INSERT INTO triples
(
id, user, collection, document_id, time,
metadata, triples
)
VALUES (?, ?, ?, ?, ?, ?, ?)
""")
self.insert_graph_embeddings_stmt = self.cassandra.prepare("""
INSERT INTO graph_embeddings
(
id, user, collection, document_id, time,
metadata, entity_embeddings
)
VALUES (?, ?, ?, ?, ?, ?, ?)
""")
self.insert_document_embeddings_stmt = self.cassandra.prepare("""
INSERT INTO document_embeddings
(
id, user, collection, document_id, time,
metadata, chunks
)
VALUES (?, ?, ?, ?, ?, ?, ?)
""")
def add(self, object_id, document):
if document.kind not in (
@ -93,6 +193,7 @@ class TableStore:
# Create random doc ID
doc_id = uuid.uuid4()
when = int(time.time() * 1000)
print("Adding", object_id, doc_id)
@ -104,6 +205,8 @@ class TableStore:
for v in document.metadata
]
# FIXME: doc_id should be the user-supplied ID???
while True:
try:
@ -111,8 +214,10 @@ class TableStore:
resp = self.cassandra.execute(
self.insert_document_stmt,
(
doc_id, document.user, document.collection,
document.kind, object_id, metadata
doc_id, document.user, document.collection,
document.kind, object_id, when,
document.title, document.comments,
metadata
)
)
@ -126,6 +231,136 @@ class TableStore:
print("Add complete", flush=True)
def add_triples(self, m):
when = int(time.time() * 1000)
if m.metadata.metadata:
metadata = [
(
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
v.o.value, v.o.is_uri
)
for v in m.metadata.metadata
]
else:
metadata = []
triples = [
(
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
v.o.value, v.o.is_uri
)
for v in m.triples
]
while True:
try:
resp = self.cassandra.execute(
self.insert_triples_stmt,
(
uuid.uuid4(), m.metadata.user,
m.metadata.collection, m.metadata.id, when,
metadata, triples,
)
)
break
except Exception as e:
print("Exception:", type(e))
print(f"{e}, retry...", flush=True)
time.sleep(1)
def add_graph_embeddings(self, m):
when = int(time.time() * 1000)
if m.metadata.metadata:
metadata = [
(
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
v.o.value, v.o.is_uri
)
for v in m.metadata.metadata
]
else:
metadata = []
entities = [
(
(v.entity.value, v.entity.is_uri),
v.vectors
)
for v in m.entities
]
while True:
try:
resp = self.cassandra.execute(
self.insert_graph_embeddings_stmt,
(
uuid.uuid4(), m.metadata.user,
m.metadata.collection, m.metadata.id, when,
metadata, entities,
)
)
break
except Exception as e:
print("Exception:", type(e))
print(f"{e}, retry...", flush=True)
time.sleep(1)
def add_document_embeddings(self, m):
when = int(time.time() * 1000)
if m.metadata.metadata:
metadata = [
(
v.s.value, v.s.is_uri, v.p.value, v.p.is_uri,
v.o.value, v.o.is_uri
)
for v in m.metadata.metadata
]
else:
metadata = []
chunks = [
(
v.chunk,
v.vectors,
)
for v in m.chunks
]
while True:
try:
resp = self.cassandra.execute(
self.insert_document_embeddings_stmt,
(
uuid.uuid4(), m.metadata.user,
m.metadata.collection, m.metadata.id, when,
metadata, chunks,
)
)
break
except Exception as e:
print("Exception:", type(e))
print(f"{e}, retry...", flush=True)
time.sleep(1)