Feature / collections (#96)

* Update schema defs for source -> metadata
* Migrate to use metadata part of schema, also add metadata to triples & vecs
* Add user/collection metadata to query
* Use user/collection in RAG
* Write and query working on triples
This commit is contained in:
cybermaggedon 2024-10-02 18:14:29 +01:00 committed by GitHub
parent 709221fa10
commit b0f4c58200
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 459 additions and 251 deletions

View file

@ -38,8 +38,12 @@ class GraphEmbeddingsClient(BaseClient):
output_schema=GraphEmbeddingsResponse,
)
def request(self, vectors, limit=10, timeout=300):
def request(
self, vectors, user="trustgraph", collection="default",
limit=10, timeout=300
):
return self.call(
user=user, collection=collection,
vectors=vectors, limit=limit, timeout=timeout
).entities

View file

@ -38,9 +38,12 @@ class GraphRagClient(BaseClient):
output_schema=GraphRagResponse,
)
def request(self, query, timeout=500):
def request(
self, query, user="trustgraph", collection="default",
timeout=500
):
return self.call(
query=query, timeout=timeout
user=user, collection=collection, query=query, timeout=timeout
).response

View file

@ -48,11 +48,18 @@ class TriplesQueryClient(BaseClient):
return Value(value=ent, is_uri=False)
def request(self, s, p, o, limit=10, timeout=60):
def request(
self,
s, p, o,
user="trustgraph", collection="default",
limit=10, timeout=60,
):
return self.call(
s=self.create_value(s),
p=self.create_value(p),
o=self.create_value(o),
user=user,
collection=collection,
limit=limit,
timeout=timeout,
).triples

View file

@ -7,6 +7,7 @@ from . object import *
from . topic import *
from . graph import *
from . retrieval import *
from . metadata import *

View file

@ -2,17 +2,13 @@
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double
from . topic import topic
from . types import Error
class Source(Record):
source = String()
id = String()
title = String()
from . metadata import Metadata
############################################################################
# PDF docs etc.
class Document(Record):
source = Source()
metadata = Metadata()
data = Bytes()
document_ingest_queue = topic('document-load')
@ -22,7 +18,7 @@ document_ingest_queue = topic('document-load')
# Text documents / text from PDF
class TextDocument(Record):
source = Source()
metadata = Metadata()
text = Bytes()
text_ingest_queue = topic('text-document-load')
@ -32,7 +28,7 @@ text_ingest_queue = topic('text-document-load')
# Chunks of text
class Chunk(Record):
source = Source()
metadata = Metadata()
chunk = Bytes()
chunk_ingest_queue = topic('chunk-load')
@ -42,7 +38,7 @@ chunk_ingest_queue = topic('chunk-load')
# Chunk embeddings are an embeddings associated with a text chunk
class ChunkEmbeddings(Record):
source = Source()
metadata = Metadata()
vectors = Array(Array(Double()))
chunk = Bytes()

View file

@ -1,16 +1,16 @@
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double
from . documents import Source
from . types import Error, Value
from . topic import topic
from . metadata import Metadata
############################################################################
# Graph embeddings are embeddings associated with a graph entity
class GraphEmbeddings(Record):
source = Source()
metadata = Metadata()
vectors = Array(Array(Double()))
entity = Value()
@ -23,6 +23,8 @@ graph_embeddings_store_queue = topic('graph-embeddings-store')
class GraphEmbeddingsRequest(Record):
vectors = Array(Array(Double()))
limit = Integer()
user = String()
collection = String()
class GraphEmbeddingsResponse(Record):
error = Error()
@ -40,7 +42,7 @@ graph_embeddings_response_queue = topic(
# Graph triples
class Triple(Record):
source = Source()
metadata = Metadata()
s = Value()
p = Value()
o = Value()
@ -56,6 +58,8 @@ class TriplesQueryRequest(Record):
p = Value()
o = Value()
limit = Integer()
user = String()
collection = String()
class TriplesQueryResponse(Record):
error = Error()

View file

@ -0,0 +1,10 @@
from pulsar.schema import Record, String
class Metadata(Record):
source = String()
id = String()
title = String()
user = String()
collection = String()

View file

@ -2,7 +2,7 @@
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array
from pulsar.schema import Double, Map
from . documents import Source
from . metadata import Metadata
from . types import Value, RowSchema
from . topic import topic
@ -12,7 +12,7 @@ from . topic import topic
# object
class ObjectEmbeddings(Record):
source = Source()
metadata = Metadata()
vectors = Array(Array(Double()))
name = String()
key_name = String()
@ -25,7 +25,7 @@ object_embeddings_store_queue = topic('object-embeddings-store')
# Stores rows of information
class Rows(Record):
source = Source()
metadata = Metadata()
row_schema = RowSchema()
rows = Array(Map(String()))

View file

@ -9,6 +9,8 @@ from . types import Error, Value
class GraphRagQuery(Record):
query = String()
user = String()
collection = String()
class GraphRagResponse(Record):
error = Error()
@ -27,6 +29,8 @@ graph_rag_response_queue = topic(
class DocumentRagQuery(Record):
query = String()
user = String()
collection = String()
class DocumentRagResponse(Record):
error = Error()

View file

@ -9,12 +9,17 @@ import os
from trustgraph.clients.triples_query_client import TriplesQueryClient
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
default_user = 'trustgraph'
default_collection = 'default'
def show_graph(pulsar):
def show_graph(pulsar, user, collection):
tq = TriplesQueryClient(pulsar_host=pulsar)
rows = tq.request(None, None, None, limit=10_000_000)
rows = tq.request(
user=user, collection=collection,
s=None, p=None, o=None, limit=10_000_000
)
for row in rows:
print(row.s.value, row.p.value, row.o.value)
@ -32,11 +37,26 @@ def main():
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'-u', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-c', '--collection',
default=default_collection,
help=f'Collection ID (default: {default_collection})'
)
args = parser.parse_args()
try:
show_graph(args.pulsar_host)
show_graph(
pulsar=args.pulsar_host, user=args.user,
collection=args.collection,
)
except Exception as e:

View file

@ -6,7 +6,7 @@ Loads a PDF document into TrustGraph processing.
import pulsar
from pulsar.schema import JsonSchema
from trustgraph.schema import Document, Source, document_ingest_queue
from trustgraph.schema import Document, document_ingest_queue, Metadata
import base64
import hashlib
import argparse
@ -15,12 +15,17 @@ import time
from trustgraph.log_level import LogLevel
default_user = 'trustgraph'
default_collection = 'default'
class Loader:
def __init__(
self,
pulsar_host,
output_queue,
user,
collection,
log_level,
):
@ -35,6 +40,9 @@ class Loader:
chunking_enabled=True,
)
self.user = user
self.collection = collection
def load(self, files):
for file in files:
@ -50,10 +58,12 @@ class Loader:
id = hashlib.sha256(path.encode("utf-8")).hexdigest()[0:8]
r = Document(
source=Source(
metadata=Metadata(
source=path,
title=path,
id=id,
user=self.user,
collection=self.collection,
),
data=base64.b64encode(data),
)
@ -90,6 +100,18 @@ def main():
help=f'Output queue (default: {default_output_queue})'
)
parser.add_argument(
'-u', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-c', '--collection',
default=default_collection,
help=f'Collection ID (default: {default_collection})'
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
@ -112,6 +134,8 @@ def main():
p = Loader(
pulsar_host=args.pulsar_host,
output_queue=args.output_queue,
user=args.user,
collection=args.collection,
log_level=args.log_level,
)

View file

@ -6,7 +6,7 @@ Loads a text document into TrustGraph processing.
import pulsar
from pulsar.schema import JsonSchema
from trustgraph.schema import TextDocument, Source, text_ingest_queue
from trustgraph.schema import TextDocument, text_ingest_queue, Metadata
import base64
import hashlib
import argparse
@ -15,12 +15,17 @@ import time
from trustgraph.log_level import LogLevel
default_user = 'trustgraph'
default_collection = 'default'
class Loader:
def __init__(
self,
pulsar_host,
output_queue,
user,
collection,
log_level,
):
@ -35,6 +40,9 @@ class Loader:
chunking_enabled=True,
)
self.user = user
self.collection = collection
def load(self, files):
for file in files:
@ -50,10 +58,12 @@ class Loader:
id = hashlib.sha256(path.encode("utf-8")).hexdigest()[0:8]
r = TextDocument(
source=Source(
metadata=Metadata(
source=path,
title=path,
id=id,
user=self.user,
collection=self.collection,
),
text=data,
)
@ -90,6 +100,18 @@ def main():
help=f'Output queue (default: {default_output_queue})'
)
parser.add_argument(
'-u', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-c', '--collection',
default=default_collection,
help=f'Collection ID (default: {default_collection})'
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
@ -112,6 +134,8 @@ def main():
p = Loader(
pulsar_host=args.pulsar_host,
output_queue=args.output_queue,
user=args.user,
collection=args.collection,
log_level=args.log_level,
)

View file

@ -9,17 +9,19 @@ import os
from trustgraph.clients.document_rag_client import DocumentRagClient
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
default_user = 'trustgraph'
default_collection = 'default'
def query(pulsar, query):
def query(pulsar_host, query, user, collection):
rag = DocumentRagClient(pulsar_host=pulsar)
resp = rag.request(query)
resp = rag.request(user=user, collection=collection, query=query)
print(resp)
def main():
parser = argparse.ArgumentParser(
prog='graph-show',
prog='tg-query-document-rag',
description=__doc__,
)
@ -35,11 +37,28 @@ def main():
help=f'Query to execute',
)
parser.add_argument(
'-u', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-c', '--collection',
default=default_collection,
help=f'Collection ID (default: {default_collection})'
)
args = parser.parse_args()
try:
query(args.pulsar_host, args.query)
query(
pulsar_host=args.pulsar_host,
query=args.query,
user=args.user,
collection=args.collection,
)
except Exception as e:

