diff --git a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py index bb1358fc..401266bc 100644 --- a/trustgraph-base/trustgraph/clients/graph_embeddings_client.py +++ b/trustgraph-base/trustgraph/clients/graph_embeddings_client.py @@ -38,8 +38,12 @@ class GraphEmbeddingsClient(BaseClient): output_schema=GraphEmbeddingsResponse, ) - def request(self, vectors, limit=10, timeout=300): + def request( + self, vectors, user="trustgraph", collection="default", + limit=10, timeout=300 + ): return self.call( + user=user, collection=collection, vectors=vectors, limit=limit, timeout=timeout ).entities diff --git a/trustgraph-base/trustgraph/clients/graph_rag_client.py b/trustgraph-base/trustgraph/clients/graph_rag_client.py index e546aaed..9f8eff62 100644 --- a/trustgraph-base/trustgraph/clients/graph_rag_client.py +++ b/trustgraph-base/trustgraph/clients/graph_rag_client.py @@ -38,9 +38,12 @@ class GraphRagClient(BaseClient): output_schema=GraphRagResponse, ) - def request(self, query, timeout=500): + def request( + self, query, user="trustgraph", collection="default", + timeout=500 + ): return self.call( - query=query, timeout=timeout + user=user, collection=collection, query=query, timeout=timeout ).response diff --git a/trustgraph-base/trustgraph/clients/triples_query_client.py b/trustgraph-base/trustgraph/clients/triples_query_client.py index 14b75151..6c18ac3f 100644 --- a/trustgraph-base/trustgraph/clients/triples_query_client.py +++ b/trustgraph-base/trustgraph/clients/triples_query_client.py @@ -48,11 +48,18 @@ class TriplesQueryClient(BaseClient): return Value(value=ent, is_uri=False) - def request(self, s, p, o, limit=10, timeout=60): + def request( + self, + s, p, o, + user="trustgraph", collection="default", + limit=10, timeout=60, + ): return self.call( s=self.create_value(s), p=self.create_value(p), o=self.create_value(o), + user=user, + collection=collection, limit=limit, timeout=timeout, ).triples diff --git a/trustgraph-base/trustgraph/schema/__init__.py b/trustgraph-base/trustgraph/schema/__init__.py index 0cd5a370..f3cc5b60 100644 --- a/trustgraph-base/trustgraph/schema/__init__.py +++ b/trustgraph-base/trustgraph/schema/__init__.py @@ -7,6 +7,7 @@ from . object import * from . topic import * from . graph import * from . retrieval import * +from . metadata import * diff --git a/trustgraph-base/trustgraph/schema/documents.py b/trustgraph-base/trustgraph/schema/documents.py index d80ff38f..59aba287 100644 --- a/trustgraph-base/trustgraph/schema/documents.py +++ b/trustgraph-base/trustgraph/schema/documents.py @@ -2,17 +2,13 @@ from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double from . topic import topic from . types import Error - -class Source(Record): - source = String() - id = String() - title = String() +from . metadata import Metadata ############################################################################ # PDF docs etc. class Document(Record): - source = Source() + metadata = Metadata() data = Bytes() document_ingest_queue = topic('document-load') @@ -22,7 +18,7 @@ document_ingest_queue = topic('document-load') # Text documents / text from PDF class TextDocument(Record): - source = Source() + metadata = Metadata() text = Bytes() text_ingest_queue = topic('text-document-load') @@ -32,7 +28,7 @@ text_ingest_queue = topic('text-document-load') # Chunks of text class Chunk(Record): - source = Source() + metadata = Metadata() chunk = Bytes() chunk_ingest_queue = topic('chunk-load') @@ -42,7 +38,7 @@ chunk_ingest_queue = topic('chunk-load') # Chunk embeddings are an embeddings associated with a text chunk class ChunkEmbeddings(Record): - source = Source() + metadata = Metadata() vectors = Array(Array(Double())) chunk = Bytes() diff --git a/trustgraph-base/trustgraph/schema/graph.py b/trustgraph-base/trustgraph/schema/graph.py index 234a0bed..107478a4 100644 --- a/trustgraph-base/trustgraph/schema/graph.py +++ b/trustgraph-base/trustgraph/schema/graph.py @@ -1,16 +1,16 @@ from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double -from . documents import Source from . types import Error, Value from . topic import topic +from . metadata import Metadata ############################################################################ # Graph embeddings are embeddings associated with a graph entity class GraphEmbeddings(Record): - source = Source() + metadata = Metadata() vectors = Array(Array(Double())) entity = Value() @@ -23,6 +23,8 @@ graph_embeddings_store_queue = topic('graph-embeddings-store') class GraphEmbeddingsRequest(Record): vectors = Array(Array(Double())) limit = Integer() + user = String() + collection = String() class GraphEmbeddingsResponse(Record): error = Error() @@ -40,7 +42,7 @@ graph_embeddings_response_queue = topic( # Graph triples class Triple(Record): - source = Source() + metadata = Metadata() s = Value() p = Value() o = Value() @@ -56,6 +58,8 @@ class TriplesQueryRequest(Record): p = Value() o = Value() limit = Integer() + user = String() + collection = String() class TriplesQueryResponse(Record): error = Error() diff --git a/trustgraph-base/trustgraph/schema/metadata.py b/trustgraph-base/trustgraph/schema/metadata.py new file mode 100644 index 00000000..c7dbbae6 --- /dev/null +++ b/trustgraph-base/trustgraph/schema/metadata.py @@ -0,0 +1,10 @@ + +from pulsar.schema import Record, String + +class Metadata(Record): + source = String() + id = String() + title = String() + user = String() + collection = String() + diff --git a/trustgraph-base/trustgraph/schema/object.py b/trustgraph-base/trustgraph/schema/object.py index 3377e2df..60c2bdc3 100644 --- a/trustgraph-base/trustgraph/schema/object.py +++ b/trustgraph-base/trustgraph/schema/object.py @@ -2,7 +2,7 @@ from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array from pulsar.schema import Double, Map -from . documents import Source +from . metadata import Metadata from . types import Value, RowSchema from . topic import topic @@ -12,7 +12,7 @@ from . topic import topic # object class ObjectEmbeddings(Record): - source = Source() + metadata = Metadata() vectors = Array(Array(Double())) name = String() key_name = String() @@ -25,7 +25,7 @@ object_embeddings_store_queue = topic('object-embeddings-store') # Stores rows of information class Rows(Record): - source = Source() + metadata = Metadata() row_schema = RowSchema() rows = Array(Map(String())) diff --git a/trustgraph-base/trustgraph/schema/retrieval.py b/trustgraph-base/trustgraph/schema/retrieval.py index fa0288fc..ad860c3c 100644 --- a/trustgraph-base/trustgraph/schema/retrieval.py +++ b/trustgraph-base/trustgraph/schema/retrieval.py @@ -9,6 +9,8 @@ from . types import Error, Value class GraphRagQuery(Record): query = String() + user = String() + collection = String() class GraphRagResponse(Record): error = Error() @@ -27,6 +29,8 @@ graph_rag_response_queue = topic( class DocumentRagQuery(Record): query = String() + user = String() + collection = String() class DocumentRagResponse(Record): error = Error() diff --git a/trustgraph-cli/scripts/tg-graph-show b/trustgraph-cli/scripts/tg-graph-show index a737c97b..ac5db93f 100755 --- a/trustgraph-cli/scripts/tg-graph-show +++ b/trustgraph-cli/scripts/tg-graph-show @@ -9,12 +9,17 @@ import os from trustgraph.clients.triples_query_client import TriplesQueryClient default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_user = 'trustgraph' +default_collection = 'default' -def show_graph(pulsar): +def show_graph(pulsar, user, collection): tq = TriplesQueryClient(pulsar_host=pulsar) - rows = tq.request(None, None, None, limit=10_000_000) + rows = tq.request( + user=user, collection=collection, + s=None, p=None, o=None, limit=10_000_000 + ) for row in rows: print(row.s.value, row.p.value, row.o.value) @@ -32,11 +37,26 @@ def main(): help=f'Pulsar host (default: {default_pulsar_host})', ) + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + args = parser.parse_args() try: - show_graph(args.pulsar_host) + show_graph( + pulsar=args.pulsar_host, user=args.user, + collection=args.collection, + ) except Exception as e: diff --git a/trustgraph-cli/scripts/tg-load-pdf b/trustgraph-cli/scripts/tg-load-pdf index 5d54da93..460a2f06 100755 --- a/trustgraph-cli/scripts/tg-load-pdf +++ b/trustgraph-cli/scripts/tg-load-pdf @@ -6,7 +6,7 @@ Loads a PDF document into TrustGraph processing. import pulsar from pulsar.schema import JsonSchema -from trustgraph.schema import Document, Source, document_ingest_queue +from trustgraph.schema import Document, document_ingest_queue, Metadata import base64 import hashlib import argparse @@ -15,12 +15,17 @@ import time from trustgraph.log_level import LogLevel +default_user = 'trustgraph' +default_collection = 'default' + class Loader: def __init__( self, pulsar_host, output_queue, + user, + collection, log_level, ): @@ -35,6 +40,9 @@ class Loader: chunking_enabled=True, ) + self.user = user + self.collection = collection + def load(self, files): for file in files: @@ -50,10 +58,12 @@ class Loader: id = hashlib.sha256(path.encode("utf-8")).hexdigest()[0:8] r = Document( - source=Source( + metadata=Metadata( source=path, title=path, id=id, + user=self.user, + collection=self.collection, ), data=base64.b64encode(data), ) @@ -90,6 +100,18 @@ def main(): help=f'Output queue (default: {default_output_queue})' ) + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + parser.add_argument( '-l', '--log-level', type=LogLevel, @@ -112,6 +134,8 @@ def main(): p = Loader( pulsar_host=args.pulsar_host, output_queue=args.output_queue, + user=args.user, + collection=args.collection, log_level=args.log_level, ) diff --git a/trustgraph-cli/scripts/tg-load-text b/trustgraph-cli/scripts/tg-load-text index 8137006c..e22af5b1 100755 --- a/trustgraph-cli/scripts/tg-load-text +++ b/trustgraph-cli/scripts/tg-load-text @@ -6,7 +6,7 @@ Loads a text document into TrustGraph processing. import pulsar from pulsar.schema import JsonSchema -from trustgraph.schema import TextDocument, Source, text_ingest_queue +from trustgraph.schema import TextDocument, text_ingest_queue, Metadata import base64 import hashlib import argparse @@ -15,12 +15,17 @@ import time from trustgraph.log_level import LogLevel +default_user = 'trustgraph' +default_collection = 'default' + class Loader: def __init__( self, pulsar_host, output_queue, + user, + collection, log_level, ): @@ -35,6 +40,9 @@ class Loader: chunking_enabled=True, ) + self.user = user + self.collection = collection + def load(self, files): for file in files: @@ -50,10 +58,12 @@ class Loader: id = hashlib.sha256(path.encode("utf-8")).hexdigest()[0:8] r = TextDocument( - source=Source( + metadata=Metadata( source=path, title=path, id=id, + user=self.user, + collection=self.collection, ), text=data, ) @@ -90,6 +100,18 @@ def main(): help=f'Output queue (default: {default_output_queue})' ) + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + parser.add_argument( '-l', '--log-level', type=LogLevel, @@ -112,6 +134,8 @@ def main(): p = Loader( pulsar_host=args.pulsar_host, output_queue=args.output_queue, + user=args.user, + collection=args.collection, log_level=args.log_level, ) diff --git a/trustgraph-cli/scripts/tg-query-document-rag b/trustgraph-cli/scripts/tg-query-document-rag index 948dcd2f..8d800629 100755 --- a/trustgraph-cli/scripts/tg-query-document-rag +++ b/trustgraph-cli/scripts/tg-query-document-rag @@ -9,17 +9,19 @@ import os from trustgraph.clients.document_rag_client import DocumentRagClient default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_user = 'trustgraph' +default_collection = 'default' -def query(pulsar, query): +def query(pulsar_host, query, user, collection): rag = DocumentRagClient(pulsar_host=pulsar) - resp = rag.request(query) + resp = rag.request(user=user, collection=collection, query=query) print(resp) def main(): parser = argparse.ArgumentParser( - prog='graph-show', + prog='tg-query-document-rag', description=__doc__, ) @@ -35,11 +37,28 @@ def main(): help=f'Query to execute', ) + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + args = parser.parse_args() try: - query(args.pulsar_host, args.query) + query( + pulsar_host=args.pulsar_host, + query=args.query, + user=args.user, + collection=args.collection, + ) except Exception as e: diff --git a/trustgraph-cli/scripts/tg-query-graph-rag b/trustgraph-cli/scripts/tg-query-graph-rag index 5250bf15..8a865eea 100755 --- a/trustgraph-cli/scripts/tg-query-graph-rag +++ b/trustgraph-cli/scripts/tg-query-graph-rag @@ -9,17 +9,19 @@ import os from trustgraph.clients.graph_rag_client import GraphRagClient default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650') +default_user = 'trustgraph' +default_collection = 'default' -def query(pulsar, query): +def query(pulsar_host, query, user, collection): - rag = GraphRagClient(pulsar_host=pulsar) - resp = rag.request(query) + rag = GraphRagClient(pulsar_host=pulsar_host) + resp = rag.request(user=user, collection=collection, query=query) print(resp) def main(): parser = argparse.ArgumentParser( - prog='graph-show', + prog='tg-graph-query-rag', description=__doc__, ) @@ -35,11 +37,28 @@ def main(): help=f'Query to execute', ) + parser.add_argument( + '-u', '--user', + default=default_user, + help=f'User ID (default: {default_user})' + ) + + parser.add_argument( + '-c', '--collection', + default=default_collection, + help=f'Collection ID (default: {default_collection})' + ) + args = parser.parse_args() try: - query(args.pulsar_host, args.query) + query( + pulsar_host=args.pulsar_host, + query=args.query, + user=args.user, + collection=args.collection, + ) except Exception as e: diff --git a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py index fe1a0cee..bec854bb 100755 --- a/trustgraph-flow/trustgraph/chunking/recursive/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/recursive/chunker.py @@ -7,7 +7,7 @@ as text as separate output objects. from langchain_text_splitters import RecursiveCharacterTextSplitter from prometheus_client import Histogram -from ... schema import TextDocument, Chunk, Source +from ... schema import TextDocument, Chunk, Metadata from ... schema import text_ingest_queue, chunk_ingest_queue from ... log_level import LogLevel from ... base import ConsumerProducer @@ -55,7 +55,7 @@ class Processor(ConsumerProducer): def handle(self, msg): v = msg.value() - print(f"Chunking {v.source.id}...", flush=True) + print(f"Chunking {v.metadata.id}...", flush=True) texts = self.text_splitter.create_documents( [v.text.decode("utf-8")] @@ -63,13 +63,15 @@ class Processor(ConsumerProducer): for ix, chunk in enumerate(texts): - id = v.source.id + "-c" + str(ix) + id = v.metadata.id + "-c" + str(ix) r = Chunk( - source=Source( - source=v.source.source, + metadata=Metadata( + source=v.metadata.source, id=id, - title=v.source.title + title=v.metadata.title, + user=v.metadata.user, + collection=v.metadata.collection, ), chunk=chunk.page_content.encode("utf-8"), ) diff --git a/trustgraph-flow/trustgraph/chunking/token/chunker.py b/trustgraph-flow/trustgraph/chunking/token/chunker.py index c152b0fd..e7bc5667 100755 --- a/trustgraph-flow/trustgraph/chunking/token/chunker.py +++ b/trustgraph-flow/trustgraph/chunking/token/chunker.py @@ -7,7 +7,7 @@ as text as separate output objects. from langchain_text_splitters import TokenTextSplitter from prometheus_client import Histogram -from ... schema import TextDocument, Chunk, Source +from ... schema import TextDocument, Chunk, Metadata from ... schema import text_ingest_queue, chunk_ingest_queue from ... log_level import LogLevel from ... base import ConsumerProducer @@ -54,7 +54,7 @@ class Processor(ConsumerProducer): def handle(self, msg): v = msg.value() - print(f"Chunking {v.source.id}...", flush=True) + print(f"Chunking {v.metadata.id}...", flush=True) texts = self.text_splitter.create_documents( [v.text.decode("utf-8")] @@ -62,13 +62,15 @@ class Processor(ConsumerProducer): for ix, chunk in enumerate(texts): - id = v.source.id + "-c" + str(ix) + id = v.metadata.id + "-c" + str(ix) r = Chunk( - source=Source( - source=v.source.source, + metadata=Metadata( + source=v.metadata.source, id=id, - title=v.source.title + title=v.metadata.title, + user=v.metadata.user, + collection=v.metadata.collection, ), chunk=chunk.page_content.encode("utf-8"), ) diff --git a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py index fffcaee0..7e9aeafa 100755 --- a/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py +++ b/trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py @@ -8,7 +8,7 @@ import tempfile import base64 from langchain_community.document_loaders import PyPDFLoader -from ... schema import Document, TextDocument, Source +from ... schema import Document, TextDocument, Metadata from ... schema import document_ingest_queue, text_ingest_queue from ... log_level import LogLevel from ... base import ConsumerProducer @@ -45,7 +45,7 @@ class Processor(ConsumerProducer): v = msg.value() - print(f"Decoding {v.source.id}...", flush=True) + print(f"Decoding {v.metadata.id}...", flush=True) with tempfile.NamedTemporaryFile(delete_on_close=False) as fp: @@ -59,12 +59,14 @@ class Processor(ConsumerProducer): for ix, page in enumerate(pages): - id = v.source.id + "-p" + str(ix) + id = v.metadata.id + "-p" + str(ix) r = TextDocument( - source=Source( - source=v.source.source, - title=v.source.title, + metadata=Metadata( + source=v.metadata.source, + title=v.metadata.title, id=id, + user=v.metadata.user, + collection=v.metadata.collection, ), text=page.page_content.encode("utf-8"), ) diff --git a/trustgraph-flow/trustgraph/direct/cassandra.py b/trustgraph-flow/trustgraph/direct/cassandra.py index 1754e090..2b577df1 100644 --- a/trustgraph-flow/trustgraph/direct/cassandra.py +++ b/trustgraph-flow/trustgraph/direct/cassandra.py @@ -4,10 +4,16 @@ from cassandra.auth import PlainTextAuthProvider class TrustGraph: - def __init__(self, hosts=None): + def __init__( + self, hosts=None, + keyspace="trustgraph", table="default", + ): if hosts is None: hosts = ["localhost"] + + self.keyspace = keyspace + self.table = table self.cluster = Cluster(hosts) self.session = self.cluster.connect() @@ -16,26 +22,26 @@ class TrustGraph: def clear(self): - self.session.execute(""" - drop keyspace if exists trustgraph; + self.session.execute(f""" + drop keyspace if exists {self.keyspace}; """); self.init() def init(self): - self.session.execute(""" - create keyspace if not exists trustgraph - with replication = { + self.session.execute(f""" + create keyspace if not exists {self.keyspace} + with replication = {{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 - }; + }}; """); - self.session.set_keyspace('trustgraph') + self.session.set_keyspace(self.keyspace) - self.session.execute(""" - create table if not exists triples ( + self.session.execute(f""" + create table if not exists {self.table} ( s text, p text, o text, @@ -43,66 +49,66 @@ class TrustGraph: ); """); - self.session.execute(""" - create index if not exists triples_p - ON triples (p); + self.session.execute(f""" + create index if not exists {self.table}_p + ON {self.table} (p); """); - self.session.execute(""" - create index if not exists triples_o - ON triples (o); + self.session.execute(f""" + create index if not exists {self.table}_o + ON {self.table} (o); """); def insert(self, s, p, o): self.session.execute( - "insert into triples (s, p, o) values (%s, %s, %s)", + f"insert into {self.table} (s, p, o) values (%s, %s, %s)", (s, p, o) ) def get_all(self, limit=50): return self.session.execute( - f"select s, p, o from triples limit {limit}" + f"select s, p, o from {self.table} limit {limit}" ) def get_s(self, s, limit=10): return self.session.execute( - f"select p, o from triples where s = %s limit {limit}", + f"select p, o from {self.table} where s = %s limit {limit}", (s,) ) def get_p(self, p, limit=10): return self.session.execute( - f"select s, o from triples where p = %s limit {limit}", + f"select s, o from {self.table} where p = %s limit {limit}", (p,) ) def get_o(self, o, limit=10): return self.session.execute( - f"select s, p from triples where o = %s limit {limit}", + f"select s, p from {self.table} where o = %s limit {limit}", (o,) ) def get_sp(self, s, p, limit=10): return self.session.execute( - f"select o from triples where s = %s and p = %s limit {limit}", + f"select o from {self.table} where s = %s and p = %s limit {limit}", (s, p) ) def get_po(self, p, o, limit=10): return self.session.execute( - f"select s from triples where p = %s and o = %s allow filtering limit {limit}", + f"select s from {self.table} where p = %s and o = %s allow filtering limit {limit}", (p, o) ) def get_os(self, o, s, limit=10): return self.session.execute( - f"select p from triples where o = %s and s = %s limit {limit}", + f"select p from {self.table} where o = %s and s = %s limit {limit}", (o, s) ) def get_spo(self, s, p, o, limit=10): return self.session.execute( - f"""select s as x from triples where s = %s and p = %s and o = %s limit {limit}""", + f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""", (s, p, o) ) diff --git a/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py b/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py index 3770fee2..4cf2af05 100755 --- a/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py +++ b/trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py @@ -50,15 +50,15 @@ class Processor(ConsumerProducer): subscriber=module + "-emb", ) - def emit(self, source, chunk, vectors): + def emit(self, metadata, chunk, vectors): - r = ChunkEmbeddings(source=source, chunk=chunk, vectors=vectors) + r = ChunkEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors) self.producer.send(r) def handle(self, msg): v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) + print(f"Indexing {v.metadata.id}...", flush=True) chunk = v.chunk.decode("utf-8") @@ -67,7 +67,7 @@ class Processor(ConsumerProducer): vectors = self.embeddings.request(chunk) self.emit( - source=v.source, + metadata=v.metadata, chunk=chunk.encode("utf-8"), vectors=vectors ) diff --git a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py index 06ba8e68..d528b74d 100755 --- a/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/definitions/extract.py @@ -7,7 +7,7 @@ get entity definitions which are output as graph edges. import urllib.parse import json -from .... schema import ChunkEmbeddings, Triple, Source, Value +from .... schema import ChunkEmbeddings, Triple, Metadata, Value from .... schema import chunk_embeddings_ingest_queue, triples_store_queue from .... schema import prompt_request_queue from .... schema import prompt_response_queue @@ -69,15 +69,15 @@ class Processor(ConsumerProducer): return self.prompt.request_definitions(chunk) - def emit_edge(self, s, p, o): + def emit_edge(self, metadata, s, p, o): - t = Triple(s=s, p=p, o=o) + t = Triple(metadata=metadata, s=s, p=p, o=o) self.producer.send(t) def handle(self, msg): v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) + print(f"Indexing {v.metadata.id}...", flush=True) chunk = v.chunk.decode("utf-8") @@ -101,7 +101,7 @@ class Processor(ConsumerProducer): s_value = Value(value=str(s_uri), is_uri=True) o_value = Value(value=str(o), is_uri=False) - self.emit_edge(s_value, DEFINITION_VALUE, o_value) + self.emit_edge(v.metadata, s_value, DEFINITION_VALUE, o_value) except Exception as e: print("Exception: ", e, flush=True) diff --git a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py index 49ef9072..cb80e47f 100755 --- a/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py @@ -9,7 +9,8 @@ import urllib.parse import os from pulsar.schema import JsonSchema -from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value +from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings +from .... schema import Metadata, Value from .... schema import chunk_embeddings_ingest_queue, triples_store_queue from .... schema import graph_embeddings_store_queue from .... schema import prompt_request_queue @@ -91,20 +92,20 @@ class Processor(ConsumerProducer): return self.prompt.request_relationships(chunk) - def emit_edge(self, s, p, o): + def emit_edge(self, metadata, s, p, o): - t = Triple(s=s, p=p, o=o) + t = Triple(metadata=metadata, s=s, p=p, o=o) self.producer.send(t) - def emit_vec(self, ent, vec): + def emit_vec(self, metadata, ent, vec): - r = GraphEmbeddings(entity=ent, vectors=vec) + r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec) self.vec_prod.send(r) def handle(self, msg): v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) + print(f"Indexing {v.metadata.id}...", flush=True) chunk = v.chunk.decode("utf-8") @@ -139,6 +140,7 @@ class Processor(ConsumerProducer): o_value = Value(value=str(o), is_uri=False) self.emit_edge( + v.metadata, s_value, p_value, o_value @@ -146,6 +148,7 @@ class Processor(ConsumerProducer): # Label for s self.emit_edge( + v.metadata, s_value, RDF_LABEL_VALUE, Value(value=str(s), is_uri=False) @@ -153,6 +156,7 @@ class Processor(ConsumerProducer): # Label for p self.emit_edge( + v.metadata, p_value, RDF_LABEL_VALUE, Value(value=str(p), is_uri=False) @@ -161,15 +165,16 @@ class Processor(ConsumerProducer): if rel.o_entity: # Label for o self.emit_edge( + v.metadata, o_value, RDF_LABEL_VALUE, Value(value=str(o), is_uri=False) ) - self.emit_vec(s_value, v.vectors) - self.emit_vec(p_value, v.vectors) + self.emit_vec(v.metadata, s_value, v.vectors) + self.emit_vec(v.metadata, p_value, v.vectors) if rel.o_entity: - self.emit_vec(o_value, v.vectors) + self.emit_vec(v.metadata, o_value, v.vectors) except Exception as e: print("Exception: ", e, flush=True) diff --git a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py index e2ebe5b0..81e52669 100755 --- a/trustgraph-flow/trustgraph/extract/kg/topics/extract.py +++ b/trustgraph-flow/trustgraph/extract/kg/topics/extract.py @@ -7,7 +7,7 @@ get entity definitions which are output as graph edges. import urllib.parse import json -from .... schema import ChunkEmbeddings, Triple, Source, Value +from .... schema import ChunkEmbeddings, Triple, Metadata, Value from .... schema import chunk_embeddings_ingest_queue, triples_store_queue from .... schema import prompt_request_queue from .... schema import prompt_response_queue @@ -69,15 +69,15 @@ class Processor(ConsumerProducer): return self.prompt.request_topics(chunk) - def emit_edge(self, s, p, o): + def emit_edge(self, metadata, s, p, o): - t = Triple(s=s, p=p, o=o) + t = Triple(metadata=metadata, s=s, p=p, o=o) self.producer.send(t) def handle(self, msg): v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) + print(f"Indexing {v.metadata.id}...", flush=True) chunk = v.chunk.decode("utf-8") @@ -101,7 +101,7 @@ class Processor(ConsumerProducer): s_value = Value(value=str(s_uri), is_uri=True) o_value = Value(value=str(o), is_uri=False) - self.emit_edge(s_value, DEFINITION_VALUE, o_value) + self.emit_edge(v. metadata, s_value, DEFINITION_VALUE, o_value) except Exception as e: print("Exception: ", e, flush=True) diff --git a/trustgraph-flow/trustgraph/extract/object/row/extract.py b/trustgraph-flow/trustgraph/extract/object/row/extract.py index aa53f2a6..185a59c3 100755 --- a/trustgraph-flow/trustgraph/extract/object/row/extract.py +++ b/trustgraph-flow/trustgraph/extract/object/row/extract.py @@ -8,7 +8,7 @@ import urllib.parse import os from pulsar.schema import JsonSchema -from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Source +from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Metadata from .... schema import RowSchema, Field from .... schema import chunk_embeddings_ingest_queue, rows_store_queue from .... schema import object_embeddings_store_queue @@ -124,24 +124,24 @@ class Processor(ConsumerProducer): def get_rows(self, chunk): return self.prompt.request_rows(self.schema, chunk) - def emit_rows(self, source, rows): + def emit_rows(self, metadata, rows): t = Rows( - source=source, row_schema=self.row_schema, rows=rows + metadata=metadata, row_schema=self.row_schema, rows=rows ) self.producer.send(t) - def emit_vec(self, source, name, vec, key_name, key): + def emit_vec(self, metadata, name, vec, key_name, key): r = ObjectEmbeddings( - source=source, vectors=vec, name=name, key_name=key_name, id=key + metadata=metadata, vectors=vec, name=name, key_name=key_name, id=key ) self.vec_prod.send(r) def handle(self, msg): v = msg.value() - print(f"Indexing {v.source.id}...", flush=True) + print(f"Indexing {v.metadata.id}...", flush=True) chunk = v.chunk.decode("utf-8") @@ -150,13 +150,13 @@ class Processor(ConsumerProducer): rows = self.get_rows(chunk) self.emit_rows( - source=v.source, + metadata=v.metadata, rows=rows ) for row in rows: self.emit_vec( - source=v.source, vec=v.vectors, + metadata=v.metadata, vec=v.vectors, name=self.schema.name, key_name=self.primary.name, key=row[self.primary.name] ) diff --git a/trustgraph-flow/trustgraph/graph_rag.py b/trustgraph-flow/trustgraph/graph_rag.py index 15acb609..f69ebeb7 100644 --- a/trustgraph-flow/trustgraph/graph_rag.py +++ b/trustgraph-flow/trustgraph/graph_rag.py @@ -18,6 +18,144 @@ from . schema import triples_response_queue LABEL="http://www.w3.org/2000/01/rdf-schema#label" DEFINITION="http://www.w3.org/2004/02/skos/core#definition" +class Query: + + def __init__(self, rag, user, collection, verbose): + self.rag = rag + self.user = user + self.collection = collection + self.verbose = verbose + + def get_vector(self, query): + + if self.verbose: + print("Compute embeddings...", flush=True) + + qembeds = self.rag.embeddings.request(query) + + if self.verbose: + print("Done.", flush=True) + + return qembeds + + def get_entities(self, query): + + vectors = self.get_vector(query) + + if self.verbose: + print("Get entities...", flush=True) + + entities = self.rag.ge_client.request( + user=self.user, collection=self.collection, + vectors=vectors, limit=self.rag.entity_limit, + ) + + entities = [ + e.value + for e in entities + ] + + if self.verbose: + print("Entities:", flush=True) + for ent in entities: + print(" ", ent, flush=True) + + return entities + + def maybe_label(self, e): + + if e in self.rag.label_cache: + return self.rag.label_cache[e] + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=e, p=LABEL, o=None, limit=1, + ) + + if len(res) == 0: + self.rag.label_cache[e] = e + return e + + self.rag.label_cache[e] = res[0].o.value + return self.rag.label_cache[e] + + def get_subgraph(self, query): + + entities = self.get_entities(query) + + subgraph = set() + + if self.verbose: + print("Get subgraph...", flush=True) + + for e in entities: + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=e, p=None, o=None, + limit=self.rag.query_limit + ) + + for triple in res: + subgraph.add( + (triple.s.value, triple.p.value, triple.o.value) + ) + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=None, p=e, o=None, + limit=self.rag.query_limit + ) + + for triple in res: + subgraph.add( + (triple.s.value, triple.p.value, triple.o.value) + ) + + res = self.rag.triples_client.request( + user=self.user, collection=self.collection, + s=None, p=None, o=e, + limit=self.rag.query_limit, + ) + + for triple in res: + subgraph.add( + (triple.s.value, triple.p.value, triple.o.value) + ) + + subgraph = list(subgraph) + + subgraph = subgraph[0:self.rag.max_subgraph_size] + + if self.verbose: + print("Subgraph:", flush=True) + for edge in subgraph: + print(" ", str(edge), flush=True) + + if self.verbose: + print("Done.", flush=True) + + return subgraph + + def get_labelgraph(self, query): + + subgraph = self.get_subgraph(query) + + sg2 = [] + + for edge in subgraph: + + if edge[1] == LABEL: + continue + + s = self.maybe_label(edge[0]) + p = self.maybe_label(edge[1]) + o = self.maybe_label(edge[2]) + + sg2.append((s, p, o)) + + return sg2 + class GraphRag: def __init__( @@ -94,7 +232,7 @@ class GraphRag: self.label_cache = {} - self.lang = PromptClient( + self.prompt = PromptClient( pulsar_host=pulsar_host, input_queue=pr_request_queue, output_queue=pr_response_queue, @@ -104,144 +242,23 @@ class GraphRag: if self.verbose: print("Initialised", flush=True) - def get_vector(self, query): - - if self.verbose: - print("Compute embeddings...", flush=True) - - qembeds = self.embeddings.request(query) - - if self.verbose: - print("Done.", flush=True) - - return qembeds - - def get_entities(self, query): - - vectors = self.get_vector(query) - - if self.verbose: - print("Get entities...", flush=True) - - entities = self.ge_client.request( - vectors, self.entity_limit - ) - - entities = [ - e.value - for e in entities - ] - - if self.verbose: - print("Entities:", flush=True) - for ent in entities: - print(" ", ent, flush=True) - - return entities - - def maybe_label(self, e): - - if e in self.label_cache: - return self.label_cache[e] - - res = self.triples_client.request( - e, LABEL, None, limit=1 - ) - - if len(res) == 0: - self.label_cache[e] = e - return e - - self.label_cache[e] = res[0].o.value - return self.label_cache[e] - - def get_subgraph(self, query): - - entities = self.get_entities(query) - - subgraph = set() - - if self.verbose: - print("Get subgraph...", flush=True) - - for e in entities: - - res = self.triples_client.request( - e, None, None, - limit=self.query_limit - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - - res = self.triples_client.request( - None, e, None, - limit=self.query_limit - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - - res = self.triples_client.request( - None, None, e, - limit=self.query_limit - ) - - for triple in res: - subgraph.add( - (triple.s.value, triple.p.value, triple.o.value) - ) - - subgraph = list(subgraph) - - subgraph = subgraph[0:self.max_subgraph_size] - - if self.verbose: - print("Subgraph:", flush=True) - for edge in subgraph: - print(" ", str(edge), flush=True) - - if self.verbose: - print("Done.", flush=True) - - return subgraph - - def get_labelgraph(self, query): - - subgraph = self.get_subgraph(query) - - sg2 = [] - - for edge in subgraph: - - if edge[1] == LABEL: - continue - - s = self.maybe_label(edge[0]) - p = self.maybe_label(edge[1]) - o = self.maybe_label(edge[2]) - - sg2.append((s, p, o)) - - return sg2 - - def query(self, query): + def query(self, query, user="trustgraph", collection="default"): if self.verbose: print("Construct prompt...", flush=True) - kg = self.get_labelgraph(query) + q = Query( + rag=self, user=user, collection=collection, verbose=self.verbose + ) + + kg = q.get_labelgraph(query) if self.verbose: print("Invoke LLM...", flush=True) print(kg) print(query) - resp = self.lang.request_kg_prompt(query, kg) + resp = self.prompt.request_kg_prompt(query, kg) if self.verbose: print("Done", flush=True) diff --git a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py index 5cc41437..7bb5133a 100755 --- a/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py @@ -60,7 +60,10 @@ class Processor(ConsumerProducer): for vec in v.vectors: dim = len(vec) - collection = "doc_" + str(dim) + collection = ( + "d_" + v.user + "_" + v.collection + "_" + + str(dim) + ) search_result = self.client.query_points( collection_name=collection, diff --git a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py index e61a00a7..8991f9ea 100755 --- a/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py +++ b/trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py @@ -66,7 +66,10 @@ class Processor(ConsumerProducer): for vec in v.vectors: dim = len(vec) - collection = "triples_" + str(dim) + collection = ( + "t_" + v.user + "_" + v.collection + "_" + + str(dim) + ) search_result = self.client.query_points( collection_name=collection, diff --git a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py index 5e1e0e3e..4245784d 100755 --- a/trustgraph-flow/trustgraph/query/triples/cassandra/service.py +++ b/trustgraph-flow/trustgraph/query/triples/cassandra/service.py @@ -38,7 +38,8 @@ class Processor(ConsumerProducer): } ) - self.tg = TrustGraph([graph_host]) + self.graph_host = [graph_host] + self.table = None def create_value(self, ent): if ent.startswith("http://") or ent.startswith("https://"): @@ -52,6 +53,15 @@ class Processor(ConsumerProducer): v = msg.value() + table = (v.user, v.collection) + + if table != self.table: + self.tg = TrustGraph( + hosts=self.graph_host, + keyspace=v.user, table=v.collection, + ) + self.table = table + # Sender-produced ID id = msg.properties()["id"] diff --git a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py index 87f8d24f..1219050e 100755 --- a/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py +++ b/trustgraph-flow/trustgraph/retrieval/graph_rag/rag.py @@ -108,7 +108,9 @@ class Processor(ConsumerProducer): print(f"Handling input {id}...", flush=True) - response = self.rag.query(v.query) + response = self.rag.query( + query=v.query, user=v.user, collection=v.collection + ) print("Send response...", flush=True) r = GraphRagResponse(response = response, error=None) diff --git a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py index f22ae74a..813c4f29 100644 --- a/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/doc_embeddings/qdrant/write.py @@ -37,7 +37,6 @@ class Processor(Consumer): ) self.last_collection = None - self.last_dim = None self.client = QdrantClient(url=store_uri) @@ -52,9 +51,12 @@ class Processor(Consumer): for vec in v.vectors: dim = len(vec) - collection = "doc_" + str(dim) + collection = ( + "d_" + v.metadata.user + "_" + v.metadata.collection + "_" + + str(dim) + ) - if dim != self.last_dim: + if collection != self.last_collection: if not self.client.collection_exists(collection): @@ -70,7 +72,6 @@ class Processor(Consumer): raise e self.last_collection = collection - self.last_dim = dim self.client.upsert( collection_name=collection, diff --git a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py index 95448750..e27c2516 100755 --- a/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py +++ b/trustgraph-flow/trustgraph/storage/graph_embeddings/qdrant/write.py @@ -37,7 +37,6 @@ class Processor(Consumer): ) self.last_collection = None - self.last_dim = None self.client = QdrantClient(url=store_uri) @@ -50,9 +49,12 @@ class Processor(Consumer): for vec in v.vectors: dim = len(vec) - collection = "triples_" + str(dim) + collection = ( + "t_" + v.metadata.user + "_" + v.metadata.collection + "_" + + str(dim) + ) - if dim != self.last_dim: + if collection != self.last_collection: if not self.client.collection_exists(collection): @@ -68,7 +70,6 @@ class Processor(Consumer): raise e self.last_collection = collection - self.last_dim = dim self.client.upsert( collection_name=collection, diff --git a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py index 84c002ff..e8074c80 100755 --- a/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py +++ b/trustgraph-flow/trustgraph/storage/triples/cassandra/write.py @@ -38,12 +38,31 @@ class Processor(Consumer): } ) - self.tg = TrustGraph([graph_host]) + self.graph_host = [graph_host] + self.table = None def handle(self, msg): v = msg.value() + table = (v.metadata.user, v.metadata.collection) + + if self.table is None or self.table != table: + + self.tg = None + + try: + self.tg = TrustGraph( + hosts=self.graph_host, + keyspace=v.metadata.user, table=v.metadata.collection, + ) + except Exception as e: + print("Exception", e, flush=True) + time.sleep(1) + raise e + + self.table = table + self.tg.insert( v.s.value, v.p.value,