Feature / collections (#96)

* Update schema defs for source -> metadata
* Migrate to use metadata part of schema, also add metadata to triples & vecs
* Add user/collection metadata to query
* Use user/collection in RAG
* Write and query working on triples
This commit is contained in:
cybermaggedon 2024-10-02 18:14:29 +01:00 committed by GitHub
parent 709221fa10
commit b0f4c58200
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 459 additions and 251 deletions

View file

@ -7,7 +7,7 @@ as text as separate output objects.
from langchain_text_splitters import RecursiveCharacterTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Source
from ... schema import TextDocument, Chunk, Metadata
from ... schema import text_ingest_queue, chunk_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
@ -55,7 +55,7 @@ class Processor(ConsumerProducer):
def handle(self, msg):
v = msg.value()
print(f"Chunking {v.source.id}...", flush=True)
print(f"Chunking {v.metadata.id}...", flush=True)
texts = self.text_splitter.create_documents(
[v.text.decode("utf-8")]
@ -63,13 +63,15 @@ class Processor(ConsumerProducer):
for ix, chunk in enumerate(texts):
id = v.source.id + "-c" + str(ix)
id = v.metadata.id + "-c" + str(ix)
r = Chunk(
source=Source(
source=v.source.source,
metadata=Metadata(
source=v.metadata.source,
id=id,
title=v.source.title
title=v.metadata.title,
user=v.metadata.user,
collection=v.metadata.collection,
),
chunk=chunk.page_content.encode("utf-8"),
)

View file

@ -7,7 +7,7 @@ as text as separate output objects.
from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Source
from ... schema import TextDocument, Chunk, Metadata
from ... schema import text_ingest_queue, chunk_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
@ -54,7 +54,7 @@ class Processor(ConsumerProducer):
def handle(self, msg):
v = msg.value()
print(f"Chunking {v.source.id}...", flush=True)
print(f"Chunking {v.metadata.id}...", flush=True)
texts = self.text_splitter.create_documents(
[v.text.decode("utf-8")]
@ -62,13 +62,15 @@ class Processor(ConsumerProducer):
for ix, chunk in enumerate(texts):
id = v.source.id + "-c" + str(ix)
id = v.metadata.id + "-c" + str(ix)
r = Chunk(
source=Source(
source=v.source.source,
metadata=Metadata(
source=v.metadata.source,
id=id,
title=v.source.title
title=v.metadata.title,
user=v.metadata.user,
collection=v.metadata.collection,
),
chunk=chunk.page_content.encode("utf-8"),
)

View file

@ -8,7 +8,7 @@ import tempfile
import base64
from langchain_community.document_loaders import PyPDFLoader
from ... schema import Document, TextDocument, Source
from ... schema import Document, TextDocument, Metadata
from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
@ -45,7 +45,7 @@ class Processor(ConsumerProducer):
v = msg.value()
print(f"Decoding {v.source.id}...", flush=True)
print(f"Decoding {v.metadata.id}...", flush=True)
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
@ -59,12 +59,14 @@ class Processor(ConsumerProducer):
for ix, page in enumerate(pages):
id = v.source.id + "-p" + str(ix)
id = v.metadata.id + "-p" + str(ix)
r = TextDocument(
source=Source(
source=v.source.source,
title=v.source.title,
metadata=Metadata(
source=v.metadata.source,
title=v.metadata.title,
id=id,
user=v.metadata.user,
collection=v.metadata.collection,
),
text=page.page_content.encode("utf-8"),
)

View file

@ -4,10 +4,16 @@ from cassandra.auth import PlainTextAuthProvider
class TrustGraph:
def __init__(self, hosts=None):
def __init__(
self, hosts=None,
keyspace="trustgraph", table="default",
):
if hosts is None:
hosts = ["localhost"]
self.keyspace = keyspace
self.table = table
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
@ -16,26 +22,26 @@ class TrustGraph:
def clear(self):
self.session.execute("""
drop keyspace if exists trustgraph;
self.session.execute(f"""
drop keyspace if exists {self.keyspace};
""");
self.init()
def init(self):
self.session.execute("""
create keyspace if not exists trustgraph
with replication = {
self.session.execute(f"""
create keyspace if not exists {self.keyspace}
with replication = {{
'class' : 'SimpleStrategy',
'replication_factor' : 1
};
}};
""");
self.session.set_keyspace('trustgraph')
self.session.set_keyspace(self.keyspace)
self.session.execute("""
create table if not exists triples (
self.session.execute(f"""
create table if not exists {self.table} (
s text,
p text,
o text,
@ -43,66 +49,66 @@ class TrustGraph:
);
""");
self.session.execute("""
create index if not exists triples_p
ON triples (p);
self.session.execute(f"""
create index if not exists {self.table}_p
ON {self.table} (p);
""");
self.session.execute("""
create index if not exists triples_o
ON triples (o);
self.session.execute(f"""
create index if not exists {self.table}_o
ON {self.table} (o);
""");
def insert(self, s, p, o):
self.session.execute(
"insert into triples (s, p, o) values (%s, %s, %s)",
f"insert into {self.table} (s, p, o) values (%s, %s, %s)",
(s, p, o)
)
def get_all(self, limit=50):
return self.session.execute(
f"select s, p, o from triples limit {limit}"
f"select s, p, o from {self.table} limit {limit}"
)
def get_s(self, s, limit=10):
return self.session.execute(
f"select p, o from triples where s = %s limit {limit}",
f"select p, o from {self.table} where s = %s limit {limit}",
(s,)
)
def get_p(self, p, limit=10):
return self.session.execute(
f"select s, o from triples where p = %s limit {limit}",
f"select s, o from {self.table} where p = %s limit {limit}",
(p,)
)
def get_o(self, o, limit=10):
return self.session.execute(
f"select s, p from triples where o = %s limit {limit}",
f"select s, p from {self.table} where o = %s limit {limit}",
(o,)
)
def get_sp(self, s, p, limit=10):
return self.session.execute(
f"select o from triples where s = %s and p = %s limit {limit}",
f"select o from {self.table} where s = %s and p = %s limit {limit}",
(s, p)
)
def get_po(self, p, o, limit=10):
return self.session.execute(
f"select s from triples where p = %s and o = %s allow filtering limit {limit}",
f"select s from {self.table} where p = %s and o = %s allow filtering limit {limit}",
(p, o)
)
def get_os(self, o, s, limit=10):
return self.session.execute(
f"select p from triples where o = %s and s = %s limit {limit}",
f"select p from {self.table} where o = %s and s = %s limit {limit}",
(o, s)
)
def get_spo(self, s, p, o, limit=10):
return self.session.execute(
f"""select s as x from triples where s = %s and p = %s and o = %s limit {limit}""",
f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""",
(s, p, o)
)

View file

@ -50,15 +50,15 @@ class Processor(ConsumerProducer):
subscriber=module + "-emb",
)
def emit(self, source, chunk, vectors):
def emit(self, metadata, chunk, vectors):
r = ChunkEmbeddings(source=source, chunk=chunk, vectors=vectors)
r = ChunkEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors)
self.producer.send(r)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
@ -67,7 +67,7 @@ class Processor(ConsumerProducer):
vectors = self.embeddings.request(chunk)
self.emit(
source=v.source,
metadata=v.metadata,
chunk=chunk.encode("utf-8"),
vectors=vectors
)

View file

@ -7,7 +7,7 @@ get entity definitions which are output as graph edges.
import urllib.parse
import json
from .... schema import ChunkEmbeddings, Triple, Source, Value
from .... schema import ChunkEmbeddings, Triple, Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
@ -69,15 +69,15 @@ class Processor(ConsumerProducer):
return self.prompt.request_definitions(chunk)
def emit_edge(self, s, p, o):
def emit_edge(self, metadata, s, p, o):
t = Triple(s=s, p=p, o=o)
t = Triple(metadata=metadata, s=s, p=p, o=o)
self.producer.send(t)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
@ -101,7 +101,7 @@ class Processor(ConsumerProducer):
s_value = Value(value=str(s_uri), is_uri=True)
o_value = Value(value=str(o), is_uri=False)
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
self.emit_edge(v.metadata, s_value, DEFINITION_VALUE, o_value)
except Exception as e:
print("Exception: ", e, flush=True)

View file

@ -9,7 +9,8 @@ import urllib.parse
import os
from pulsar.schema import JsonSchema
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings
from .... schema import Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import graph_embeddings_store_queue
from .... schema import prompt_request_queue
@ -91,20 +92,20 @@ class Processor(ConsumerProducer):
return self.prompt.request_relationships(chunk)
def emit_edge(self, s, p, o):
def emit_edge(self, metadata, s, p, o):
t = Triple(s=s, p=p, o=o)
t = Triple(metadata=metadata, s=s, p=p, o=o)
self.producer.send(t)
def emit_vec(self, ent, vec):
def emit_vec(self, metadata, ent, vec):
r = GraphEmbeddings(entity=ent, vectors=vec)
r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec)
self.vec_prod.send(r)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
@ -139,6 +140,7 @@ class Processor(ConsumerProducer):
o_value = Value(value=str(o), is_uri=False)
self.emit_edge(
v.metadata,
s_value,
p_value,
o_value
@ -146,6 +148,7 @@ class Processor(ConsumerProducer):
# Label for s
self.emit_edge(
v.metadata,
s_value,
RDF_LABEL_VALUE,
Value(value=str(s), is_uri=False)
@ -153,6 +156,7 @@ class Processor(ConsumerProducer):
# Label for p
self.emit_edge(
v.metadata,
p_value,
RDF_LABEL_VALUE,
Value(value=str(p), is_uri=False)
@ -161,15 +165,16 @@ class Processor(ConsumerProducer):
if rel.o_entity:
# Label for o
self.emit_edge(
v.metadata,
o_value,
RDF_LABEL_VALUE,
Value(value=str(o), is_uri=False)
)
self.emit_vec(s_value, v.vectors)
self.emit_vec(p_value, v.vectors)
self.emit_vec(v.metadata, s_value, v.vectors)
self.emit_vec(v.metadata, p_value, v.vectors)
if rel.o_entity:
self.emit_vec(o_value, v.vectors)
self.emit_vec(v.metadata, o_value, v.vectors)
except Exception as e:
print("Exception: ", e, flush=True)

View file

@ -7,7 +7,7 @@ get entity definitions which are output as graph edges.
import urllib.parse
import json
from .... schema import ChunkEmbeddings, Triple, Source, Value
from .... schema import ChunkEmbeddings, Triple, Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
@ -69,15 +69,15 @@ class Processor(ConsumerProducer):
return self.prompt.request_topics(chunk)
def emit_edge(self, s, p, o):
def emit_edge(self, metadata, s, p, o):
t = Triple(s=s, p=p, o=o)
t = Triple(metadata=metadata, s=s, p=p, o=o)
self.producer.send(t)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
@ -101,7 +101,7 @@ class Processor(ConsumerProducer):
s_value = Value(value=str(s_uri), is_uri=True)
o_value = Value(value=str(o), is_uri=False)
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
self.emit_edge(v. metadata, s_value, DEFINITION_VALUE, o_value)
except Exception as e:
print("Exception: ", e, flush=True)

View file

@ -8,7 +8,7 @@ import urllib.parse
import os
from pulsar.schema import JsonSchema
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Source
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Metadata
from .... schema import RowSchema, Field
from .... schema import chunk_embeddings_ingest_queue, rows_store_queue
from .... schema import object_embeddings_store_queue
@ -124,24 +124,24 @@ class Processor(ConsumerProducer):
def get_rows(self, chunk):
return self.prompt.request_rows(self.schema, chunk)
def emit_rows(self, source, rows):
def emit_rows(self, metadata, rows):
t = Rows(
source=source, row_schema=self.row_schema, rows=rows
metadata=metadata, row_schema=self.row_schema, rows=rows
)
self.producer.send(t)
def emit_vec(self, source, name, vec, key_name, key):
def emit_vec(self, metadata, name, vec, key_name, key):
r = ObjectEmbeddings(
source=source, vectors=vec, name=name, key_name=key_name, id=key
metadata=metadata, vectors=vec, name=name, key_name=key_name, id=key
)
self.vec_prod.send(r)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
@ -150,13 +150,13 @@ class Processor(ConsumerProducer):
rows = self.get_rows(chunk)
self.emit_rows(
source=v.source,
metadata=v.metadata,
rows=rows
)
for row in rows:
self.emit_vec(
source=v.source, vec=v.vectors,
metadata=v.metadata, vec=v.vectors,
name=self.schema.name, key_name=self.primary.name,
key=row[self.primary.name]
)

View file

@ -18,6 +18,144 @@ from . schema import triples_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class Query:
def __init__(self, rag, user, collection, verbose):
self.rag = rag
self.user = user
self.collection = collection
self.verbose = verbose
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.rag.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_entities(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
entities = self.rag.ge_client.request(
user=self.user, collection=self.collection,
vectors=vectors, limit=self.rag.entity_limit,
)
entities = [
e.value
for e in entities
]
if self.verbose:
print("Entities:", flush=True)
for ent in entities:
print(" ", ent, flush=True)
return entities
def maybe_label(self, e):
if e in self.rag.label_cache:
return self.rag.label_cache[e]
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=e, p=LABEL, o=None, limit=1,
)
if len(res) == 0:
self.rag.label_cache[e] = e
return e
self.rag.label_cache[e] = res[0].o.value
return self.rag.label_cache[e]
def get_subgraph(self, query):
entities = self.get_entities(query)
subgraph = set()
if self.verbose:
print("Get subgraph...", flush=True)
for e in entities:
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=e, p=None, o=None,
limit=self.rag.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=e, o=None,
limit=self.rag.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.rag.triples_client.request(
user=self.user, collection=self.collection,
s=None, p=None, o=e,
limit=self.rag.query_limit,
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
subgraph = list(subgraph)
subgraph = subgraph[0:self.rag.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in subgraph:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return subgraph
def get_labelgraph(self, query):
subgraph = self.get_subgraph(query)
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = self.maybe_label(edge[0])
p = self.maybe_label(edge[1])
o = self.maybe_label(edge[2])
sg2.append((s, p, o))
return sg2
class GraphRag:
def __init__(
@ -94,7 +232,7 @@ class GraphRag:
self.label_cache = {}
self.lang = PromptClient(
self.prompt = PromptClient(
pulsar_host=pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
@ -104,144 +242,23 @@ class GraphRag:
if self.verbose:
print("Initialised", flush=True)
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_entities(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
entities = self.ge_client.request(
vectors, self.entity_limit
)
entities = [
e.value
for e in entities
]
if self.verbose:
print("Entities:", flush=True)
for ent in entities:
print(" ", ent, flush=True)
return entities
def maybe_label(self, e):
if e in self.label_cache:
return self.label_cache[e]
res = self.triples_client.request(
e, LABEL, None, limit=1
)
if len(res) == 0:
self.label_cache[e] = e
return e
self.label_cache[e] = res[0].o.value
return self.label_cache[e]
def get_subgraph(self, query):
entities = self.get_entities(query)
subgraph = set()
if self.verbose:
print("Get subgraph...", flush=True)
for e in entities:
res = self.triples_client.request(
e, None, None,
limit=self.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.triples_client.request(
None, e, None,
limit=self.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.triples_client.request(
None, None, e,
limit=self.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
subgraph = list(subgraph)
subgraph = subgraph[0:self.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in subgraph:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return subgraph
def get_labelgraph(self, query):
subgraph = self.get_subgraph(query)
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = self.maybe_label(edge[0])
p = self.maybe_label(edge[1])
o = self.maybe_label(edge[2])
sg2.append((s, p, o))
return sg2
def query(self, query):
def query(self, query, user="trustgraph", collection="default"):
if self.verbose:
print("Construct prompt...", flush=True)
kg = self.get_labelgraph(query)
q = Query(
rag=self, user=user, collection=collection, verbose=self.verbose
)
kg = q.get_labelgraph(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(kg)
print(query)
resp = self.lang.request_kg_prompt(query, kg)
resp = self.prompt.request_kg_prompt(query, kg)
if self.verbose:
print("Done", flush=True)

View file

@ -60,7 +60,10 @@ class Processor(ConsumerProducer):
for vec in v.vectors:
dim = len(vec)
collection = "doc_" + str(dim)
collection = (
"d_" + v.user + "_" + v.collection + "_" +
str(dim)
)
search_result = self.client.query_points(
collection_name=collection,

View file

@ -66,7 +66,10 @@ class Processor(ConsumerProducer):
for vec in v.vectors:
dim = len(vec)
collection = "triples_" + str(dim)
collection = (
"t_" + v.user + "_" + v.collection + "_" +
str(dim)
)
search_result = self.client.query_points(
collection_name=collection,

View file

@ -38,7 +38,8 @@ class Processor(ConsumerProducer):
}
)
self.tg = TrustGraph([graph_host])
self.graph_host = [graph_host]
self.table = None
def create_value(self, ent):
if ent.startswith("http://") or ent.startswith("https://"):
@ -52,6 +53,15 @@ class Processor(ConsumerProducer):
v = msg.value()
table = (v.user, v.collection)
if table != self.table:
self.tg = TrustGraph(
hosts=self.graph_host,
keyspace=v.user, table=v.collection,
)
self.table = table
# Sender-produced ID
id = msg.properties()["id"]

View file

@ -108,7 +108,9 @@ class Processor(ConsumerProducer):
print(f"Handling input {id}...", flush=True)
response = self.rag.query(v.query)
response = self.rag.query(
query=v.query, user=v.user, collection=v.collection
)
print("Send response...", flush=True)
r = GraphRagResponse(response = response, error=None)

View file

@ -37,7 +37,6 @@ class Processor(Consumer):
)
self.last_collection = None
self.last_dim = None
self.client = QdrantClient(url=store_uri)
@ -52,9 +51,12 @@ class Processor(Consumer):
for vec in v.vectors:
dim = len(vec)
collection = "doc_" + str(dim)
collection = (
"d_" + v.metadata.user + "_" + v.metadata.collection + "_" +
str(dim)
)
if dim != self.last_dim:
if collection != self.last_collection:
if not self.client.collection_exists(collection):
@ -70,7 +72,6 @@ class Processor(Consumer):
raise e
self.last_collection = collection
self.last_dim = dim
self.client.upsert(
collection_name=collection,

View file

@ -37,7 +37,6 @@ class Processor(Consumer):
)
self.last_collection = None
self.last_dim = None
self.client = QdrantClient(url=store_uri)
@ -50,9 +49,12 @@ class Processor(Consumer):
for vec in v.vectors:
dim = len(vec)
collection = "triples_" + str(dim)
collection = (
"t_" + v.metadata.user + "_" + v.metadata.collection + "_" +
str(dim)
)
if dim != self.last_dim:
if collection != self.last_collection:
if not self.client.collection_exists(collection):
@ -68,7 +70,6 @@ class Processor(Consumer):
raise e
self.last_collection = collection
self.last_dim = dim
self.client.upsert(
collection_name=collection,

View file

@ -38,12 +38,31 @@ class Processor(Consumer):
}
)
self.tg = TrustGraph([graph_host])
self.graph_host = [graph_host]
self.table = None
def handle(self, msg):
v = msg.value()
table = (v.metadata.user, v.metadata.collection)
if self.table is None or self.table != table:
self.tg = None
try:
self.tg = TrustGraph(
hosts=self.graph_host,
keyspace=v.metadata.user, table=v.metadata.collection,
)
except Exception as e:
print("Exception", e, flush=True)
time.sleep(1)
raise e
self.table = table
self.tg.insert(
v.s.value,
v.p.value,