View file

@ -9,17 +9,19 @@ import os
from trustgraph.clients.graph_rag_client import GraphRagClient
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
default_user = 'trustgraph'
default_collection = 'default'
def query(pulsar, query):
def query(pulsar_host, query, user, collection):
rag = GraphRagClient(pulsar_host=pulsar)
resp = rag.request(query)
rag = GraphRagClient(pulsar_host=pulsar_host)
resp = rag.request(user=user, collection=collection, query=query)
print(resp)
def main():
parser = argparse.ArgumentParser(
prog='graph-show',
prog='tg-graph-query-rag',
description=__doc__,
)
@ -35,11 +37,28 @@ def main():
help=f'Query to execute',
)
parser.add_argument(
'-u', '--user',
default=default_user,
help=f'User ID (default: {default_user})'
)
parser.add_argument(
'-c', '--collection',
default=default_collection,
help=f'Collection ID (default: {default_collection})'
)
args = parser.parse_args()
try:
query(args.pulsar_host, args.query)
query(
pulsar_host=args.pulsar_host,
query=args.query,
user=args.user,
collection=args.collection,
)
except Exception as e:

View file

@ -7,7 +7,7 @@ as text as separate output objects.
from langchain_text_splitters import RecursiveCharacterTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Source
from ... schema import TextDocument, Chunk, Metadata
from ... schema import text_ingest_queue, chunk_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
@ -55,7 +55,7 @@ class Processor(ConsumerProducer):
def handle(self, msg):
v = msg.value()
print(f"Chunking {v.source.id}...", flush=True)
print(f"Chunking {v.metadata.id}...", flush=True)
texts = self.text_splitter.create_documents(
[v.text.decode("utf-8")]
@ -63,13 +63,15 @@ class Processor(ConsumerProducer):
for ix, chunk in enumerate(texts):
id = v.source.id + "-c" + str(ix)
id = v.metadata.id + "-c" + str(ix)
r = Chunk(
source=Source(
source=v.source.source,
metadata=Metadata(
source=v.metadata.source,
id=id,
title=v.source.title
title=v.metadata.title,
user=v.metadata.user,
collection=v.metadata.collection,
),
chunk=chunk.page_content.encode("utf-8"),
)

View file

@ -7,7 +7,7 @@ as text as separate output objects.
from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Source
from ... schema import TextDocument, Chunk, Metadata
from ... schema import text_ingest_queue, chunk_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
@ -54,7 +54,7 @@ class Processor(ConsumerProducer):
def handle(self, msg):
v = msg.value()
print(f"Chunking {v.source.id}...", flush=True)
print(f"Chunking {v.metadata.id}...", flush=True)
texts = self.text_splitter.create_documents(
[v.text.decode("utf-8")]
@ -62,13 +62,15 @@ class Processor(ConsumerProducer):
for ix, chunk in enumerate(texts):
id = v.source.id + "-c" + str(ix)
id = v.metadata.id + "-c" + str(ix)
r = Chunk(
source=Source(
source=v.source.source,
metadata=Metadata(
source=v.metadata.source,
id=id,
title=v.source.title
title=v.metadata.title,
user=v.metadata.user,
collection=v.metadata.collection,
),
chunk=chunk.page_content.encode("utf-8"),
)

View file

@ -8,7 +8,7 @@ import tempfile
import base64
from langchain_community.document_loaders import PyPDFLoader
from ... schema import Document, TextDocument, Source
from ... schema import Document, TextDocument, Metadata
from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
@ -45,7 +45,7 @@ class Processor(ConsumerProducer):
v = msg.value()
print(f"Decoding {v.source.id}...", flush=True)
print(f"Decoding {v.metadata.id}...", flush=True)
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
@ -59,12 +59,14 @@ class Processor(ConsumerProducer):
for ix, page in enumerate(pages):
id = v.source.id + "-p" + str(ix)
id = v.metadata.id + "-p" + str(ix)
r = TextDocument(
source=Source(
source=v.source.source,
title=v.source.title,
metadata=Metadata(
source=v.metadata.source,
title=v.metadata.title,
id=id,
user=v.metadata.user,
collection=v.metadata.collection,
),
text=page.page_content.encode("utf-8"),
)

View file

@ -4,10 +4,16 @@ from cassandra.auth import PlainTextAuthProvider
class TrustGraph:
def __init__(self, hosts=None):
def __init__(
self, hosts=None,
keyspace="trustgraph", table="default",
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
self.table = table
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
@ -16,26 +22,26 @@ class TrustGraph:
def clear(self):
self.session.execute("""
drop keyspace if exists trustgraph;
self.session.execute(f"""
drop keyspace if exists {self.keyspace};
""");
self.init()
def init(self):
self.session.execute("""
create keyspace if not exists trustgraph
with replication = {
self.session.execute(f"""
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
};
}};
""");
self.session.set_keyspace('trustgraph')
self.session.set_keyspace(self.keyspace)
self.session.execute("""
create table if not exists triples (
self.session.execute(f"""
create table if not exists {self.table} (
s text,
p text,
o text,
@ -43,66 +49,66 @@ class TrustGraph:
);
""");
self.session.execute("""
create index if not exists triples_p
ON triples (p);
self.session.execute(f"""
create index if not exists {self.table}_p
ON {self.table} (p);
""");
self.session.execute("""
create index if not exists triples_o
ON triples (o);
self.session.execute(f"""
create index if not exists {self.table}_o
ON {self.table} (o);
""");
def insert(self, s, p, o):
self.session.execute(
"insert into triples (s, p, o) values (%s, %s, %s)",
f"insert into {self.table} (s, p, o) values (%s, %s, %s)",
(s, p, o)
)
def get_all(self, limit=50):
return self.session.execute(
f"select s, p, o from triples limit {limit}"
f"select s, p, o from {self.table} limit {limit}"
)
def get_s(self, s, limit=10):
return self.session.execute(
f"select p, o from triples where s = %s limit {limit}",
f"select p, o from {self.table} where s = %s limit {limit}",
(s,)
)
def get_p(self, p, limit=10):
return self.session.execute(
f"select s, o from triples where p = %s limit {limit}",
f"select s, o from {self.table} where p = %s limit {limit}",
(p,)
)
def get_o(self, o, limit=10):
return self.session.execute(
f"select s, p from triples where o = %s limit {limit}",
f"select s, p from {self.table} where o = %s limit {limit}",
(o,)
)
def get_sp(self, s, p, limit=10):
return self.session.execute(
f"select o from triples where s = %s and p = %s limit {limit}",
f"select o from {self.table} where s = %s and p = %s limit {limit}",
(s, p)
)
def get_po(self, p, o, limit=10):
return self.session.execute(
f"select s from triples where p = %s and o = %s allow filtering limit {limit}",
f"select s from {self.table} where p = %s and o = %s allow filtering limit {limit}",
(p, o)
)
def get_os(self, o, s, limit=10):
return self.session.execute(
f"select p from triples where o = %s and s = %s limit {limit}",
f"select p from {self.table} where o = %s and s = %s limit {limit}",
(o, s)
)
def get_spo(self, s, p, o, limit=10):
return self.session.execute(
f"""select s as x from triples where s = %s and p = %s and o = %s limit {limit}""",
f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""",
(s, p, o)
)

View file

@ -50,15 +50,15 @@ class Processor(ConsumerProducer):
subscriber=module + "-emb",
)
def emit(self, source, chunk, vectors):
def emit(self, metadata, chunk, vectors):
r = ChunkEmbeddings(source=source, chunk=chunk, vectors=vectors)
r = ChunkEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors)
self.producer.send(r)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
@ -67,7 +67,7 @@ class Processor(ConsumerProducer):
vectors = self.embeddings.request(chunk)
self.emit(
source=v.source,
metadata=v.metadata,
chunk=chunk.encode("utf-8"),
vectors=vectors
)

View file

@ -7,7 +7,7 @@ get entity definitions which are output as graph edges.
import urllib.parse
import json
from .... schema import ChunkEmbeddings, Triple, Source, Value
from .... schema import ChunkEmbeddings, Triple, Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
@ -69,15 +69,15 @@ class Processor(ConsumerProducer):
return self.prompt.request_definitions(chunk)
def emit_edge(self, s, p, o):
def emit_edge(self, metadata, s, p, o):
t = Triple(s=s, p=p, o=o)
t = Triple(metadata=metadata, s=s, p=p, o=o)
self.producer.send(t)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
@ -101,7 +101,7 @@ class Processor(ConsumerProducer):
s_value = Value(value=str(s_uri), is_uri=True)
o_value = Value(value=str(o), is_uri=False)
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
self.emit_edge(v.metadata, s_value, DEFINITION_VALUE, o_value)
except Exception as e:
print("Exception: ", e, flush=True)

View file

@ -9,7 +9,8 @@ import urllib.parse
import os
from pulsar.schema import JsonSchema
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings
from .... schema import Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import graph_embeddings_store_queue
from .... schema import prompt_request_queue
@ -91,20 +92,20 @@ class Processor(ConsumerProducer):
return self.prompt.request_relationships(chunk)
def emit_edge(self, s, p, o):
def emit_edge(self, metadata, s, p, o):
t = Triple(s=s, p=p, o=o)
t = Triple(metadata=metadata, s=s, p=p, o=o)
self.producer.send(t)
def emit_vec(self, ent, vec):
def emit_vec(self, metadata, ent, vec):
r = GraphEmbeddings(entity=ent, vectors=vec)
r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec)
self.vec_prod.send(r)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
@ -139,6 +140,7 @@ class Processor(ConsumerProducer):
o_value = Value(value=str(o), is_uri=False)
self.emit_edge(
v.metadata,
s_value,
p_value,
o_value
@ -146,6 +148,7 @@ class Processor(ConsumerProducer):
# Label for s
self.emit_edge(
v.metadata,
s_value,
RDF_LABEL_VALUE,
Value(value=str(s), is_uri=False)
@ -153,6 +156,7 @@ class Processor(ConsumerProducer):
# Label for p
self.emit_edge(
v.metadata,
p_value,
RDF_LABEL_VALUE,
Value(value=str(p), is_uri=False)
@ -161,15 +165,16 @@ class Processor(ConsumerProducer):
if rel.o_entity:
# Label for o
self.emit_edge(
v.metadata,
o_value,
RDF_LABEL_VALUE,
Value(value=str(o), is_uri=False)
)
self.emit_vec(s_value, v.vectors)
self.emit_vec(p_value, v.vectors)
self.emit_vec(v.metadata, s_value, v.vectors)
self.emit_vec(v.metadata, p_value, v.vectors)
if rel.o_entity:
self.emit_vec(o_value, v.vectors)
self.emit_vec(v.metadata, o_value, v.vectors)
except Exception as e:
print("Exception: ", e, flush=True)

View file

@ -7,7 +7,7 @@ get entity definitions which are output as graph edges.
import urllib.parse
import json
from .... schema import ChunkEmbeddings, Triple, Source, Value
from .... schema import ChunkEmbeddings, Triple, Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
@ -69,15 +69,15 @@ class Processor(ConsumerProducer):
return self.prompt.request_topics(chunk)
def emit_edge(self, s, p, o):
def emit_edge(self, metadata, s, p, o):
t = Triple(s=s, p=p, o=o)
t = Triple(metadata=metadata, s=s, p=p, o=o)
self.producer.send(t)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
@ -101,7 +101,7 @@ class Processor(ConsumerProducer):
s_value = Value(value=str(s_uri), is_uri=True)
o_value = Value(value=str(o), is_uri=False)
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
self.emit_edge(v. metadata, s_value, DEFINITION_VALUE, o_value)
except Exception as e:
print("Exception: ", e, flush=True)

View file

@ -8,7 +8,7 @@ import urllib.parse
import os
from pulsar.schema import JsonSchema
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Source
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Metadata
from .... schema import RowSchema, Field
from .... schema import chunk_embeddings_ingest_queue, rows_store_queue
from .... schema import object_embeddings_store_queue
@ -124,24 +124,24 @@ class Processor(ConsumerProducer):
def get_rows(self, chunk):
return self.prompt.request_rows(self.schema, chunk)
def emit_rows(self, source, rows):
def emit_rows(self, metadata, rows):
t = Rows(
source=source, row_schema=self.row_schema, rows=rows
metadata=metadata, row_schema=self.row_schema, rows=rows
)
self.producer.send(t)
def emit_vec(self, source, name, vec, key_name, key):
def emit_vec(self, metadata, name, vec, key_name, key):
r = ObjectEmbeddings(
source=source, vectors=vec, name=name, key_name=key_name, id=key
metadata=metadata, vectors=vec, name=name, key_name=key_name, id=key
)
self.vec_prod.send(r)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
@ -150,13 +150,13 @@ class Processor(ConsumerProducer):
rows = self.get_rows(chunk)
self.emit_rows(
source=v.source,
metadata=v.metadata,
rows=rows
)
for row in rows:
self.emit_vec(
source=v.source, vec=v.vectors,
metadata=v.metadata, vec=v.vectors,
name=self.schema.name, key_name=self.primary.name,
key=row[self.primary.name]
)

View file

@ -18,6 +18,144 @@ from . schema import triples_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(self, rag, user, collection, verbose):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.rag.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_entities(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
entities = self.rag.ge_client.request(
user=self.user, collection=self.collection,
vectors=vectors, limit=self.rag.entity_limit,
)
entities = [
e.value
for e in entities
]
if self.verbose:
print("Entities:", flush=True)
for ent in entities:
print(" ", ent, flush=True)
return entities
def maybe_label(self, e):
if e in self.rag.label_cache:
return self.rag.label_cache[e]
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=e, p=LABEL, o=None, limit=1,
)
if len(res) == 0:
self.rag.label_cache[e] = e
return e
self.rag.label_cache[e] = res[0].o.value
return self.rag.label_cache[e]
def get_subgraph(self, query):
entities = self.get_entities(query)
subgraph = set()
if self.verbose:
print("Get subgraph...", flush=True)
for e in entities:
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=e, p=None, o=None,
limit=self.rag.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=e, o=None,
limit=self.rag.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=None, o=e,
limit=self.rag.query_limit,
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
subgraph = list(subgraph)
subgraph = subgraph[0:self.rag.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in subgraph:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return subgraph
def get_labelgraph(self, query):
subgraph = self.get_subgraph(query)
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = self.maybe_label(edge[0])
p = self.maybe_label(edge[1])
o = self.maybe_label(edge[2])
sg2.append((s, p, o))
return sg2
class GraphRag:
def __init__(
@ -94,7 +232,7 @@ class GraphRag:
self.label_cache = {}
self.lang = PromptClient(
self.prompt = PromptClient(
pulsar_host=pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
@ -104,144 +242,23 @@ class GraphRag:
if self.verbose:
print("Initialised", flush=True)
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_entities(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
entities = self.ge_client.request(
vectors, self.entity_limit
)
entities = [
e.value
for e in entities
]
if self.verbose:
print("Entities:", flush=True)
for ent in entities:
print(" ", ent, flush=True)
return entities
def maybe_label(self, e):
if e in self.label_cache:
return self.label_cache[e]
res = self.triples_client.request(
e, LABEL, None, limit=1
)
if len(res) == 0:
self.label_cache[e] = e
return e
self.label_cache[e] = res[0].o.value
return self.label_cache[e]
def get_subgraph(self, query):
entities = self.get_entities(query)
subgraph = set()
if self.verbose:
print("Get subgraph...", flush=True)
for e in entities:
res = self.triples_client.request(
e, None, None,
limit=self.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.triples_client.request(
None, e, None,
limit=self.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.triples_client.request(
None, None, e,
limit=self.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
subgraph = list(subgraph)
subgraph = subgraph[0:self.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in subgraph:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return subgraph
def get_labelgraph(self, query):
subgraph = self.get_subgraph(query)
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = self.maybe_label(edge[0])
p = self.maybe_label(edge[1])
o = self.maybe_label(edge[2])
sg2.append((s, p, o))
return sg2
def query(self, query):
def query(self, query, user="trustgraph", collection="default"):
if self.verbose:
print("Construct prompt...", flush=True)
kg = self.get_labelgraph(query)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose
)
kg = q.get_labelgraph(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(kg)
print(query)
resp = self.lang.request_kg_prompt(query, kg)
resp = self.prompt.request_kg_prompt(query, kg)
if self.verbose:
print("Done", flush=True)

View file

@ -60,7 +60,10 @@ class Processor(ConsumerProducer):
for vec in v.vectors:
dim = len(vec)
collection = "doc_" + str(dim)
collection = (
"d_" + v.user + "_" + v.collection + "_" +
str(dim)
)
search_result = self.client.query_points(
collection_name=collection,

View file

@ -66,7 +66,10 @@ class Processor(ConsumerProducer):
for vec in v.vectors:
dim = len(vec)
collection = "triples_" + str(dim)
collection = (
"t_" + v.user + "_" + v.collection + "_" +
str(dim)
)
search_result = self.client.query_points(
collection_name=collection,

View file

@ -38,7 +38,8 @@ class Processor(ConsumerProducer):
}
)
self.tg = TrustGraph([graph_host])
self.graph_host = [graph_host]
self.table = None
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
@ -52,6 +53,15 @@ class Processor(ConsumerProducer):
v = msg.value()
table = (v.user, v.collection)
if table != self.table:
self.tg = TrustGraph(
hosts=self.graph_host,
keyspace=v.user, table=v.collection,
)
self.table = table
# Sender-produced ID
id = msg.properties()["id"]

View file

@ -108,7 +108,9 @@ class Processor(ConsumerProducer):
print(f"Handling input {id}...", flush=True)
response = self.rag.query(v.query)
response = self.rag.query(
query=v.query, user=v.user, collection=v.collection
)
print("Send response...", flush=True)
r = GraphRagResponse(response = response, error=None)

View file

@ -37,7 +37,6 @@ class Processor(Consumer):
)
self.last_collection = None
self.last_dim = None
self.client = QdrantClient(url=store_uri)
@ -52,9 +51,12 @@ class Processor(Consumer):
for vec in v.vectors:
dim = len(vec)
collection = "doc_" + str(dim)
collection = (
"d_" + v.metadata.user + "_" + v.metadata.collection + "_" +
str(dim)
)
if dim != self.last_dim:
if collection != self.last_collection:
if not self.client.collection_exists(collection):
@ -70,7 +72,6 @@ class Processor(Consumer):
raise e
self.last_collection = collection
self.last_dim = dim
self.client.upsert(
collection_name=collection,

View file

@ -37,7 +37,6 @@ class Processor(Consumer):
)
self.last_collection = None
self.last_dim = None
self.client = QdrantClient(url=store_uri)
@ -50,9 +49,12 @@ class Processor(Consumer):
for vec in v.vectors:
dim = len(vec)
collection = "triples_" + str(dim)
collection = (
"t_" + v.metadata.user + "_" + v.metadata.collection + "_" +
str(dim)
)
if dim != self.last_dim:
if collection != self.last_collection:
if not self.client.collection_exists(collection):
@ -68,7 +70,6 @@ class Processor(Consumer):
raise e
self.last_collection = collection
self.last_dim = dim
self.client.upsert(
collection_name=collection,

View file

@ -38,12 +38,31 @@ class Processor(Consumer):
}
)
self.tg = TrustGraph([graph_host])
self.graph_host = [graph_host]
self.table = None
def handle(self, msg):
v = msg.value()
table = (v.metadata.user, v.metadata.collection)
if self.table is None or self.table != table:
self.tg = None
try:
self.tg = TrustGraph(
hosts=self.graph_host,
keyspace=v.metadata.user, table=v.metadata.collection,
)
except Exception as e:
print("Exception", e, flush=True)
time.sleep(1)
raise e
self.table = table
self.tg.insert(
v.s.value,
v.p.value,