trustgraph/trustgraph-flow/trustgraph/extract/kg/relationships/extract.py
cybermaggedon 7954e863cc
Feature: document metadata (#123)
* Rework metadata structure in processing messages to be a subgraph
* Add subgraph creation for tg-load-pdf and tg-load-text based on command-line passing of doc attributes
* Document metadata is added to knowledge graph with subjectOf linkage to extracted entities
2024-10-23 18:04:04 +01:00

246 lines
7.4 KiB
Python
Executable file

"""
Simple decoder, accepts vector+text chunks input, applies entity
relationship analysis to get entity relationship edges which are output as
graph edges.
"""
import urllib.parse
import os
from pulsar.schema import JsonSchema
from .... schema import ChunkEmbeddings, Triple, Triples, GraphEmbeddings
from .... schema import Metadata, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import graph_embeddings_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES, SUBJECT_OF
from .... base import ConsumerProducer
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
SUBJECT_OF_VALUE = Value(value=SUBJECT_OF, is_uri=True)
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue
default_output_queue = triples_store_queue
default_vector_queue = graph_embeddings_store_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
vector_queue = params.get("vector_queue", default_vector_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings,
"output_schema": Triples,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
}
)
self.vec_prod = self.client.create_producer(
topic=vector_queue,
schema=JsonSchema(GraphEmbeddings),
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"output_queue": output_queue,
"vector_queue": vector_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings.__name__,
"output_schema": Triples.__name__,
"vector_schema": GraphEmbeddings.__name__,
})
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",
)
def to_uri(self, text):
part = text.replace(" ", "-").lower().encode("utf-8")
quoted = urllib.parse.quote(part)
uri = TRUSTGRAPH_ENTITIES + quoted
return uri
def get_relationships(self, chunk):
return self.prompt.request_relationships(chunk)
def emit_edges(self, metadata, triples):
t = Triples(
metadata=metadata,
triples=triples,
)
self.producer.send(t)
def emit_vec(self, metadata, ent, vec):
r = GraphEmbeddings(metadata=metadata, entity=ent, vectors=vec)
self.vec_prod.send(r)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.metadata.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
try:
rels = self.get_relationships(chunk)
triples = []
# FIXME: Putting metadata into triples store is duplicated in
# relationships extractor too
for t in v.metadata.metadata:
triples.append(t)
for rel in rels:
s = rel.s
p = rel.p
o = rel.o
if s == "": continue
if p == "": continue
if o == "": continue
if s is None: continue
if p is None: continue
if o is None: continue
s_uri = self.to_uri(s)
s_value = Value(value=str(s_uri), is_uri=True)
p_uri = self.to_uri(p)
p_value = Value(value=str(p_uri), is_uri=True)
if rel.o_entity:
o_uri = self.to_uri(o)
o_value = Value(value=str(o_uri), is_uri=True)
else:
o_value = Value(value=str(o), is_uri=False)
triples.append(Triple(
s=s_value,
p=p_value,
o=o_value
))
# Label for s
triples.append(Triple(
s=s_value,
p=RDF_LABEL_VALUE,
o=Value(value=str(s), is_uri=False)
))
# Label for p
triples.append(Triple(
s=p_value,
p=RDF_LABEL_VALUE,
o=Value(value=str(p), is_uri=False)
))
if rel.o_entity:
# Label for o
triples.append(Triple(
s=o_value,
p=RDF_LABEL_VALUE,
o=Value(value=str(o), is_uri=False)
))
# 'Subject of' for s
triples.append(Triple(
s=s_value,
p=SUBJECT_OF_VALUE,
o=Value(value=v.metadata.id, is_uri=True)
))
if rel.o_entity:
# 'Subject of' for o
triples.append(Triple(
s=o_value,
p=RDF_LABEL_VALUE,
o=Value(value=v.metadata.id, is_uri=True)
))
self.emit_vec(v.metadata, s_value, v.vectors)
self.emit_vec(v.metadata, p_value, v.vectors)
if rel.o_entity:
self.emit_vec(v.metadata, o_value, v.vectors)
self.emit_edges(
Metadata(
id=v.metadata.id,
metadata=[],
user=v.metadata.user,
collection=v.metadata.collection,
),
triples
)
except Exception as e:
print("Exception: ", e, flush=True)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-c', '--vector-queue',
default=default_vector_queue,
help=f'Vector output queue (default: {default_vector_queue})'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
def run():
Processor.start(module, __doc__)