mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-04 12:25:13 +02:00
Feature / collections (#96)
* Update schema defs for source -> metadata * Migrate to use metadata part of schema, also add metadata to triples & vecs * Add user/collection metadata to query * Use user/collection in RAG * Write and query working on triples
This commit is contained in:
parent
709221fa10
commit
b0f4c58200
31 changed files with 459 additions and 251 deletions
|
|
@ -7,7 +7,7 @@ as text as separate output objects.
|
|||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk, Source
|
||||
from ... schema import TextDocument, Chunk, Metadata
|
||||
from ... schema import text_ingest_queue, chunk_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
|
@ -55,7 +55,7 @@ class Processor(ConsumerProducer):
|
|||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Chunking {v.source.id}...", flush=True)
|
||||
print(f"Chunking {v.metadata.id}...", flush=True)
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
|
|
@ -63,13 +63,15 @@ class Processor(ConsumerProducer):
|
|||
|
||||
for ix, chunk in enumerate(texts):
|
||||
|
||||
id = v.source.id + "-c" + str(ix)
|
||||
id = v.metadata.id + "-c" + str(ix)
|
||||
|
||||
r = Chunk(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
metadata=Metadata(
|
||||
source=v.metadata.source,
|
||||
id=id,
|
||||
title=v.source.title
|
||||
title=v.metadata.title,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ as text as separate output objects.
|
|||
from langchain_text_splitters import TokenTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk, Source
|
||||
from ... schema import TextDocument, Chunk, Metadata
|
||||
from ... schema import text_ingest_queue, chunk_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
|
@ -54,7 +54,7 @@ class Processor(ConsumerProducer):
|
|||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Chunking {v.source.id}...", flush=True)
|
||||
print(f"Chunking {v.metadata.id}...", flush=True)
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
|
|
@ -62,13 +62,15 @@ class Processor(ConsumerProducer):
|
|||
|
||||
for ix, chunk in enumerate(texts):
|
||||
|
||||
id = v.source.id + "-c" + str(ix)
|
||||
id = v.metadata.id + "-c" + str(ix)
|
||||
|
||||
r = Chunk(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
metadata=Metadata(
|
||||
source=v.metadata.source,
|
||||
id=id,
|
||||
title=v.source.title
|
||||
title=v.metadata.title,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import tempfile
|
|||
import base64
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
|
||||
from ... schema import Document, TextDocument, Source
|
||||
from ... schema import Document, TextDocument, Metadata
|
||||
from ... schema import document_ingest_queue, text_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
|
@ -45,7 +45,7 @@ class Processor(ConsumerProducer):
|
|||
|
||||
v = msg.value()
|
||||
|
||||
print(f"Decoding {v.source.id}...", flush=True)
|
||||
print(f"Decoding {v.metadata.id}...", flush=True)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
|
||||
|
||||
|
|
@ -59,12 +59,14 @@ class Processor(ConsumerProducer):
|
|||
|
||||
for ix, page in enumerate(pages):
|
||||
|
||||
id = v.source.id + "-p" + str(ix)
|
||||
id = v.metadata.id + "-p" + str(ix)
|
||||
r = TextDocument(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
title=v.source.title,
|
||||
metadata=Metadata(
|
||||
source=v.metadata.source,
|
||||
title=v.metadata.title,
|
||||
id=id,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
text=page.page_content.encode("utf-8"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,10 +4,16 @@ from cassandra.auth import PlainTextAuthProvider
|
|||
|
||||
class TrustGraph:
|
||||
|
||||
def __init__(self, hosts=None):
|
||||
def __init__(
|
||||
self, hosts=None,
|
||||
keyspace="trustgraph", table="default",
|
||||
):
|
||||
|
||||
if hosts is None:
|
||||
hosts = ["localhost"]
|
||||
|
||||
self.keyspace = keyspace
|
||||
self.table = table
|
||||
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
|
@ -16,26 +22,26 @@ class TrustGraph:
|
|||
|
||||
def clear(self):
|
||||
|
||||
self.session.execute("""
|
||||
drop keyspace if exists trustgraph;
|
||||
self.session.execute(f"""
|
||||
drop keyspace if exists {self.keyspace};
|
||||
""");
|
||||
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
|
||||
self.session.execute("""
|
||||
create keyspace if not exists trustgraph
|
||||
with replication = {
|
||||
self.session.execute(f"""
|
||||
create keyspace if not exists {self.keyspace}
|
||||
with replication = {{
|
||||
'class' : 'SimpleStrategy',
|
||||
'replication_factor' : 1
|
||||
};
|
||||
}};
|
||||
""");
|
||||
|
||||
self.session.set_keyspace('trustgraph')
|
||||
self.session.set_keyspace(self.keyspace)
|
||||
|
||||
self.session.execute("""
|
||||
create table if not exists triples (
|
||||
self.session.execute(f"""
|
||||
create table if not exists {self.table} (
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
|
|
@ -43,66 +49,66 @@ class TrustGraph:
|
|||
);
|
||||
""");
|
||||
|
||||
self.session.execute("""
|
||||
create index if not exists triples_p
|
||||
ON triples (p);
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_p
|
||||
ON {self.table} (p);
|
||||
""");
|
||||
|
||||
self.session.execute("""
|
||||
create index if not exists triples_o
|
||||
ON triples (o);
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_o
|
||||
ON {self.table} (o);
|
||||
""");
|
||||
|
||||
def insert(self, s, p, o):
|
||||
|
||||
self.session.execute(
|
||||
"insert into triples (s, p, o) values (%s, %s, %s)",
|
||||
f"insert into {self.table} (s, p, o) values (%s, %s, %s)",
|
||||
(s, p, o)
|
||||
)
|
||||
|
||||
def get_all(self, limit=50):
|
||||
return self.session.execute(
|
||||
f"select s, p, o from triples limit {limit}"
|
||||
f"select s, p, o from {self.table} limit {limit}"
|
||||
)
|
||||
|
||||
def get_s(self, s, limit=10):
|
||||
return self.session.execute(
|
||||
f"select p, o from triples where s = %s limit {limit}",
|
||||
f"select p, o from {self.table} where s = %s limit {limit}",
|
||||
(s,)
|
||||
)
|
||||
|
||||
def get_p(self, p, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s, o from triples where p = %s limit {limit}",
|
||||
f"select s, o from {self.table} where p = %s limit {limit}",
|
||||
(p,)
|
||||
)
|
||||
|
||||
def get_o(self, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s, p from triples where o = %s limit {limit}",
|
||||
f"select s, p from {self.table} where o = %s limit {limit}",
|
||||
(o,)
|
||||
)
|
||||
|
||||
def get_sp(self, s, p, limit=10):
|
||||
return self.session.execute(
|
||||
f"select o from triples where s = %s and p = %s limit {limit}",
|
||||
f"select o from {self.table} where s = %s and p = %s limit {limit}",
|
||||
(s, p)
|
||||
)
|
||||
|
||||
def get_po(self, p, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s from triples where p = %s and o = %s allow filtering limit {limit}",
|
||||
f"select s from {self.table} where p = %s and o = %s allow filtering limit {limit}",
|
||||
(p, o)
|
||||
)
|
||||
|
||||
def get_os(self, o, s, limit=10):
|
||||
return self.session.execute(
|
||||
f"select p from triples where o = %s and s = %s limit {limit}",
|
||||
f"select p from {self.table} where o = %s and s = %s limit {limit}",
|
||||
(o, s)
|
||||
)
|
||||
|
||||
def get_spo(self, s, p, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"""select s as x from triples where s = %s and p = %s and o = %s limit {limit}""",
|
||||
f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""",
|
||||
(s, p, o)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -50,15 +50,15 @@ class Processor(ConsumerProducer):
|
|||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
def emit(self, source, chunk, vectors):
|
||||
def emit(self, metadata, chunk, vectors):
|
||||
|
||||
r = ChunkEmbeddings(source=source, chunk=chunk, vectors=vectors)
|
||||
r = ChunkEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors)
|
||||
self.producer.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ class Processor(ConsumerProducer):
|
|||
vectors = self.embeddings.request(chunk)
|
||||
|
||||
self.emit(
|
||||
source=v.source,
|
||||
metadata=v.metadata,
|
||||
chunk=chunk.encode("utf-8"),
|
||||
vectors=vectors
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ get entity definitions which are output as graph edges.
|
|||
import urllib.parse
|
||||
import json
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Source, Value
|
||||
from .... schema import ChunkEmbeddings, Triple, Metadata, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
|
|
@ -69,15 +69,15 @@ class Processor(ConsumerProducer):
|
|||
|
||||
return self.prompt.request_definitions(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
def emit_edge(self, metadata, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
t = Triple(metadata=metadata, s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ class Processor(ConsumerProducer):
|
|||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
|
||||
self.emit_edge(v.metadata, s_value, DEFINITION_VALUE, o_value)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ import urllib.parse
|
|||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
|
||||
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings
|
||||
from .... schema import Metadata, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import graph_embeddings_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
|
|
@ -91,20 +92,20 @@ class Processor(ConsumerProducer):
|
|||
|
||||
return self.prompt.request_relationships(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
def emit_edge(self, metadata, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
t = Triple(metadata=metadata, s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_vec(self, ent, vec):
|
||||
def emit_vec(self, metadata, ent, vec):
|
||||
|
||||
r = GraphEmbeddings(entity=ent, vectors=vec)
|
||||
r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec)
|
||||
self.vec_prod.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
|
|
@ -139,6 +140,7 @@ class Processor(ConsumerProducer):
|
|||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(
|
||||
v.metadata,
|
||||
s_value,
|
||||
p_value,
|
||||
o_value
|
||||
|
|
@ -146,6 +148,7 @@ class Processor(ConsumerProducer):
|
|||
|
||||
# Label for s
|
||||
self.emit_edge(
|
||||
v.metadata,
|
||||
s_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(s), is_uri=False)
|
||||
|
|
@ -153,6 +156,7 @@ class Processor(ConsumerProducer):
|
|||
|
||||
# Label for p
|
||||
self.emit_edge(
|
||||
v.metadata,
|
||||
p_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(p), is_uri=False)
|
||||
|
|
@ -161,15 +165,16 @@ class Processor(ConsumerProducer):
|
|||
if rel.o_entity:
|
||||
# Label for o
|
||||
self.emit_edge(
|
||||
v.metadata,
|
||||
o_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(o), is_uri=False)
|
||||
)
|
||||
|
||||
self.emit_vec(s_value, v.vectors)
|
||||
self.emit_vec(p_value, v.vectors)
|
||||
self.emit_vec(v.metadata, s_value, v.vectors)
|
||||
self.emit_vec(v.metadata, p_value, v.vectors)
|
||||
if rel.o_entity:
|
||||
self.emit_vec(o_value, v.vectors)
|
||||
self.emit_vec(v.metadata, o_value, v.vectors)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ get entity definitions which are output as graph edges.
|
|||
import urllib.parse
|
||||
import json
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Source, Value
|
||||
from .... schema import ChunkEmbeddings, Triple, Metadata, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
|
|
@ -69,15 +69,15 @@ class Processor(ConsumerProducer):
|
|||
|
||||
return self.prompt.request_topics(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
def emit_edge(self, metadata, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
t = Triple(metadata=metadata, s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ class Processor(ConsumerProducer):
|
|||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
|
||||
self.emit_edge(v. metadata, s_value, DEFINITION_VALUE, o_value)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import urllib.parse
|
|||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Source
|
||||
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Metadata
|
||||
from .... schema import RowSchema, Field
|
||||
from .... schema import chunk_embeddings_ingest_queue, rows_store_queue
|
||||
from .... schema import object_embeddings_store_queue
|
||||
|
|
@ -124,24 +124,24 @@ class Processor(ConsumerProducer):
|
|||
def get_rows(self, chunk):
|
||||
return self.prompt.request_rows(self.schema, chunk)
|
||||
|
||||
def emit_rows(self, source, rows):
|
||||
def emit_rows(self, metadata, rows):
|
||||
|
||||
t = Rows(
|
||||
source=source, row_schema=self.row_schema, rows=rows
|
||||
metadata=metadata, row_schema=self.row_schema, rows=rows
|
||||
)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_vec(self, source, name, vec, key_name, key):
|
||||
def emit_vec(self, metadata, name, vec, key_name, key):
|
||||
|
||||
r = ObjectEmbeddings(
|
||||
source=source, vectors=vec, name=name, key_name=key_name, id=key
|
||||
metadata=metadata, vectors=vec, name=name, key_name=key_name, id=key
|
||||
)
|
||||
self.vec_prod.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
|
|
@ -150,13 +150,13 @@ class Processor(ConsumerProducer):
|
|||
rows = self.get_rows(chunk)
|
||||
|
||||
self.emit_rows(
|
||||
source=v.source,
|
||||
metadata=v.metadata,
|
||||
rows=rows
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
self.emit_vec(
|
||||
source=v.source, vec=v.vectors,
|
||||
metadata=v.metadata, vec=v.vectors,
|
||||
name=self.schema.name, key_name=self.primary.name,
|
||||
key=row[self.primary.name]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,144 @@ from . schema import triples_response_queue
|
|||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(self, rag, user, collection, verbose):
|
||||
self.rag = rag
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = self.rag.embeddings.request(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
def get_entities(self, query):
|
||||
|
||||
vectors = self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
entities = self.rag.ge_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
vectors=vectors, limit=self.rag.entity_limit,
|
||||
)
|
||||
|
||||
entities = [
|
||||
e.value
|
||||
for e in entities
|
||||
]
|
||||
|
||||
if self.verbose:
|
||||
print("Entities:", flush=True)
|
||||
for ent in entities:
|
||||
print(" ", ent, flush=True)
|
||||
|
||||
return entities
|
||||
|
||||
def maybe_label(self, e):
|
||||
|
||||
if e in self.rag.label_cache:
|
||||
return self.rag.label_cache[e]
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=e, p=LABEL, o=None, limit=1,
|
||||
)
|
||||
|
||||
if len(res) == 0:
|
||||
self.rag.label_cache[e] = e
|
||||
return e
|
||||
|
||||
self.rag.label_cache[e] = res[0].o.value
|
||||
return self.rag.label_cache[e]
|
||||
|
||||
def get_subgraph(self, query):
|
||||
|
||||
entities = self.get_entities(query)
|
||||
|
||||
subgraph = set()
|
||||
|
||||
if self.verbose:
|
||||
print("Get subgraph...", flush=True)
|
||||
|
||||
for e in entities:
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=e, p=None, o=None,
|
||||
limit=self.rag.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=None, p=e, o=None,
|
||||
limit=self.rag.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=None, p=None, o=e,
|
||||
limit=self.rag.query_limit,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
|
||||
subgraph = subgraph[0:self.rag.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
print("Subgraph:", flush=True)
|
||||
for edge in subgraph:
|
||||
print(" ", str(edge), flush=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return subgraph
|
||||
|
||||
def get_labelgraph(self, query):
|
||||
|
||||
subgraph = self.get_subgraph(query)
|
||||
|
||||
sg2 = []
|
||||
|
||||
for edge in subgraph:
|
||||
|
||||
if edge[1] == LABEL:
|
||||
continue
|
||||
|
||||
s = self.maybe_label(edge[0])
|
||||
p = self.maybe_label(edge[1])
|
||||
o = self.maybe_label(edge[2])
|
||||
|
||||
sg2.append((s, p, o))
|
||||
|
||||
return sg2
|
||||
|
||||
class GraphRag:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -94,7 +232,7 @@ class GraphRag:
|
|||
|
||||
self.label_cache = {}
|
||||
|
||||
self.lang = PromptClient(
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
|
|
@ -104,144 +242,23 @@ class GraphRag:
|
|||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = self.embeddings.request(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
def get_entities(self, query):
|
||||
|
||||
vectors = self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
entities = self.ge_client.request(
|
||||
vectors, self.entity_limit
|
||||
)
|
||||
|
||||
entities = [
|
||||
e.value
|
||||
for e in entities
|
||||
]
|
||||
|
||||
if self.verbose:
|
||||
print("Entities:", flush=True)
|
||||
for ent in entities:
|
||||
print(" ", ent, flush=True)
|
||||
|
||||
return entities
|
||||
|
||||
def maybe_label(self, e):
|
||||
|
||||
if e in self.label_cache:
|
||||
return self.label_cache[e]
|
||||
|
||||
res = self.triples_client.request(
|
||||
e, LABEL, None, limit=1
|
||||
)
|
||||
|
||||
if len(res) == 0:
|
||||
self.label_cache[e] = e
|
||||
return e
|
||||
|
||||
self.label_cache[e] = res[0].o.value
|
||||
return self.label_cache[e]
|
||||
|
||||
def get_subgraph(self, query):
|
||||
|
||||
entities = self.get_entities(query)
|
||||
|
||||
subgraph = set()
|
||||
|
||||
if self.verbose:
|
||||
print("Get subgraph...", flush=True)
|
||||
|
||||
for e in entities:
|
||||
|
||||
res = self.triples_client.request(
|
||||
e, None, None,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.triples_client.request(
|
||||
None, e, None,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.triples_client.request(
|
||||
None, None, e,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
|
||||
subgraph = subgraph[0:self.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
print("Subgraph:", flush=True)
|
||||
for edge in subgraph:
|
||||
print(" ", str(edge), flush=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return subgraph
|
||||
|
||||
def get_labelgraph(self, query):
|
||||
|
||||
subgraph = self.get_subgraph(query)
|
||||
|
||||
sg2 = []
|
||||
|
||||
for edge in subgraph:
|
||||
|
||||
if edge[1] == LABEL:
|
||||
continue
|
||||
|
||||
s = self.maybe_label(edge[0])
|
||||
p = self.maybe_label(edge[1])
|
||||
o = self.maybe_label(edge[2])
|
||||
|
||||
sg2.append((s, p, o))
|
||||
|
||||
return sg2
|
||||
|
||||
def query(self, query):
|
||||
def query(self, query, user="trustgraph", collection="default"):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
kg = self.get_labelgraph(query)
|
||||
q = Query(
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose
|
||||
)
|
||||
|
||||
kg = q.get_labelgraph(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(kg)
|
||||
print(query)
|
||||
|
||||
resp = self.lang.request_kg_prompt(query, kg)
|
||||
resp = self.prompt.request_kg_prompt(query, kg)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
|
|
|||
|
|
@ -60,7 +60,10 @@ class Processor(ConsumerProducer):
|
|||
for vec in v.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = "doc_" + str(dim)
|
||||
collection = (
|
||||
"d_" + v.user + "_" + v.collection + "_" +
|
||||
str(dim)
|
||||
)
|
||||
|
||||
search_result = self.client.query_points(
|
||||
collection_name=collection,
|
||||
|
|
|
|||
|
|
@ -66,7 +66,10 @@ class Processor(ConsumerProducer):
|
|||
for vec in v.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = "triples_" + str(dim)
|
||||
collection = (
|
||||
"t_" + v.user + "_" + v.collection + "_" +
|
||||
str(dim)
|
||||
)
|
||||
|
||||
search_result = self.client.query_points(
|
||||
collection_name=collection,
|
||||
|
|
|
|||
|
|
@ -38,7 +38,8 @@ class Processor(ConsumerProducer):
|
|||
}
|
||||
)
|
||||
|
||||
self.tg = TrustGraph([graph_host])
|
||||
self.graph_host = [graph_host]
|
||||
self.table = None
|
||||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
|
|
@ -52,6 +53,15 @@ class Processor(ConsumerProducer):
|
|||
|
||||
v = msg.value()
|
||||
|
||||
table = (v.user, v.collection)
|
||||
|
||||
if table != self.table:
|
||||
self.tg = TrustGraph(
|
||||
hosts=self.graph_host,
|
||||
keyspace=v.user, table=v.collection,
|
||||
)
|
||||
self.table = table
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,9 @@ class Processor(ConsumerProducer):
|
|||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
response = self.rag.query(v.query)
|
||||
response = self.rag.query(
|
||||
query=v.query, user=v.user, collection=v.collection
|
||||
)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = GraphRagResponse(response = response, error=None)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ class Processor(Consumer):
|
|||
)
|
||||
|
||||
self.last_collection = None
|
||||
self.last_dim = None
|
||||
|
||||
self.client = QdrantClient(url=store_uri)
|
||||
|
||||
|
|
@ -52,9 +51,12 @@ class Processor(Consumer):
|
|||
for vec in v.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = "doc_" + str(dim)
|
||||
collection = (
|
||||
"d_" + v.metadata.user + "_" + v.metadata.collection + "_" +
|
||||
str(dim)
|
||||
)
|
||||
|
||||
if dim != self.last_dim:
|
||||
if collection != self.last_collection:
|
||||
|
||||
if not self.client.collection_exists(collection):
|
||||
|
||||
|
|
@ -70,7 +72,6 @@ class Processor(Consumer):
|
|||
raise e
|
||||
|
||||
self.last_collection = collection
|
||||
self.last_dim = dim
|
||||
|
||||
self.client.upsert(
|
||||
collection_name=collection,
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ class Processor(Consumer):
|
|||
)
|
||||
|
||||
self.last_collection = None
|
||||
self.last_dim = None
|
||||
|
||||
self.client = QdrantClient(url=store_uri)
|
||||
|
||||
|
|
@ -50,9 +49,12 @@ class Processor(Consumer):
|
|||
for vec in v.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = "triples_" + str(dim)
|
||||
collection = (
|
||||
"t_" + v.metadata.user + "_" + v.metadata.collection + "_" +
|
||||
str(dim)
|
||||
)
|
||||
|
||||
if dim != self.last_dim:
|
||||
if collection != self.last_collection:
|
||||
|
||||
if not self.client.collection_exists(collection):
|
||||
|
||||
|
|
@ -68,7 +70,6 @@ class Processor(Consumer):
|
|||
raise e
|
||||
|
||||
self.last_collection = collection
|
||||
self.last_dim = dim
|
||||
|
||||
self.client.upsert(
|
||||
collection_name=collection,
|
||||
|
|
|
|||
|
|
@ -38,12 +38,31 @@ class Processor(Consumer):
|
|||
}
|
||||
)
|
||||
|
||||
self.tg = TrustGraph([graph_host])
|
||||
self.graph_host = [graph_host]
|
||||
self.table = None
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
table = (v.metadata.user, v.metadata.collection)
|
||||
|
||||
if self.table is None or self.table != table:
|
||||
|
||||
self.tg = None
|
||||
|
||||
try:
|
||||
self.tg = TrustGraph(
|
||||
hosts=self.graph_host,
|
||||
keyspace=v.metadata.user, table=v.metadata.collection,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Exception", e, flush=True)
|
||||
time.sleep(1)
|
||||
raise e
|
||||
|
||||
self.table = table
|
||||
|
||||
self.tg.insert(
|
||||
v.s.value,
|
||||
v.p.value,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue