mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-09 06:45:13 +02:00
Feature / collections (#96)
* Update schema defs for source -> metadata * Migrate to use metadata part of schema, also add metadata to triples & vecs * Add user/collection metadata to query * Use user/collection in RAG * Write and query working on triples
This commit is contained in:
parent
709221fa10
commit
b0f4c58200
31 changed files with 459 additions and 251 deletions
|
|
@ -38,8 +38,12 @@ class GraphEmbeddingsClient(BaseClient):
|
|||
output_schema=GraphEmbeddingsResponse,
|
||||
)
|
||||
|
||||
def request(self, vectors, limit=10, timeout=300):
|
||||
def request(
|
||||
self, vectors, user="trustgraph", collection="default",
|
||||
limit=10, timeout=300
|
||||
):
|
||||
return self.call(
|
||||
user=user, collection=collection,
|
||||
vectors=vectors, limit=limit, timeout=timeout
|
||||
).entities
|
||||
|
||||
|
|
|
|||
|
|
@ -38,9 +38,12 @@ class GraphRagClient(BaseClient):
|
|||
output_schema=GraphRagResponse,
|
||||
)
|
||||
|
||||
def request(self, query, timeout=500):
|
||||
def request(
|
||||
self, query, user="trustgraph", collection="default",
|
||||
timeout=500
|
||||
):
|
||||
|
||||
return self.call(
|
||||
query=query, timeout=timeout
|
||||
user=user, collection=collection, query=query, timeout=timeout
|
||||
).response
|
||||
|
||||
|
|
|
|||
|
|
@ -48,11 +48,18 @@ class TriplesQueryClient(BaseClient):
|
|||
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
def request(self, s, p, o, limit=10, timeout=60):
|
||||
def request(
|
||||
self,
|
||||
s, p, o,
|
||||
user="trustgraph", collection="default",
|
||||
limit=10, timeout=60,
|
||||
):
|
||||
return self.call(
|
||||
s=self.create_value(s),
|
||||
p=self.create_value(p),
|
||||
o=self.create_value(o),
|
||||
user=user,
|
||||
collection=collection,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
).triples
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from . object import *
|
|||
from . topic import *
|
||||
from . graph import *
|
||||
from . retrieval import *
|
||||
from . metadata import *
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,17 +2,13 @@
|
|||
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double
|
||||
from . topic import topic
|
||||
from . types import Error
|
||||
|
||||
class Source(Record):
|
||||
source = String()
|
||||
id = String()
|
||||
title = String()
|
||||
from . metadata import Metadata
|
||||
|
||||
############################################################################
|
||||
|
||||
# PDF docs etc.
|
||||
class Document(Record):
|
||||
source = Source()
|
||||
metadata = Metadata()
|
||||
data = Bytes()
|
||||
|
||||
document_ingest_queue = topic('document-load')
|
||||
|
|
@ -22,7 +18,7 @@ document_ingest_queue = topic('document-load')
|
|||
# Text documents / text from PDF
|
||||
|
||||
class TextDocument(Record):
|
||||
source = Source()
|
||||
metadata = Metadata()
|
||||
text = Bytes()
|
||||
|
||||
text_ingest_queue = topic('text-document-load')
|
||||
|
|
@ -32,7 +28,7 @@ text_ingest_queue = topic('text-document-load')
|
|||
# Chunks of text
|
||||
|
||||
class Chunk(Record):
|
||||
source = Source()
|
||||
metadata = Metadata()
|
||||
chunk = Bytes()
|
||||
|
||||
chunk_ingest_queue = topic('chunk-load')
|
||||
|
|
@ -42,7 +38,7 @@ chunk_ingest_queue = topic('chunk-load')
|
|||
# Chunk embeddings are an embeddings associated with a text chunk
|
||||
|
||||
class ChunkEmbeddings(Record):
|
||||
source = Source()
|
||||
metadata = Metadata()
|
||||
vectors = Array(Array(Double()))
|
||||
chunk = Bytes()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
|
||||
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double
|
||||
|
||||
from . documents import Source
|
||||
from . types import Error, Value
|
||||
from . topic import topic
|
||||
from . metadata import Metadata
|
||||
|
||||
############################################################################
|
||||
|
||||
# Graph embeddings are embeddings associated with a graph entity
|
||||
|
||||
class GraphEmbeddings(Record):
|
||||
source = Source()
|
||||
metadata = Metadata()
|
||||
vectors = Array(Array(Double()))
|
||||
entity = Value()
|
||||
|
||||
|
|
@ -23,6 +23,8 @@ graph_embeddings_store_queue = topic('graph-embeddings-store')
|
|||
class GraphEmbeddingsRequest(Record):
|
||||
vectors = Array(Array(Double()))
|
||||
limit = Integer()
|
||||
user = String()
|
||||
collection = String()
|
||||
|
||||
class GraphEmbeddingsResponse(Record):
|
||||
error = Error()
|
||||
|
|
@ -40,7 +42,7 @@ graph_embeddings_response_queue = topic(
|
|||
# Graph triples
|
||||
|
||||
class Triple(Record):
|
||||
source = Source()
|
||||
metadata = Metadata()
|
||||
s = Value()
|
||||
p = Value()
|
||||
o = Value()
|
||||
|
|
@ -56,6 +58,8 @@ class TriplesQueryRequest(Record):
|
|||
p = Value()
|
||||
o = Value()
|
||||
limit = Integer()
|
||||
user = String()
|
||||
collection = String()
|
||||
|
||||
class TriplesQueryResponse(Record):
|
||||
error = Error()
|
||||
|
|
|
|||
10
trustgraph-base/trustgraph/schema/metadata.py
Normal file
10
trustgraph-base/trustgraph/schema/metadata.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
|
||||
from pulsar.schema import Record, String
|
||||
|
||||
class Metadata(Record):
|
||||
source = String()
|
||||
id = String()
|
||||
title = String()
|
||||
user = String()
|
||||
collection = String()
|
||||
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array
|
||||
from pulsar.schema import Double, Map
|
||||
|
||||
from . documents import Source
|
||||
from . metadata import Metadata
|
||||
from . types import Value, RowSchema
|
||||
from . topic import topic
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ from . topic import topic
|
|||
# object
|
||||
|
||||
class ObjectEmbeddings(Record):
|
||||
source = Source()
|
||||
metadata = Metadata()
|
||||
vectors = Array(Array(Double()))
|
||||
name = String()
|
||||
key_name = String()
|
||||
|
|
@ -25,7 +25,7 @@ object_embeddings_store_queue = topic('object-embeddings-store')
|
|||
# Stores rows of information
|
||||
|
||||
class Rows(Record):
|
||||
source = Source()
|
||||
metadata = Metadata()
|
||||
row_schema = RowSchema()
|
||||
rows = Array(Map(String()))
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from . types import Error, Value
|
|||
|
||||
class GraphRagQuery(Record):
|
||||
query = String()
|
||||
user = String()
|
||||
collection = String()
|
||||
|
||||
class GraphRagResponse(Record):
|
||||
error = Error()
|
||||
|
|
@ -27,6 +29,8 @@ graph_rag_response_queue = topic(
|
|||
|
||||
class DocumentRagQuery(Record):
|
||||
query = String()
|
||||
user = String()
|
||||
collection = String()
|
||||
|
||||
class DocumentRagResponse(Record):
|
||||
error = Error()
|
||||
|
|
|
|||
|
|
@ -9,12 +9,17 @@ import os
|
|||
from trustgraph.clients.triples_query_client import TriplesQueryClient
|
||||
|
||||
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
def show_graph(pulsar):
|
||||
def show_graph(pulsar, user, collection):
|
||||
|
||||
tq = TriplesQueryClient(pulsar_host=pulsar)
|
||||
|
||||
rows = tq.request(None, None, None, limit=10_000_000)
|
||||
rows = tq.request(
|
||||
user=user, collection=collection,
|
||||
s=None, p=None, o=None, limit=10_000_000
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
print(row.s.value, row.p.value, row.o.value)
|
||||
|
|
@ -32,11 +37,26 @@ def main():
|
|||
help=f'Pulsar host (default: {default_pulsar_host})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--user',
|
||||
default=default_user,
|
||||
help=f'User ID (default: {default_user})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--collection',
|
||||
default=default_collection,
|
||||
help=f'Collection ID (default: {default_collection})'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
show_graph(args.pulsar_host)
|
||||
show_graph(
|
||||
pulsar=args.pulsar_host, user=args.user,
|
||||
collection=args.collection,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ Loads a PDF document into TrustGraph processing.
|
|||
|
||||
import pulsar
|
||||
from pulsar.schema import JsonSchema
|
||||
from trustgraph.schema import Document, Source, document_ingest_queue
|
||||
from trustgraph.schema import Document, document_ingest_queue, Metadata
|
||||
import base64
|
||||
import hashlib
|
||||
import argparse
|
||||
|
|
@ -15,12 +15,17 @@ import time
|
|||
|
||||
from trustgraph.log_level import LogLevel
|
||||
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
class Loader:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host,
|
||||
output_queue,
|
||||
user,
|
||||
collection,
|
||||
log_level,
|
||||
):
|
||||
|
||||
|
|
@ -35,6 +40,9 @@ class Loader:
|
|||
chunking_enabled=True,
|
||||
)
|
||||
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
|
||||
def load(self, files):
|
||||
|
||||
for file in files:
|
||||
|
|
@ -50,10 +58,12 @@ class Loader:
|
|||
id = hashlib.sha256(path.encode("utf-8")).hexdigest()[0:8]
|
||||
|
||||
r = Document(
|
||||
source=Source(
|
||||
metadata=Metadata(
|
||||
source=path,
|
||||
title=path,
|
||||
id=id,
|
||||
user=self.user,
|
||||
collection=self.collection,
|
||||
),
|
||||
data=base64.b64encode(data),
|
||||
)
|
||||
|
|
@ -90,6 +100,18 @@ def main():
|
|||
help=f'Output queue (default: {default_output_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--user',
|
||||
default=default_user,
|
||||
help=f'User ID (default: {default_user})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--collection',
|
||||
default=default_collection,
|
||||
help=f'Collection ID (default: {default_collection})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--log-level',
|
||||
type=LogLevel,
|
||||
|
|
@ -112,6 +134,8 @@ def main():
|
|||
p = Loader(
|
||||
pulsar_host=args.pulsar_host,
|
||||
output_queue=args.output_queue,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
log_level=args.log_level,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ Loads a text document into TrustGraph processing.
|
|||
|
||||
import pulsar
|
||||
from pulsar.schema import JsonSchema
|
||||
from trustgraph.schema import TextDocument, Source, text_ingest_queue
|
||||
from trustgraph.schema import TextDocument, text_ingest_queue, Metadata
|
||||
import base64
|
||||
import hashlib
|
||||
import argparse
|
||||
|
|
@ -15,12 +15,17 @@ import time
|
|||
|
||||
from trustgraph.log_level import LogLevel
|
||||
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
class Loader:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host,
|
||||
output_queue,
|
||||
user,
|
||||
collection,
|
||||
log_level,
|
||||
):
|
||||
|
||||
|
|
@ -35,6 +40,9 @@ class Loader:
|
|||
chunking_enabled=True,
|
||||
)
|
||||
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
|
||||
def load(self, files):
|
||||
|
||||
for file in files:
|
||||
|
|
@ -50,10 +58,12 @@ class Loader:
|
|||
id = hashlib.sha256(path.encode("utf-8")).hexdigest()[0:8]
|
||||
|
||||
r = TextDocument(
|
||||
source=Source(
|
||||
metadata=Metadata(
|
||||
source=path,
|
||||
title=path,
|
||||
id=id,
|
||||
user=self.user,
|
||||
collection=self.collection,
|
||||
),
|
||||
text=data,
|
||||
)
|
||||
|
|
@ -90,6 +100,18 @@ def main():
|
|||
help=f'Output queue (default: {default_output_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--user',
|
||||
default=default_user,
|
||||
help=f'User ID (default: {default_user})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--collection',
|
||||
default=default_collection,
|
||||
help=f'Collection ID (default: {default_collection})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--log-level',
|
||||
type=LogLevel,
|
||||
|
|
@ -112,6 +134,8 @@ def main():
|
|||
p = Loader(
|
||||
pulsar_host=args.pulsar_host,
|
||||
output_queue=args.output_queue,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
log_level=args.log_level,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,17 +9,19 @@ import os
|
|||
from trustgraph.clients.document_rag_client import DocumentRagClient
|
||||
|
||||
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
def query(pulsar, query):
|
||||
def query(pulsar_host, query, user, collection):
|
||||
|
||||
rag = DocumentRagClient(pulsar_host=pulsar)
|
||||
resp = rag.request(query)
|
||||
resp = rag.request(user=user, collection=collection, query=query)
|
||||
print(resp)
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='graph-show',
|
||||
prog='tg-query-document-rag',
|
||||
description=__doc__,
|
||||
)
|
||||
|
||||
|
|
@ -35,11 +37,28 @@ def main():
|
|||
help=f'Query to execute',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--user',
|
||||
default=default_user,
|
||||
help=f'User ID (default: {default_user})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--collection',
|
||||
default=default_collection,
|
||||
help=f'Collection ID (default: {default_collection})'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
query(args.pulsar_host, args.query)
|
||||
query(
|
||||
pulsar_host=args.pulsar_host,
|
||||
query=args.query,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
|
|
@ -9,17 +9,19 @@ import os
|
|||
from trustgraph.clients.graph_rag_client import GraphRagClient
|
||||
|
||||
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://localhost:6650')
|
||||
default_user = 'trustgraph'
|
||||
default_collection = 'default'
|
||||
|
||||
def query(pulsar, query):
|
||||
def query(pulsar_host, query, user, collection):
|
||||
|
||||
rag = GraphRagClient(pulsar_host=pulsar)
|
||||
resp = rag.request(query)
|
||||
rag = GraphRagClient(pulsar_host=pulsar_host)
|
||||
resp = rag.request(user=user, collection=collection, query=query)
|
||||
print(resp)
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='graph-show',
|
||||
prog='tg-graph-query-rag',
|
||||
description=__doc__,
|
||||
)
|
||||
|
||||
|
|
@ -35,11 +37,28 @@ def main():
|
|||
help=f'Query to execute',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-u', '--user',
|
||||
default=default_user,
|
||||
help=f'User ID (default: {default_user})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--collection',
|
||||
default=default_collection,
|
||||
help=f'Collection ID (default: {default_collection})'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
|
||||
query(args.pulsar_host, args.query)
|
||||
query(
|
||||
pulsar_host=args.pulsar_host,
|
||||
query=args.query,
|
||||
user=args.user,
|
||||
collection=args.collection,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ as text as separate output objects.
|
|||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk, Source
|
||||
from ... schema import TextDocument, Chunk, Metadata
|
||||
from ... schema import text_ingest_queue, chunk_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
|
@ -55,7 +55,7 @@ class Processor(ConsumerProducer):
|
|||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Chunking {v.source.id}...", flush=True)
|
||||
print(f"Chunking {v.metadata.id}...", flush=True)
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
|
|
@ -63,13 +63,15 @@ class Processor(ConsumerProducer):
|
|||
|
||||
for ix, chunk in enumerate(texts):
|
||||
|
||||
id = v.source.id + "-c" + str(ix)
|
||||
id = v.metadata.id + "-c" + str(ix)
|
||||
|
||||
r = Chunk(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
metadata=Metadata(
|
||||
source=v.metadata.source,
|
||||
id=id,
|
||||
title=v.source.title
|
||||
title=v.metadata.title,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ as text as separate output objects.
|
|||
from langchain_text_splitters import TokenTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk, Source
|
||||
from ... schema import TextDocument, Chunk, Metadata
|
||||
from ... schema import text_ingest_queue, chunk_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
|
@ -54,7 +54,7 @@ class Processor(ConsumerProducer):
|
|||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Chunking {v.source.id}...", flush=True)
|
||||
print(f"Chunking {v.metadata.id}...", flush=True)
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
|
|
@ -62,13 +62,15 @@ class Processor(ConsumerProducer):
|
|||
|
||||
for ix, chunk in enumerate(texts):
|
||||
|
||||
id = v.source.id + "-c" + str(ix)
|
||||
id = v.metadata.id + "-c" + str(ix)
|
||||
|
||||
r = Chunk(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
metadata=Metadata(
|
||||
source=v.metadata.source,
|
||||
id=id,
|
||||
title=v.source.title
|
||||
title=v.metadata.title,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import tempfile
|
|||
import base64
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
|
||||
from ... schema import Document, TextDocument, Source
|
||||
from ... schema import Document, TextDocument, Metadata
|
||||
from ... schema import document_ingest_queue, text_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
|
@ -45,7 +45,7 @@ class Processor(ConsumerProducer):
|
|||
|
||||
v = msg.value()
|
||||
|
||||
print(f"Decoding {v.source.id}...", flush=True)
|
||||
print(f"Decoding {v.metadata.id}...", flush=True)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
|
||||
|
||||
|
|
@ -59,12 +59,14 @@ class Processor(ConsumerProducer):
|
|||
|
||||
for ix, page in enumerate(pages):
|
||||
|
||||
id = v.source.id + "-p" + str(ix)
|
||||
id = v.metadata.id + "-p" + str(ix)
|
||||
r = TextDocument(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
title=v.source.title,
|
||||
metadata=Metadata(
|
||||
source=v.metadata.source,
|
||||
title=v.metadata.title,
|
||||
id=id,
|
||||
user=v.metadata.user,
|
||||
collection=v.metadata.collection,
|
||||
),
|
||||
text=page.page_content.encode("utf-8"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,10 +4,16 @@ from cassandra.auth import PlainTextAuthProvider
|
|||
|
||||
class TrustGraph:
|
||||
|
||||
def __init__(self, hosts=None):
|
||||
def __init__(
|
||||
self, hosts=None,
|
||||
keyspace="trustgraph", table="default",
|
||||
):
|
||||
|
||||
if hosts is None:
|
||||
hosts = ["localhost"]
|
||||
|
||||
self.keyspace = keyspace
|
||||
self.table = table
|
||||
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
|
@ -16,26 +22,26 @@ class TrustGraph:
|
|||
|
||||
def clear(self):
|
||||
|
||||
self.session.execute("""
|
||||
drop keyspace if exists trustgraph;
|
||||
self.session.execute(f"""
|
||||
drop keyspace if exists {self.keyspace};
|
||||
""");
|
||||
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
|
||||
self.session.execute("""
|
||||
create keyspace if not exists trustgraph
|
||||
with replication = {
|
||||
self.session.execute(f"""
|
||||
create keyspace if not exists {self.keyspace}
|
||||
with replication = {{
|
||||
'class' : 'SimpleStrategy',
|
||||
'replication_factor' : 1
|
||||
};
|
||||
}};
|
||||
""");
|
||||
|
||||
self.session.set_keyspace('trustgraph')
|
||||
self.session.set_keyspace(self.keyspace)
|
||||
|
||||
self.session.execute("""
|
||||
create table if not exists triples (
|
||||
self.session.execute(f"""
|
||||
create table if not exists {self.table} (
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
|
|
@ -43,66 +49,66 @@ class TrustGraph:
|
|||
);
|
||||
""");
|
||||
|
||||
self.session.execute("""
|
||||
create index if not exists triples_p
|
||||
ON triples (p);
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_p
|
||||
ON {self.table} (p);
|
||||
""");
|
||||
|
||||
self.session.execute("""
|
||||
create index if not exists triples_o
|
||||
ON triples (o);
|
||||
self.session.execute(f"""
|
||||
create index if not exists {self.table}_o
|
||||
ON {self.table} (o);
|
||||
""");
|
||||
|
||||
def insert(self, s, p, o):
|
||||
|
||||
self.session.execute(
|
||||
"insert into triples (s, p, o) values (%s, %s, %s)",
|
||||
f"insert into {self.table} (s, p, o) values (%s, %s, %s)",
|
||||
(s, p, o)
|
||||
)
|
||||
|
||||
def get_all(self, limit=50):
|
||||
return self.session.execute(
|
||||
f"select s, p, o from triples limit {limit}"
|
||||
f"select s, p, o from {self.table} limit {limit}"
|
||||
)
|
||||
|
||||
def get_s(self, s, limit=10):
|
||||
return self.session.execute(
|
||||
f"select p, o from triples where s = %s limit {limit}",
|
||||
f"select p, o from {self.table} where s = %s limit {limit}",
|
||||
(s,)
|
||||
)
|
||||
|
||||
def get_p(self, p, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s, o from triples where p = %s limit {limit}",
|
||||
f"select s, o from {self.table} where p = %s limit {limit}",
|
||||
(p,)
|
||||
)
|
||||
|
||||
def get_o(self, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s, p from triples where o = %s limit {limit}",
|
||||
f"select s, p from {self.table} where o = %s limit {limit}",
|
||||
(o,)
|
||||
)
|
||||
|
||||
def get_sp(self, s, p, limit=10):
|
||||
return self.session.execute(
|
||||
f"select o from triples where s = %s and p = %s limit {limit}",
|
||||
f"select o from {self.table} where s = %s and p = %s limit {limit}",
|
||||
(s, p)
|
||||
)
|
||||
|
||||
def get_po(self, p, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s from triples where p = %s and o = %s allow filtering limit {limit}",
|
||||
f"select s from {self.table} where p = %s and o = %s allow filtering limit {limit}",
|
||||
(p, o)
|
||||
)
|
||||
|
||||
def get_os(self, o, s, limit=10):
|
||||
return self.session.execute(
|
||||
f"select p from triples where o = %s and s = %s limit {limit}",
|
||||
f"select p from {self.table} where o = %s and s = %s limit {limit}",
|
||||
(o, s)
|
||||
)
|
||||
|
||||
def get_spo(self, s, p, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"""select s as x from triples where s = %s and p = %s and o = %s limit {limit}""",
|
||||
f"""select s as x from {self.table} where s = %s and p = %s and o = %s limit {limit}""",
|
||||
(s, p, o)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -50,15 +50,15 @@ class Processor(ConsumerProducer):
|
|||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
def emit(self, source, chunk, vectors):
|
||||
def emit(self, metadata, chunk, vectors):
|
||||
|
||||
r = ChunkEmbeddings(source=source, chunk=chunk, vectors=vectors)
|
||||
r = ChunkEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors)
|
||||
self.producer.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ class Processor(ConsumerProducer):
|
|||
vectors = self.embeddings.request(chunk)
|
||||
|
||||
self.emit(
|
||||
source=v.source,
|
||||
metadata=v.metadata,
|
||||
chunk=chunk.encode("utf-8"),
|
||||
vectors=vectors
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ get entity definitions which are output as graph edges.
|
|||
import urllib.parse
|
||||
import json
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Source, Value
|
||||
from .... schema import ChunkEmbeddings, Triple, Metadata, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
|
|
@ -69,15 +69,15 @@ class Processor(ConsumerProducer):
|
|||
|
||||
return self.prompt.request_definitions(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
def emit_edge(self, metadata, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
t = Triple(metadata=metadata, s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ class Processor(ConsumerProducer):
|
|||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
|
||||
self.emit_edge(v.metadata, s_value, DEFINITION_VALUE, o_value)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ import urllib.parse
|
|||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
|
||||
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings
|
||||
from .... schema import Metadata, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import graph_embeddings_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
|
|
@ -91,20 +92,20 @@ class Processor(ConsumerProducer):
|
|||
|
||||
return self.prompt.request_relationships(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
def emit_edge(self, metadata, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
t = Triple(metadata=metadata, s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_vec(self, ent, vec):
|
||||
def emit_vec(self, metadata, ent, vec):
|
||||
|
||||
r = GraphEmbeddings(entity=ent, vectors=vec)
|
||||
r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec)
|
||||
self.vec_prod.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
|
|
@ -139,6 +140,7 @@ class Processor(ConsumerProducer):
|
|||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(
|
||||
v.metadata,
|
||||
s_value,
|
||||
p_value,
|
||||
o_value
|
||||
|
|
@ -146,6 +148,7 @@ class Processor(ConsumerProducer):
|
|||
|
||||
# Label for s
|
||||
self.emit_edge(
|
||||
v.metadata,
|
||||
s_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(s), is_uri=False)
|
||||
|
|
@ -153,6 +156,7 @@ class Processor(ConsumerProducer):
|
|||
|
||||
# Label for p
|
||||
self.emit_edge(
|
||||
v.metadata,
|
||||
p_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(p), is_uri=False)
|
||||
|
|
@ -161,15 +165,16 @@ class Processor(ConsumerProducer):
|
|||
if rel.o_entity:
|
||||
# Label for o
|
||||
self.emit_edge(
|
||||
v.metadata,
|
||||
o_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(o), is_uri=False)
|
||||
)
|
||||
|
||||
self.emit_vec(s_value, v.vectors)
|
||||
self.emit_vec(p_value, v.vectors)
|
||||
self.emit_vec(v.metadata, s_value, v.vectors)
|
||||
self.emit_vec(v.metadata, p_value, v.vectors)
|
||||
if rel.o_entity:
|
||||
self.emit_vec(o_value, v.vectors)
|
||||
self.emit_vec(v.metadata, o_value, v.vectors)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ get entity definitions which are output as graph edges.
|
|||
import urllib.parse
|
||||
import json
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Source, Value
|
||||
from .... schema import ChunkEmbeddings, Triple, Metadata, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
|
|
@ -69,15 +69,15 @@ class Processor(ConsumerProducer):
|
|||
|
||||
return self.prompt.request_topics(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
def emit_edge(self, metadata, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
t = Triple(metadata=metadata, s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
|
|
@ -101,7 +101,7 @@ class Processor(ConsumerProducer):
|
|||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
|
||||
self.emit_edge(v. metadata, s_value, DEFINITION_VALUE, o_value)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import urllib.parse
|
|||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Source
|
||||
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Metadata
|
||||
from .... schema import RowSchema, Field
|
||||
from .... schema import chunk_embeddings_ingest_queue, rows_store_queue
|
||||
from .... schema import object_embeddings_store_queue
|
||||
|
|
@ -124,24 +124,24 @@ class Processor(ConsumerProducer):
|
|||
def get_rows(self, chunk):
|
||||
return self.prompt.request_rows(self.schema, chunk)
|
||||
|
||||
def emit_rows(self, source, rows):
|
||||
def emit_rows(self, metadata, rows):
|
||||
|
||||
t = Rows(
|
||||
source=source, row_schema=self.row_schema, rows=rows
|
||||
metadata=metadata, row_schema=self.row_schema, rows=rows
|
||||
)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_vec(self, source, name, vec, key_name, key):
|
||||
def emit_vec(self, metadata, name, vec, key_name, key):
|
||||
|
||||
r = ObjectEmbeddings(
|
||||
source=source, vectors=vec, name=name, key_name=key_name, id=key
|
||||
metadata=metadata, vectors=vec, name=name, key_name=key_name, id=key
|
||||
)
|
||||
self.vec_prod.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
|
|
@ -150,13 +150,13 @@ class Processor(ConsumerProducer):
|
|||
rows = self.get_rows(chunk)
|
||||
|
||||
self.emit_rows(
|
||||
source=v.source,
|
||||
metadata=v.metadata,
|
||||
rows=rows
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
self.emit_vec(
|
||||
source=v.source, vec=v.vectors,
|
||||
metadata=v.metadata, vec=v.vectors,
|
||||
name=self.schema.name, key_name=self.primary.name,
|
||||
key=row[self.primary.name]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,144 @@ from . schema import triples_response_queue
|
|||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
|
||||
|
||||
class Query:
|
||||
|
||||
def __init__(self, rag, user, collection, verbose):
|
||||
self.rag = rag
|
||||
self.user = user
|
||||
self.collection = collection
|
||||
self.verbose = verbose
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = self.rag.embeddings.request(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
def get_entities(self, query):
|
||||
|
||||
vectors = self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
entities = self.rag.ge_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
vectors=vectors, limit=self.rag.entity_limit,
|
||||
)
|
||||
|
||||
entities = [
|
||||
e.value
|
||||
for e in entities
|
||||
]
|
||||
|
||||
if self.verbose:
|
||||
print("Entities:", flush=True)
|
||||
for ent in entities:
|
||||
print(" ", ent, flush=True)
|
||||
|
||||
return entities
|
||||
|
||||
def maybe_label(self, e):
|
||||
|
||||
if e in self.rag.label_cache:
|
||||
return self.rag.label_cache[e]
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=e, p=LABEL, o=None, limit=1,
|
||||
)
|
||||
|
||||
if len(res) == 0:
|
||||
self.rag.label_cache[e] = e
|
||||
return e
|
||||
|
||||
self.rag.label_cache[e] = res[0].o.value
|
||||
return self.rag.label_cache[e]
|
||||
|
||||
def get_subgraph(self, query):
|
||||
|
||||
entities = self.get_entities(query)
|
||||
|
||||
subgraph = set()
|
||||
|
||||
if self.verbose:
|
||||
print("Get subgraph...", flush=True)
|
||||
|
||||
for e in entities:
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=e, p=None, o=None,
|
||||
limit=self.rag.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=None, p=e, o=None,
|
||||
limit=self.rag.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.rag.triples_client.request(
|
||||
user=self.user, collection=self.collection,
|
||||
s=None, p=None, o=e,
|
||||
limit=self.rag.query_limit,
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
|
||||
subgraph = subgraph[0:self.rag.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
print("Subgraph:", flush=True)
|
||||
for edge in subgraph:
|
||||
print(" ", str(edge), flush=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return subgraph
|
||||
|
||||
def get_labelgraph(self, query):
|
||||
|
||||
subgraph = self.get_subgraph(query)
|
||||
|
||||
sg2 = []
|
||||
|
||||
for edge in subgraph:
|
||||
|
||||
if edge[1] == LABEL:
|
||||
continue
|
||||
|
||||
s = self.maybe_label(edge[0])
|
||||
p = self.maybe_label(edge[1])
|
||||
o = self.maybe_label(edge[2])
|
||||
|
||||
sg2.append((s, p, o))
|
||||
|
||||
return sg2
|
||||
|
||||
class GraphRag:
|
||||
|
||||
def __init__(
|
||||
|
|
@ -94,7 +232,7 @@ class GraphRag:
|
|||
|
||||
self.label_cache = {}
|
||||
|
||||
self.lang = PromptClient(
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
|
|
@ -104,144 +242,23 @@ class GraphRag:
|
|||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = self.embeddings.request(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
def get_entities(self, query):
|
||||
|
||||
vectors = self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
entities = self.ge_client.request(
|
||||
vectors, self.entity_limit
|
||||
)
|
||||
|
||||
entities = [
|
||||
e.value
|
||||
for e in entities
|
||||
]
|
||||
|
||||
if self.verbose:
|
||||
print("Entities:", flush=True)
|
||||
for ent in entities:
|
||||
print(" ", ent, flush=True)
|
||||
|
||||
return entities
|
||||
|
||||
def maybe_label(self, e):
|
||||
|
||||
if e in self.label_cache:
|
||||
return self.label_cache[e]
|
||||
|
||||
res = self.triples_client.request(
|
||||
e, LABEL, None, limit=1
|
||||
)
|
||||
|
||||
if len(res) == 0:
|
||||
self.label_cache[e] = e
|
||||
return e
|
||||
|
||||
self.label_cache[e] = res[0].o.value
|
||||
return self.label_cache[e]
|
||||
|
||||
def get_subgraph(self, query):
|
||||
|
||||
entities = self.get_entities(query)
|
||||
|
||||
subgraph = set()
|
||||
|
||||
if self.verbose:
|
||||
print("Get subgraph...", flush=True)
|
||||
|
||||
for e in entities:
|
||||
|
||||
res = self.triples_client.request(
|
||||
e, None, None,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.triples_client.request(
|
||||
None, e, None,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.triples_client.request(
|
||||
None, None, e,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
|
||||
subgraph = subgraph[0:self.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
print("Subgraph:", flush=True)
|
||||
for edge in subgraph:
|
||||
print(" ", str(edge), flush=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return subgraph
|
||||
|
||||
def get_labelgraph(self, query):
|
||||
|
||||
subgraph = self.get_subgraph(query)
|
||||
|
||||
sg2 = []
|
||||
|
||||
for edge in subgraph:
|
||||
|
||||
if edge[1] == LABEL:
|
||||
continue
|
||||
|
||||
s = self.maybe_label(edge[0])
|
||||
p = self.maybe_label(edge[1])
|
||||
o = self.maybe_label(edge[2])
|
||||
|
||||
sg2.append((s, p, o))
|
||||
|
||||
return sg2
|
||||
|
||||
def query(self, query):
|
||||
def query(self, query, user="trustgraph", collection="default"):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
kg = self.get_labelgraph(query)
|
||||
q = Query(
|
||||
rag=self, user=user, collection=collection, verbose=self.verbose
|
||||
)
|
||||
|
||||
kg = q.get_labelgraph(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(kg)
|
||||
print(query)
|
||||
|
||||
resp = self.lang.request_kg_prompt(query, kg)
|
||||
resp = self.prompt.request_kg_prompt(query, kg)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
|
|
|||
|
|
@ -60,7 +60,10 @@ class Processor(ConsumerProducer):
|
|||
for vec in v.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = "doc_" + str(dim)
|
||||
collection = (
|
||||
"d_" + v.user + "_" + v.collection + "_" +
|
||||
str(dim)
|
||||
)
|
||||
|
||||
search_result = self.client.query_points(
|
||||
collection_name=collection,
|
||||
|
|
|
|||
|
|
@ -66,7 +66,10 @@ class Processor(ConsumerProducer):
|
|||
for vec in v.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = "triples_" + str(dim)
|
||||
collection = (
|
||||
"t_" + v.user + "_" + v.collection + "_" +
|
||||
str(dim)
|
||||
)
|
||||
|
||||
search_result = self.client.query_points(
|
||||
collection_name=collection,
|
||||
|
|
|
|||
|
|
@ -38,7 +38,8 @@ class Processor(ConsumerProducer):
|
|||
}
|
||||
)
|
||||
|
||||
self.tg = TrustGraph([graph_host])
|
||||
self.graph_host = [graph_host]
|
||||
self.table = None
|
||||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
|
|
@ -52,6 +53,15 @@ class Processor(ConsumerProducer):
|
|||
|
||||
v = msg.value()
|
||||
|
||||
table = (v.user, v.collection)
|
||||
|
||||
if table != self.table:
|
||||
self.tg = TrustGraph(
|
||||
hosts=self.graph_host,
|
||||
keyspace=v.user, table=v.collection,
|
||||
)
|
||||
self.table = table
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,9 @@ class Processor(ConsumerProducer):
|
|||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
response = self.rag.query(v.query)
|
||||
response = self.rag.query(
|
||||
query=v.query, user=v.user, collection=v.collection
|
||||
)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = GraphRagResponse(response = response, error=None)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ class Processor(Consumer):
|
|||
)
|
||||
|
||||
self.last_collection = None
|
||||
self.last_dim = None
|
||||
|
||||
self.client = QdrantClient(url=store_uri)
|
||||
|
||||
|
|
@ -52,9 +51,12 @@ class Processor(Consumer):
|
|||
for vec in v.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = "doc_" + str(dim)
|
||||
collection = (
|
||||
"d_" + v.metadata.user + "_" + v.metadata.collection + "_" +
|
||||
str(dim)
|
||||
)
|
||||
|
||||
if dim != self.last_dim:
|
||||
if collection != self.last_collection:
|
||||
|
||||
if not self.client.collection_exists(collection):
|
||||
|
||||
|
|
@ -70,7 +72,6 @@ class Processor(Consumer):
|
|||
raise e
|
||||
|
||||
self.last_collection = collection
|
||||
self.last_dim = dim
|
||||
|
||||
self.client.upsert(
|
||||
collection_name=collection,
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ class Processor(Consumer):
|
|||
)
|
||||
|
||||
self.last_collection = None
|
||||
self.last_dim = None
|
||||
|
||||
self.client = QdrantClient(url=store_uri)
|
||||
|
||||
|
|
@ -50,9 +49,12 @@ class Processor(Consumer):
|
|||
for vec in v.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = "triples_" + str(dim)
|
||||
collection = (
|
||||
"t_" + v.metadata.user + "_" + v.metadata.collection + "_" +
|
||||
str(dim)
|
||||
)
|
||||
|
||||
if dim != self.last_dim:
|
||||
if collection != self.last_collection:
|
||||
|
||||
if not self.client.collection_exists(collection):
|
||||
|
||||
|
|
@ -68,7 +70,6 @@ class Processor(Consumer):
|
|||
raise e
|
||||
|
||||
self.last_collection = collection
|
||||
self.last_dim = dim
|
||||
|
||||
self.client.upsert(
|
||||
collection_name=collection,
|
||||
|
|
|
|||
|
|
@ -38,12 +38,31 @@ class Processor(Consumer):
|
|||
}
|
||||
)
|
||||
|
||||
self.tg = TrustGraph([graph_host])
|
||||
self.graph_host = [graph_host]
|
||||
self.table = None
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
table = (v.metadata.user, v.metadata.collection)
|
||||
|
||||
if self.table is None or self.table != table:
|
||||
|
||||
self.tg = None
|
||||
|
||||
try:
|
||||
self.tg = TrustGraph(
|
||||
hosts=self.graph_host,
|
||||
keyspace=v.metadata.user, table=v.metadata.collection,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Exception", e, flush=True)
|
||||
time.sleep(1)
|
||||
raise e
|
||||
|
||||
self.table = table
|
||||
|
||||
self.tg.insert(
|
||||
v.s.value,
|
||||
v.p.value,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue