mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 02:23:44 +02:00
Feature/refactor entity embeddings (#235)
* Make schema changes * Core entity context flow in place * extract-def outputs entity contexts * Refactored qdrant write * Refactoring of all vector stores in place
This commit is contained in:
parent
9942f63773
commit
a458d57af2
9 changed files with 230 additions and 169 deletions
|
|
@ -35,17 +35,6 @@ chunk_ingest_queue = topic('chunk-load')
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
# Chunk embeddings are an embeddings associated with a text chunk
|
|
||||||
|
|
||||||
class ChunkEmbeddings(Record):
|
|
||||||
metadata = Metadata()
|
|
||||||
vectors = Array(Array(Double()))
|
|
||||||
chunk = Bytes()
|
|
||||||
|
|
||||||
chunk_embeddings_ingest_queue = topic('chunk-embeddings-load')
|
|
||||||
|
|
||||||
############################################################################
|
|
||||||
|
|
||||||
# Doc embeddings query
|
# Doc embeddings query
|
||||||
|
|
||||||
class DocumentEmbeddingsRequest(Record):
|
class DocumentEmbeddingsRequest(Record):
|
||||||
|
|
@ -62,3 +51,4 @@ document_embeddings_request_queue = topic(
|
||||||
document_embeddings_response_queue = topic(
|
document_embeddings_response_queue = topic(
|
||||||
'doc-embeddings', kind='non-persistent', namespace='response',
|
'doc-embeddings', kind='non-persistent', namespace='response',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,31 @@ from . metadata import Metadata
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
|
|
||||||
|
# Entity context are an entity associated with textual context
|
||||||
|
|
||||||
|
class EntityContext(Record):
|
||||||
|
entity = Value()
|
||||||
|
context = String()
|
||||||
|
|
||||||
|
# This is a 'batching' mechanism for the above data
|
||||||
|
class EntityContexts(Record):
|
||||||
|
metadata = Metadata()
|
||||||
|
entities = Array(EntityContext())
|
||||||
|
|
||||||
|
entity_contexts_ingest_queue = topic('entity-contexts-load')
|
||||||
|
|
||||||
|
############################################################################
|
||||||
|
|
||||||
# Graph embeddings are embeddings associated with a graph entity
|
# Graph embeddings are embeddings associated with a graph entity
|
||||||
|
|
||||||
|
class EntityEmbeddings(Record):
|
||||||
|
entity = Value()
|
||||||
|
vectors = Array(Array(Double()))
|
||||||
|
|
||||||
|
# This is a 'batching' mechanism for the above data
|
||||||
class GraphEmbeddings(Record):
|
class GraphEmbeddings(Record):
|
||||||
metadata = Metadata()
|
metadata = Metadata()
|
||||||
vectors = Array(Array(Double()))
|
entities = Array(EntityEmbeddings())
|
||||||
entity = Value()
|
|
||||||
|
|
||||||
graph_embeddings_store_queue = topic('graph-embeddings-store')
|
graph_embeddings_store_queue = topic('graph-embeddings-store')
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,9 @@ Vectorizer, calls the embeddings service to get embeddings for a chunk.
|
||||||
Input is text chunk, output is chunk and vectors.
|
Input is text chunk, output is chunk and vectors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from ... schema import Chunk, ChunkEmbeddings
|
from ... schema import EntityContexts, EntityEmbeddings, GraphEmbeddings
|
||||||
from ... schema import chunk_ingest_queue, chunk_embeddings_ingest_queue
|
from ... schema import entity_contexts_ingest_queue
|
||||||
|
from ... schema import graph_embeddings_store_queue
|
||||||
from ... schema import embeddings_request_queue, embeddings_response_queue
|
from ... schema import embeddings_request_queue, embeddings_response_queue
|
||||||
from ... clients.embeddings_client import EmbeddingsClient
|
from ... clients.embeddings_client import EmbeddingsClient
|
||||||
from ... log_level import LogLevel
|
from ... log_level import LogLevel
|
||||||
|
|
@ -13,8 +14,8 @@ from ... base import ConsumerProducer
|
||||||
|
|
||||||
module = ".".join(__name__.split(".")[1:-1])
|
module = ".".join(__name__.split(".")[1:-1])
|
||||||
|
|
||||||
default_input_queue = chunk_ingest_queue
|
default_input_queue = entity_contexts_ingest_queue
|
||||||
default_output_queue = chunk_embeddings_ingest_queue
|
default_output_queue = graph_embeddings_store_queue
|
||||||
default_subscriber = module
|
default_subscriber = module
|
||||||
|
|
||||||
class Processor(ConsumerProducer):
|
class Processor(ConsumerProducer):
|
||||||
|
|
@ -38,8 +39,8 @@ class Processor(ConsumerProducer):
|
||||||
"embeddings_request_queue": emb_request_queue,
|
"embeddings_request_queue": emb_request_queue,
|
||||||
"embeddings_response_queue": emb_response_queue,
|
"embeddings_response_queue": emb_response_queue,
|
||||||
"subscriber": subscriber,
|
"subscriber": subscriber,
|
||||||
"input_schema": Chunk,
|
"input_schema": EntityContexts,
|
||||||
"output_schema": ChunkEmbeddings,
|
"output_schema": GraphEmbeddings,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -50,9 +51,9 @@ class Processor(ConsumerProducer):
|
||||||
subscriber=module + "-emb",
|
subscriber=module + "-emb",
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit(self, metadata, chunk, vectors):
|
def emit(self, rec, vectors):
|
||||||
|
|
||||||
r = ChunkEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors)
|
r = GraphEmbeddings(metadata=metadata, chunk=chunk, vectors=vectors)
|
||||||
self.producer.send(r)
|
self.producer.send(r)
|
||||||
|
|
||||||
def handle(self, msg):
|
def handle(self, msg):
|
||||||
|
|
@ -60,21 +61,34 @@ class Processor(ConsumerProducer):
|
||||||
v = msg.value()
|
v = msg.value()
|
||||||
print(f"Indexing {v.metadata.id}...", flush=True)
|
print(f"Indexing {v.metadata.id}...", flush=True)
|
||||||
|
|
||||||
chunk = v.chunk.decode("utf-8")
|
entities = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
vectors = self.embeddings.request(chunk)
|
for entity in v.entities:
|
||||||
|
|
||||||
self.emit(
|
vectors = self.embeddings.request(entity.context)
|
||||||
|
|
||||||
|
entities.append(
|
||||||
|
EntityEmbeddings(
|
||||||
|
entity=entity.entity,
|
||||||
|
vectors=vectors
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
r = GraphEmbeddings(
|
||||||
metadata=v.metadata,
|
metadata=v.metadata,
|
||||||
chunk=chunk.encode("utf-8"),
|
entities=entities,
|
||||||
vectors=vectors
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.producer.send(r)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Exception:", e, flush=True)
|
print("Exception:", e, flush=True)
|
||||||
|
|
||||||
|
# Retry
|
||||||
|
raise e
|
||||||
|
|
||||||
print("Done.", flush=True)
|
print("Done.", flush=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,17 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
|
Simple decoder, accepts text chunks input, applies entity analysis to
|
||||||
get entity definitions which are output as graph edges.
|
get entity definitions which are output as graph edges along with
|
||||||
|
entity/context definitions for embedding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import json
|
from pulsar.schema import JsonSchema
|
||||||
|
|
||||||
from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value
|
from .... schema import Chunk, Triple, Triples, Metadata, Value
|
||||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
from .... schema import EntityContext, EntityContexts
|
||||||
|
from .... schema import chunk_ingest_queue, triples_store_queue
|
||||||
|
from .... schema import entity_contexts_ingest_queue
|
||||||
from .... schema import prompt_request_queue
|
from .... schema import prompt_request_queue
|
||||||
from .... schema import prompt_response_queue
|
from .... schema import prompt_response_queue
|
||||||
from .... log_level import LogLevel
|
from .... log_level import LogLevel
|
||||||
|
|
@ -22,8 +25,9 @@ SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
||||||
|
|
||||||
module = ".".join(__name__.split(".")[1:-1])
|
module = ".".join(__name__.split(".")[1:-1])
|
||||||
|
|
||||||
default_input_queue = chunk_embeddings_ingest_queue
|
default_input_queue = chunk_ingest_queue
|
||||||
default_output_queue = triples_store_queue
|
default_output_queue = triples_store_queue
|
||||||
|
default_entity_context_queue = entity_contexts_ingest_queue
|
||||||
default_subscriber = module
|
default_subscriber = module
|
||||||
|
|
||||||
class Processor(ConsumerProducer):
|
class Processor(ConsumerProducer):
|
||||||
|
|
@ -32,6 +36,10 @@ class Processor(ConsumerProducer):
|
||||||
|
|
||||||
input_queue = params.get("input_queue", default_input_queue)
|
input_queue = params.get("input_queue", default_input_queue)
|
||||||
output_queue = params.get("output_queue", default_output_queue)
|
output_queue = params.get("output_queue", default_output_queue)
|
||||||
|
ec_queue = params.get(
|
||||||
|
"entity_context_queue",
|
||||||
|
default_entity_context_queue
|
||||||
|
)
|
||||||
subscriber = params.get("subscriber", default_subscriber)
|
subscriber = params.get("subscriber", default_subscriber)
|
||||||
pr_request_queue = params.get(
|
pr_request_queue = params.get(
|
||||||
"prompt_request_queue", prompt_request_queue
|
"prompt_request_queue", prompt_request_queue
|
||||||
|
|
@ -45,13 +53,30 @@ class Processor(ConsumerProducer):
|
||||||
"input_queue": input_queue,
|
"input_queue": input_queue,
|
||||||
"output_queue": output_queue,
|
"output_queue": output_queue,
|
||||||
"subscriber": subscriber,
|
"subscriber": subscriber,
|
||||||
"input_schema": ChunkEmbeddings,
|
"input_schema": Chunk,
|
||||||
"output_schema": Triples,
|
"output_schema": Triples,
|
||||||
"prompt_request_queue": pr_request_queue,
|
"prompt_request_queue": pr_request_queue,
|
||||||
"prompt_response_queue": pr_response_queue,
|
"prompt_response_queue": pr_response_queue,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.ec_prod = self.client.create_producer(
|
||||||
|
topic=ec_queue,
|
||||||
|
schema=JsonSchema(EntityContexts),
|
||||||
|
)
|
||||||
|
|
||||||
|
__class__.pubsub_metric.info({
|
||||||
|
"input_queue": input_queue,
|
||||||
|
"output_queue": output_queue,
|
||||||
|
"entity_context_queue": ec_queue,
|
||||||
|
"prompt_request_queue": pr_request_queue,
|
||||||
|
"prompt_response_queue": pr_response_queue,
|
||||||
|
"subscriber": subscriber,
|
||||||
|
"input_schema": Chunk.__name__,
|
||||||
|
"output_schema": Triples.__name__,
|
||||||
|
"vector_schema": EntityContexts.__name__,
|
||||||
|
})
|
||||||
|
|
||||||
self.prompt = PromptClient(
|
self.prompt = PromptClient(
|
||||||
pulsar_host=self.pulsar_host,
|
pulsar_host=self.pulsar_host,
|
||||||
input_queue=pr_request_queue,
|
input_queue=pr_request_queue,
|
||||||
|
|
@ -79,6 +104,14 @@ class Processor(ConsumerProducer):
|
||||||
)
|
)
|
||||||
self.producer.send(t)
|
self.producer.send(t)
|
||||||
|
|
||||||
|
def emit_ecs(self, metadata, entities):
|
||||||
|
|
||||||
|
t = EntityContexts(
|
||||||
|
metadata=metadata,
|
||||||
|
entities=entities,
|
||||||
|
)
|
||||||
|
self.ec_prod.send(t)
|
||||||
|
|
||||||
def handle(self, msg):
|
def handle(self, msg):
|
||||||
|
|
||||||
v = msg.value()
|
v = msg.value()
|
||||||
|
|
@ -91,6 +124,7 @@ class Processor(ConsumerProducer):
|
||||||
defs = self.get_definitions(chunk)
|
defs = self.get_definitions(chunk)
|
||||||
|
|
||||||
triples = []
|
triples = []
|
||||||
|
entities = []
|
||||||
|
|
||||||
# FIXME: Putting metadata into triples store is duplicated in
|
# FIXME: Putting metadata into triples store is duplicated in
|
||||||
# relationships extractor too
|
# relationships extractor too
|
||||||
|
|
@ -129,6 +163,14 @@ class Processor(ConsumerProducer):
|
||||||
o=Value(value=v.metadata.id, is_uri=True)
|
o=Value(value=v.metadata.id, is_uri=True)
|
||||||
))
|
))
|
||||||
|
|
||||||
|
ec = EntityContext(
|
||||||
|
entity=s_value,
|
||||||
|
context=defn.definition,
|
||||||
|
)
|
||||||
|
|
||||||
|
entities.append(ec)
|
||||||
|
|
||||||
|
|
||||||
self.emit_edges(
|
self.emit_edges(
|
||||||
Metadata(
|
Metadata(
|
||||||
id=v.metadata.id,
|
id=v.metadata.id,
|
||||||
|
|
@ -139,6 +181,16 @@ class Processor(ConsumerProducer):
|
||||||
triples
|
triples
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.emit_ecs(
|
||||||
|
Metadata(
|
||||||
|
id=v.metadata.id,
|
||||||
|
metadata=[],
|
||||||
|
user=v.metadata.user,
|
||||||
|
collection=v.metadata.collection,
|
||||||
|
),
|
||||||
|
entities
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Exception: ", e, flush=True)
|
print("Exception: ", e, flush=True)
|
||||||
|
|
||||||
|
|
@ -152,6 +204,12 @@ class Processor(ConsumerProducer):
|
||||||
default_output_queue,
|
default_output_queue,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'-e', '--entity-context-queue',
|
||||||
|
default=default_entity_context_queue,
|
||||||
|
help=f'Entity context queue (default: {default_entity_context_queue})'
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--prompt-request-queue',
|
'--prompt-request-queue',
|
||||||
default=prompt_request_queue,
|
default=prompt_request_queue,
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,15 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Simple decoder, accepts vector+text chunks input, applies entity
|
Simple decoder, accepts text chunks input, applies entity
|
||||||
relationship analysis to get entity relationship edges which are output as
|
relationship analysis to get entity relationship edges which are output as
|
||||||
graph edges.
|
graph edges.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import os
|
|
||||||
from pulsar.schema import JsonSchema
|
|
||||||
|
|
||||||
from .... schema import ChunkEmbeddings, Triple, Triples, GraphEmbeddings
|
from .... schema import Chunk, Triple, Triples
|
||||||
from .... schema import Metadata, Value
|
from .... schema import Metadata, Value
|
||||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
from .... schema import chunk_ingest_queue, triples_store_queue
|
||||||
from .... schema import graph_embeddings_store_queue
|
|
||||||
from .... schema import prompt_request_queue
|
from .... schema import prompt_request_queue
|
||||||
from .... schema import prompt_response_queue
|
from .... schema import prompt_response_queue
|
||||||
from .... log_level import LogLevel
|
from .... log_level import LogLevel
|
||||||
|
|
@ -25,9 +22,8 @@ SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
|
||||||
|
|
||||||
module = ".".join(__name__.split(".")[1:-1])
|
module = ".".join(__name__.split(".")[1:-1])
|
||||||
|
|
||||||
default_input_queue = chunk_embeddings_ingest_queue
|
default_input_queue = chunk_ingest_queue
|
||||||
default_output_queue = triples_store_queue
|
default_output_queue = triples_store_queue
|
||||||
default_vector_queue = graph_embeddings_store_queue
|
|
||||||
default_subscriber = module
|
default_subscriber = module
|
||||||
|
|
||||||
class Processor(ConsumerProducer):
|
class Processor(ConsumerProducer):
|
||||||
|
|
@ -36,7 +32,6 @@ class Processor(ConsumerProducer):
|
||||||
|
|
||||||
input_queue = params.get("input_queue", default_input_queue)
|
input_queue = params.get("input_queue", default_input_queue)
|
||||||
output_queue = params.get("output_queue", default_output_queue)
|
output_queue = params.get("output_queue", default_output_queue)
|
||||||
vector_queue = params.get("vector_queue", default_vector_queue)
|
|
||||||
subscriber = params.get("subscriber", default_subscriber)
|
subscriber = params.get("subscriber", default_subscriber)
|
||||||
pr_request_queue = params.get(
|
pr_request_queue = params.get(
|
||||||
"prompt_request_queue", prompt_request_queue
|
"prompt_request_queue", prompt_request_queue
|
||||||
|
|
@ -50,30 +45,13 @@ class Processor(ConsumerProducer):
|
||||||
"input_queue": input_queue,
|
"input_queue": input_queue,
|
||||||
"output_queue": output_queue,
|
"output_queue": output_queue,
|
||||||
"subscriber": subscriber,
|
"subscriber": subscriber,
|
||||||
"input_schema": ChunkEmbeddings,
|
"input_schema": Chunk,
|
||||||
"output_schema": Triples,
|
"output_schema": Triples,
|
||||||
"prompt_request_queue": pr_request_queue,
|
"prompt_request_queue": pr_request_queue,
|
||||||
"prompt_response_queue": pr_response_queue,
|
"prompt_response_queue": pr_response_queue,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.vec_prod = self.client.create_producer(
|
|
||||||
topic=vector_queue,
|
|
||||||
schema=JsonSchema(GraphEmbeddings),
|
|
||||||
)
|
|
||||||
|
|
||||||
__class__.pubsub_metric.info({
|
|
||||||
"input_queue": input_queue,
|
|
||||||
"output_queue": output_queue,
|
|
||||||
"vector_queue": vector_queue,
|
|
||||||
"prompt_request_queue": pr_request_queue,
|
|
||||||
"prompt_response_queue": pr_response_queue,
|
|
||||||
"subscriber": subscriber,
|
|
||||||
"input_schema": ChunkEmbeddings.__name__,
|
|
||||||
"output_schema": Triples.__name__,
|
|
||||||
"vector_schema": GraphEmbeddings.__name__,
|
|
||||||
})
|
|
||||||
|
|
||||||
self.prompt = PromptClient(
|
self.prompt = PromptClient(
|
||||||
pulsar_host=self.pulsar_host,
|
pulsar_host=self.pulsar_host,
|
||||||
input_queue=pr_request_queue,
|
input_queue=pr_request_queue,
|
||||||
|
|
@ -101,11 +79,6 @@ class Processor(ConsumerProducer):
|
||||||
)
|
)
|
||||||
self.producer.send(t)
|
self.producer.send(t)
|
||||||
|
|
||||||
def emit_vec(self, metadata, ent, vec):
|
|
||||||
|
|
||||||
r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec)
|
|
||||||
self.vec_prod.send(r)
|
|
||||||
|
|
||||||
def handle(self, msg):
|
def handle(self, msg):
|
||||||
|
|
||||||
v = msg.value()
|
v = msg.value()
|
||||||
|
|
@ -193,12 +166,6 @@ class Processor(ConsumerProducer):
|
||||||
o=Value(value=v.metadata.id, is_uri=True)
|
o=Value(value=v.metadata.id, is_uri=True)
|
||||||
))
|
))
|
||||||
|
|
||||||
self.emit_vec(v.metadata, s_value, v.vectors)
|
|
||||||
self.emit_vec(v.metadata, p_value, v.vectors)
|
|
||||||
|
|
||||||
if rel.o_entity:
|
|
||||||
self.emit_vec(v.metadata, o_value, v.vectors)
|
|
||||||
|
|
||||||
self.emit_edges(
|
self.emit_edges(
|
||||||
Metadata(
|
Metadata(
|
||||||
id=v.metadata.id,
|
id=v.metadata.id,
|
||||||
|
|
@ -222,12 +189,6 @@ class Processor(ConsumerProducer):
|
||||||
default_output_queue,
|
default_output_queue,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'-c', '--vector-queue',
|
|
||||||
default=default_vector_queue,
|
|
||||||
help=f'Vector output queue (default: {default_vector_queue})'
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--prompt-request-queue',
|
'--prompt-request-queue',
|
||||||
default=prompt_request_queue,
|
default=prompt_request_queue,
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
|
Simple decoder, accepts text chunks input, applies entity analysis to
|
||||||
get entity definitions which are output as graph edges.
|
get topics which are output as graph edges.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from .... schema import ChunkEmbeddings, Triple, Triples, Metadata, Value
|
from .... schema import Chunk, Triple, Triples, Metadata, Value
|
||||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
from .... schema import chunk_ingest_queue, triples_store_queue
|
||||||
from .... schema import prompt_request_queue
|
from .... schema import prompt_request_queue
|
||||||
from .... schema import prompt_response_queue
|
from .... schema import prompt_response_queue
|
||||||
from .... log_level import LogLevel
|
from .... log_level import LogLevel
|
||||||
|
|
@ -20,7 +20,7 @@ DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
||||||
|
|
||||||
module = ".".join(__name__.split(".")[1:-1])
|
module = ".".join(__name__.split(".")[1:-1])
|
||||||
|
|
||||||
default_input_queue = chunk_embeddings_ingest_queue
|
default_input_queue = chunk_ingest_queue
|
||||||
default_output_queue = triples_store_queue
|
default_output_queue = triples_store_queue
|
||||||
default_subscriber = module
|
default_subscriber = module
|
||||||
|
|
||||||
|
|
@ -43,7 +43,7 @@ class Processor(ConsumerProducer):
|
||||||
"input_queue": input_queue,
|
"input_queue": input_queue,
|
||||||
"output_queue": output_queue,
|
"output_queue": output_queue,
|
||||||
"subscriber": subscriber,
|
"subscriber": subscriber,
|
||||||
"input_schema": ChunkEmbeddings,
|
"input_schema": Chunk,
|
||||||
"output_schema": Triples,
|
"output_schema": Triples,
|
||||||
"prompt_request_queue": pr_request_queue,
|
"prompt_request_queue": pr_request_queue,
|
||||||
"prompt_response_queue": pr_response_queue,
|
"prompt_response_queue": pr_response_queue,
|
||||||
|
|
|
||||||
|
|
@ -38,9 +38,11 @@ class Processor(Consumer):
|
||||||
|
|
||||||
v = msg.value()
|
v = msg.value()
|
||||||
|
|
||||||
if v.entity.value != "":
|
for entity in v.entities:
|
||||||
for vec in v.vectors:
|
|
||||||
self.vecstore.insert(vec, v.entity.value)
|
if entity.entity.value != "" and entity.entity.value is not None:
|
||||||
|
for vec in entity.vectors:
|
||||||
|
self.vecstore.insert(vec, entity.entity.value)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
|
|
@ -60,76 +60,83 @@ class Processor(Consumer):
|
||||||
|
|
||||||
self.last_index_name = None
|
self.last_index_name = None
|
||||||
|
|
||||||
|
def create_index(self, index_name, dim):
|
||||||
|
|
||||||
|
self.pinecone.create_index(
|
||||||
|
name = index_name,
|
||||||
|
dimension = dim,
|
||||||
|
metric = "cosine",
|
||||||
|
spec = ServerlessSpec(
|
||||||
|
cloud = self.cloud,
|
||||||
|
region = self.region,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(0, 1000):
|
||||||
|
|
||||||
|
if self.pinecone.describe_index(
|
||||||
|
index_name
|
||||||
|
).status["ready"]:
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
if not self.pinecone.describe_index(
|
||||||
|
index_name
|
||||||
|
).status["ready"]:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Gave up waiting for index creation"
|
||||||
|
)
|
||||||
|
|
||||||
def handle(self, msg):
|
def handle(self, msg):
|
||||||
|
|
||||||
v = msg.value()
|
v = msg.value()
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
|
|
||||||
if v.entity.value == "" or v.entity.value is None: return
|
for entity in v.entities:
|
||||||
|
|
||||||
for vec in v.vectors:
|
if entity.entity.value == "" or entity.entity.value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
dim = len(vec)
|
for vec in entity.vectors:
|
||||||
|
|
||||||
index_name = (
|
dim = len(vec)
|
||||||
"t-" + v.metadata.user + "-" + str(dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
if index_name != self.last_index_name:
|
index_name = (
|
||||||
|
"t-" + v.metadata.user + "-" + str(dim)
|
||||||
|
)
|
||||||
|
|
||||||
if not self.pinecone.has_index(index_name):
|
if index_name != self.last_index_name:
|
||||||
|
|
||||||
try:
|
if not self.pinecone.has_index(index_name):
|
||||||
|
|
||||||
self.pinecone.create_index(
|
try:
|
||||||
name = index_name,
|
|
||||||
dimension = dim,
|
|
||||||
metric = "cosine",
|
|
||||||
spec = ServerlessSpec(
|
|
||||||
cloud = self.cloud,
|
|
||||||
region = self.region,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for i in range(0, 1000):
|
self.create_index(index_name, dim)
|
||||||
|
|
||||||
if self.pinecone.describe_index(
|
except Exception as e:
|
||||||
index_name
|
print("Pinecone index creation failed")
|
||||||
).status["ready"]:
|
raise e
|
||||||
break
|
|
||||||
|
|
||||||
time.sleep(1)
|
print(f"Index {index_name} created", flush=True)
|
||||||
|
|
||||||
if not self.pinecone.describe_index(
|
self.last_index_name = index_name
|
||||||
index_name
|
|
||||||
).status["ready"]:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Gave up waiting for index creation"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
index = self.pinecone.Index(index_name)
|
||||||
print("Pinecone index creation failed")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
print(f"Index {index_name} created", flush=True)
|
records = [
|
||||||
|
{
|
||||||
|
"id": id,
|
||||||
|
"values": vec,
|
||||||
|
"metadata": { "entity": entity.entity.value },
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
self.last_index_name = index_name
|
index.upsert(
|
||||||
|
vectors = records,
|
||||||
index = self.pinecone.Index(index_name)
|
namespace = v.metadata.collection,
|
||||||
|
)
|
||||||
records = [
|
|
||||||
{
|
|
||||||
"id": id,
|
|
||||||
"values": vec,
|
|
||||||
"metadata": { "entity": v.entity.value },
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
index.upsert(
|
|
||||||
vectors = records,
|
|
||||||
namespace = v.metadata.collection,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
|
|
@ -40,49 +40,59 @@ class Processor(Consumer):
|
||||||
|
|
||||||
self.client = QdrantClient(url=store_uri)
|
self.client = QdrantClient(url=store_uri)
|
||||||
|
|
||||||
|
def get_collection(self, dim, user, collection):
|
||||||
|
|
||||||
|
cname = (
|
||||||
|
"t_" + user + "_" + collection + "_" + str(dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
if cname != self.last_collection:
|
||||||
|
|
||||||
|
if not self.client.collection_exists(cname):
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.client.create_collection(
|
||||||
|
collection_name=cname,
|
||||||
|
vectors_config=VectorParams(
|
||||||
|
size=dim, distance=Distance.COSINE
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print("Qdrant collection creation failed")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
self.last_collection = cname
|
||||||
|
|
||||||
|
return cname
|
||||||
|
|
||||||
def handle(self, msg):
|
def handle(self, msg):
|
||||||
|
|
||||||
v = msg.value()
|
v = msg.value()
|
||||||
|
|
||||||
if v.entity.value == "" or v.entity.value is None: return
|
for entity in v.entities:
|
||||||
|
|
||||||
for vec in v.vectors:
|
if entity.entity.value == "" or entity.entity.value is None: return
|
||||||
|
|
||||||
dim = len(vec)
|
for vec in entity.vectors:
|
||||||
collection = (
|
|
||||||
"t_" + v.metadata.user + "_" + v.metadata.collection + "_" +
|
|
||||||
str(dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
if collection != self.last_collection:
|
dim = len(vec)
|
||||||
|
|
||||||
if not self.client.collection_exists(collection):
|
collection = self.get_collection(
|
||||||
|
dim, v.metadata.user, v.metadata.collection
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
self.client.upsert(
|
||||||
self.client.create_collection(
|
collection_name=collection,
|
||||||
collection_name=collection,
|
points=[
|
||||||
vectors_config=VectorParams(
|
PointStruct(
|
||||||
size=dim, distance=Distance.COSINE
|
id=str(uuid.uuid4()),
|
||||||
),
|
vector=vec,
|
||||||
|
payload={
|
||||||
|
"entity": entity.entity.value,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
]
|
||||||
print("Qdrant collection creation failed")
|
)
|
||||||
raise e
|
|
||||||
|
|
||||||
self.last_collection = collection
|
|
||||||
|
|
||||||
self.client.upsert(
|
|
||||||
collection_name=collection,
|
|
||||||
points=[
|
|
||||||
PointStruct(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
vector=vec,
|
|
||||||
payload={
|
|
||||||
"entity": v.entity.value,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_args(parser):
|
def add_args(parser):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue