mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-30 10:56:23 +02:00
Extract rows and apply object embeddings (#42)
* - Restructured the extract directories - Added an extractor for 'rows' == a row of a table - Added a row extractor prompt to prompter. * Add row support to template prompter * Row extraction working * Bump version * Emit extracted info * Object embeddings store * Invocation script * Add script to package, remove cruft output * Write rows to Cassandra * Remove output cruft
This commit is contained in:
parent
b574ba26a8
commit
e4c4774b5d
70 changed files with 1624 additions and 520 deletions
|
|
@ -116,4 +116,4 @@ class BaseProcessor:
|
|||
print("Exception:", e, flush=True)
|
||||
print("Will retry...", flush=True)
|
||||
|
||||
time.sleep(10)
|
||||
time.sleep(4)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ class Consumer(BaseProcessor):
|
|||
|
||||
def __init__(self, **params):
|
||||
|
||||
print("HERE2")
|
||||
super(Consumer, self).__init__(**params)
|
||||
|
||||
input_queue = params.get("input_queue")
|
||||
|
|
@ -29,6 +30,7 @@ class Consumer(BaseProcessor):
|
|||
'pubsub', 'Pub/sub configuration'
|
||||
)
|
||||
|
||||
print("HERE")
|
||||
if not hasattr(__class__, "processing_metric"):
|
||||
__class__.processing_metric = Counter(
|
||||
'processing_count', 'Processing count', ["status"]
|
||||
|
|
@ -65,7 +67,7 @@ class Consumer(BaseProcessor):
|
|||
self.consumer.negative_acknowledge(msg)
|
||||
print("TooManyRequests: will retry")
|
||||
__class__.processing_metric.labels(status="rate-limit").inc()
|
||||
time.sleep(2)
|
||||
time.sleep(5)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class ConsumerProducer(BaseProcessor):
|
|||
self.consumer.negative_acknowledge(msg)
|
||||
print("TooManyRequests: will retry")
|
||||
__class__.processing_metric.labels(status="rate-limit").inc()
|
||||
time.sleep(2)
|
||||
time.sleep(5)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import PromptRequest, PromptResponse, Fact
|
||||
from .. schema import PromptRequest, PromptResponse, Fact, RowSchema, Field
|
||||
from .. schema import prompt_request_queue
|
||||
from .. schema import prompt_response_queue
|
||||
from . base import BaseClient
|
||||
|
|
@ -52,6 +52,24 @@ class PromptClient(BaseClient):
|
|||
timeout=timeout
|
||||
).relationships
|
||||
|
||||
def request_rows(self, schema, chunk, timeout=300):
|
||||
|
||||
return self.call(
|
||||
kind="extract-rows", chunk=chunk,
|
||||
row_schema=RowSchema(
|
||||
name=schema.name,
|
||||
description=schema.description,
|
||||
fields=[
|
||||
Field(
|
||||
name=f.name, type=str(f.type), size=f.size,
|
||||
primary=f.primary, description=f.description,
|
||||
)
|
||||
for f in schema.fields
|
||||
]
|
||||
),
|
||||
timeout=timeout
|
||||
).rows
|
||||
|
||||
def request_kg_prompt(self, query, kg, timeout=300):
|
||||
|
||||
return self.call(
|
||||
|
|
|
|||
154
trustgraph/direct/milvus_object_embeddings.py
Normal file
154
trustgraph/direct/milvus_object_embeddings.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
|
||||
class ObjectVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='obj'):
|
||||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
# Time between reloads
|
||||
self.reload_time = 90
|
||||
|
||||
# Next time to reload - this forces a reload at next window
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
print("Reload at", self.next_reload)
|
||||
|
||||
def init_collection(self, dimension, name):
|
||||
|
||||
collection_name = self.prefix + "_" + name + "_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
name_field = FieldSchema(
|
||||
name="name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
key_name_field = FieldSchema(
|
||||
name="key_name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
key_field = FieldSchema(
|
||||
name="key",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [
|
||||
pkey_field, vec_field, name_field, key_name_field, key_field
|
||||
],
|
||||
description = "Object embedding schema",
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[(dimension, name)] = collection_name
|
||||
|
||||
def insert(self, embeds, name, key_name, key):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if (dim, name) not in self.collections:
|
||||
self.init_collection(dim, name)
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"name": name,
|
||||
"key_name": key_name,
|
||||
"key": key,
|
||||
}
|
||||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[(dim, name)],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, name, fields=["key_name", "name"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim, name)
|
||||
|
||||
coll = self.collections[(dim, name)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
print("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
print("Searching...")
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
|
||||
# If reload time has passed, unload collection
|
||||
if time.time() > self.next_reload:
|
||||
print("Unloading, reload at", self.next_reload)
|
||||
self.client.release_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
|
||||
return res
|
||||
|
||||
0
trustgraph/extract/kg/__init__.py
Normal file
0
trustgraph/extract/kg/__init__.py
Normal file
|
|
@ -7,14 +7,14 @@ get entity definitions which are output as graph edges.
|
|||
import urllib.parse
|
||||
import json
|
||||
|
||||
from ... schema import ChunkEmbeddings, Triple, Source, Value
|
||||
from ... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from ... schema import prompt_request_queue
|
||||
from ... schema import prompt_response_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... clients.prompt_client import PromptClient
|
||||
from ... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
|
||||
from ... base import ConsumerProducer
|
||||
from .... schema import ChunkEmbeddings, Triple, Source, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
||||
|
||||
|
|
@ -9,15 +9,15 @@ import urllib.parse
|
|||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from ... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
|
||||
from ... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from ... schema import graph_embeddings_store_queue
|
||||
from ... schema import prompt_request_queue
|
||||
from ... schema import prompt_response_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... clients.prompt_client import PromptClient
|
||||
from ... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES
|
||||
from ... base import ConsumerProducer
|
||||
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import graph_embeddings_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
|
||||
|
||||
0
trustgraph/extract/object/__init__.py
Normal file
0
trustgraph/extract/object/__init__.py
Normal file
3
trustgraph/extract/object/row/__init__.py
Normal file
3
trustgraph/extract/object/row/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
7
trustgraph/extract/object/row/__main__.py
Executable file
7
trustgraph/extract/object/row/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
220
trustgraph/extract/object/row/extract.py
Executable file
220
trustgraph/extract/object/row/extract.py
Executable file
|
|
@ -0,0 +1,220 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts vector+text chunks input, applies analysis to pull
|
||||
out a row of fields. Output as a vector plus object.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Source
|
||||
from .... schema import RowSchema, Field
|
||||
from .... schema import chunk_embeddings_ingest_queue, rows_store_queue
|
||||
from .... schema import object_embeddings_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
from .... objects.field import Field as FieldParser
|
||||
from .... objects.object import Schema
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = rows_store_queue
|
||||
default_vector_queue = object_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
vector_queue = params.get("vector_queue", default_vector_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Rows,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.vec_prod = self.client.create_producer(
|
||||
topic=vector_queue,
|
||||
schema=JsonSchema(ObjectEmbeddings),
|
||||
)
|
||||
|
||||
__class__.pubsub_metric.info({
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"vector_queue": vector_queue,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings.__name__,
|
||||
"output_schema": Rows.__name__,
|
||||
"vector_schema": ObjectEmbeddings.__name__,
|
||||
})
|
||||
|
||||
flds = __class__.parse_fields(params["field"])
|
||||
|
||||
for fld in flds:
|
||||
print(fld)
|
||||
|
||||
self.primary = None
|
||||
|
||||
for f in flds:
|
||||
if f.primary:
|
||||
if self.primary:
|
||||
raise RuntimeError(
|
||||
"Only one primary key field is supported"
|
||||
)
|
||||
self.primary = f
|
||||
|
||||
if self.primary == None:
|
||||
raise RuntimeError(
|
||||
"Must have exactly one primary key field"
|
||||
)
|
||||
|
||||
self.schema = Schema(
|
||||
name = params["name"],
|
||||
description = params["description"],
|
||||
fields = flds
|
||||
)
|
||||
|
||||
self.row_schema=RowSchema(
|
||||
name=self.schema.name,
|
||||
description=self.schema.description,
|
||||
fields=[
|
||||
Field(
|
||||
name=f.name, type=str(f.type), size=f.size,
|
||||
primary=f.primary, description=f.description,
|
||||
)
|
||||
for f in self.schema.fields
|
||||
]
|
||||
)
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_fields(fields):
|
||||
return [ FieldParser.parse(f) for f in fields ]
|
||||
|
||||
def get_rows(self, chunk):
|
||||
return self.prompt.request_rows(self.schema, chunk)
|
||||
|
||||
def emit_rows(self, source, rows):
|
||||
|
||||
t = Rows(
|
||||
source=source, row_schema=self.row_schema, rows=rows
|
||||
)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_vec(self, source, name, vec, key_name, key):
|
||||
|
||||
r = ObjectEmbeddings(
|
||||
source=source, vectors=vec, name=name, key_name=key_name, id=key
|
||||
)
|
||||
self.vec_prod.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
rows = self.get_rows(chunk)
|
||||
|
||||
self.emit_rows(
|
||||
source=v.source,
|
||||
rows=rows
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
self.emit_vec(
|
||||
source=v.source, vec=v.vectors,
|
||||
name=self.schema.name, key_name=self.primary.name,
|
||||
key=row[self.primary.name]
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
print(row)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--vector-queue',
|
||||
default=default_vector_queue,
|
||||
help=f'Vector output queue (default: {default_vector_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--field',
|
||||
required=True,
|
||||
action='append',
|
||||
help=f'Field definition, format name:type:size:pri:descriptionn',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-n', '--name',
|
||||
required=True,
|
||||
help=f'Name of row object',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-d', '--description',
|
||||
required=True,
|
||||
help=f'Description of object',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -48,6 +48,44 @@ or headers or prefixes. Do not include null or unknown definitions.
|
|||
|
||||
return prompt
|
||||
|
||||
def to_rows(schema, text):
|
||||
|
||||
field_schema = [
|
||||
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
|
||||
for f in schema.fields
|
||||
]
|
||||
|
||||
field_schema = "\n".join(field_schema)
|
||||
|
||||
schema = f"""Object name: {schema.name}
|
||||
Description: {schema.description}
|
||||
|
||||
Fields:
|
||||
{field_schema}"""
|
||||
|
||||
prompt = f"""<instructions>
|
||||
Study the following text and derive objects which match the schema provided.
|
||||
|
||||
You must output an array of JSON objects for each object you discover
|
||||
which matches the schema. For each object, output a JSON object whose fields
|
||||
carry the name field specified in the schema.
|
||||
</instructions>
|
||||
|
||||
<schema>
|
||||
{schema}
|
||||
</schema>
|
||||
|
||||
<text>
|
||||
{text}
|
||||
</text>
|
||||
|
||||
<requirements>
|
||||
You will respond only with raw JSON format data. Do not provide
|
||||
explanations. Do not add markdown formatting or headers or prefixes.
|
||||
</requirements>"""
|
||||
|
||||
return prompt
|
||||
|
||||
def get_cypher(kg):
|
||||
|
||||
sg2 = []
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from .... base import ConsumerProducer
|
|||
from .... clients.llm_client import LlmClient
|
||||
|
||||
from . prompts import to_definitions, to_relationships
|
||||
from . prompts import to_kg_query, to_document_query
|
||||
from . prompts import to_kg_query, to_document_query, to_rows
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
|
|
@ -77,6 +77,11 @@ class Processor(ConsumerProducer):
|
|||
self.handle_extract_relationships(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-rows":
|
||||
|
||||
self.handle_extract_rows(id, v)
|
||||
return
|
||||
|
||||
elif kind == "kg-prompt":
|
||||
|
||||
self.handle_kg_prompt(id, v)
|
||||
|
|
@ -222,6 +227,77 @@ class Processor(ConsumerProducer):
|
|||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_rows(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
fields = v.row_schema.fields
|
||||
|
||||
prompt = to_rows(v.row_schema, v.chunk)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
objs = json.loads(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
objs = []
|
||||
|
||||
output = []
|
||||
|
||||
for obj in objs:
|
||||
|
||||
try:
|
||||
|
||||
row = {}
|
||||
|
||||
for f in fields:
|
||||
|
||||
if f.name not in obj:
|
||||
print(f"Object ignored, missing field {f.name}")
|
||||
row = {}
|
||||
break
|
||||
|
||||
row[f.name] = obj[f.name]
|
||||
|
||||
if row == {}:
|
||||
continue
|
||||
|
||||
output.append(row)
|
||||
|
||||
except Exception as e:
|
||||
print("row fields missing, ignored", flush=True)
|
||||
|
||||
for row in output:
|
||||
print(row)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(rows=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_kg_prompt(self, id, v):
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,27 @@ def to_relationships(template, text):
|
|||
def to_definitions(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_rows(template, schema, text):
|
||||
|
||||
field_schema = [
|
||||
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
|
||||
for f in schema.fields
|
||||
]
|
||||
|
||||
field_schema = "\n".join(field_schema)
|
||||
|
||||
return template.format(schema=schema, text=text)
|
||||
|
||||
schema = f"""Object name: {schema.name}
|
||||
Description: {schema.description}
|
||||
|
||||
Fields:
|
||||
{schema}"""
|
||||
|
||||
prompt = f""""""
|
||||
|
||||
return prompt
|
||||
|
||||
def get_cypher(kg):
|
||||
sg2 = []
|
||||
for f in kg:
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from .... schema import prompt_request_queue, prompt_response_queue
|
|||
from .... base import ConsumerProducer
|
||||
from .... clients.llm_client import LlmClient
|
||||
|
||||
from . prompts import to_definitions, to_relationships
|
||||
from . prompts import to_definitions, to_relationships, to_rows
|
||||
from . prompts import to_kg_query, to_document_query
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
|
@ -38,6 +38,7 @@ class Processor(ConsumerProducer):
|
|||
)
|
||||
definition_template = params.get("definition_template")
|
||||
relationship_template = params.get("relationship_template")
|
||||
rows_template = params.get("rows_template")
|
||||
knowledge_query_template = params.get("knowledge_query_template")
|
||||
document_query_template = params.get("document_query_template")
|
||||
|
||||
|
|
@ -62,6 +63,7 @@ class Processor(ConsumerProducer):
|
|||
|
||||
self.definition_template = definition_template
|
||||
self.relationship_template = relationship_template
|
||||
self.rows_template = rows_template
|
||||
self.knowledge_query_template = knowledge_query_template
|
||||
self.document_query_template = document_query_template
|
||||
|
||||
|
|
@ -87,6 +89,11 @@ class Processor(ConsumerProducer):
|
|||
self.handle_extract_relationships(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-rows":
|
||||
|
||||
self.handle_extract_rows(id, v)
|
||||
return
|
||||
|
||||
elif kind == "kg-prompt":
|
||||
|
||||
self.handle_kg_prompt(id, v)
|
||||
|
|
@ -232,6 +239,77 @@ class Processor(ConsumerProducer):
|
|||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_rows(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
fields = v.row_schema.fields
|
||||
|
||||
prompt = to_rows(self.rows_template, v.row_schema, v.chunk)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
objs = json.loads(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
objs = []
|
||||
|
||||
output = []
|
||||
|
||||
for obj in objs:
|
||||
|
||||
try:
|
||||
|
||||
row = {}
|
||||
|
||||
for f in fields:
|
||||
|
||||
if f.name not in obj:
|
||||
print(f"Object ignored, missing field {f.name}")
|
||||
row = {}
|
||||
break
|
||||
|
||||
row[f.name] = obj[f.name]
|
||||
|
||||
if row == {}:
|
||||
continue
|
||||
|
||||
output.append(row)
|
||||
|
||||
except Exception as e:
|
||||
print("row fields missing, ignored", flush=True)
|
||||
|
||||
for row in output:
|
||||
print(row)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(rows=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_kg_prompt(self, id, v):
|
||||
|
||||
|
|
@ -329,6 +407,12 @@ class Processor(ConsumerProducer):
|
|||
help=f'Definition extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rows-template',
|
||||
required=True,
|
||||
help=f'Rows extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--relationship-template',
|
||||
required=True,
|
||||
|
|
|
|||
0
trustgraph/objects/__init__.py
Normal file
0
trustgraph/objects/__init__.py
Normal file
72
trustgraph/objects/field.py
Normal file
72
trustgraph/objects/field.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
class FieldType(Enum):
|
||||
STRING = 0
|
||||
INT = 1
|
||||
LONG = 2
|
||||
BOOL = 3
|
||||
FLOAT = 4
|
||||
DOUBLE = 5
|
||||
|
||||
def __str__(self):
|
||||
return self.name.lower()
|
||||
|
||||
@dataclass
|
||||
class Field:
|
||||
name: str
|
||||
size: int = -1
|
||||
primary: bool = False
|
||||
type: str = "undefined"
|
||||
description: str = ""
|
||||
|
||||
@staticmethod
|
||||
def parse(defn):
|
||||
|
||||
if defn == "" or defn is None:
|
||||
raise RuntimeError("Field definition cannot be empty")
|
||||
|
||||
parts = defn.split(":")
|
||||
|
||||
if len(parts) == 0:
|
||||
raise RuntimeError("Field definition cannot be empty")
|
||||
|
||||
if len(parts) == 1: parts.append("string")
|
||||
if len(parts) == 2: parts.append("0")
|
||||
if len(parts) == 3: parts.append("")
|
||||
if len(parts) == 4: parts.append("")
|
||||
|
||||
name, type, size, pri, description = parts
|
||||
|
||||
size = int(size)
|
||||
|
||||
try:
|
||||
type = FieldType[type.upper()]
|
||||
except:
|
||||
raise RuntimeError(f"Field type {type} is not known")
|
||||
|
||||
pri = True if pri == "pri" else False
|
||||
|
||||
return Field(
|
||||
name=name, type=type, size=size, primary=pri,
|
||||
description=description
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
name = self.name
|
||||
type = self.type
|
||||
size = self.size
|
||||
pri = "pri" if self.primary else ""
|
||||
description = self.description
|
||||
|
||||
return f"{name}:{type}:{size}:{pri}:{description}"
|
||||
|
||||
def __str__(self):
|
||||
name = self.name
|
||||
type = self.type
|
||||
size = self.size
|
||||
pri = "pri" if self.primary else ""
|
||||
description = self.description
|
||||
|
||||
return f"{name}:{type}:{size}:{pri}:{description}"
|
||||
8
trustgraph/objects/object.py
Normal file
8
trustgraph/objects/object.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
|
||||
class Schema:
|
||||
def __init__(self, name, description, fields):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.fields = fields
|
||||
|
||||
|
||||
|
|
@ -1,274 +0,0 @@
|
|||
|
||||
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double
|
||||
|
||||
from enum import Enum
|
||||
|
||||
def topic(topic, kind='persistent', tenant='tg', namespace='flow'):
|
||||
return f"{kind}://{tenant}/{namespace}/{topic}"
|
||||
|
||||
############################################################################
|
||||
|
||||
class Error(Record):
|
||||
type = String()
|
||||
message = String()
|
||||
|
||||
############################################################################
|
||||
|
||||
class Value(Record):
|
||||
value = String()
|
||||
is_uri = Boolean()
|
||||
type = String()
|
||||
|
||||
class Source(Record):
|
||||
source = String()
|
||||
id = String()
|
||||
title = String()
|
||||
|
||||
############################################################################
|
||||
|
||||
# PDF docs etc.
|
||||
class Document(Record):
|
||||
source = Source()
|
||||
data = Bytes()
|
||||
|
||||
document_ingest_queue = topic('document-load')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Text documents / text from PDF
|
||||
|
||||
class TextDocument(Record):
|
||||
source = Source()
|
||||
text = Bytes()
|
||||
|
||||
text_ingest_queue = topic('text-document-load')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Chunks of text
|
||||
|
||||
class Chunk(Record):
|
||||
source = Source()
|
||||
chunk = Bytes()
|
||||
|
||||
chunk_ingest_queue = topic('chunk-load')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Chunk embeddings are an embeddings associated with a text chunk
|
||||
|
||||
class ChunkEmbeddings(Record):
|
||||
source = Source()
|
||||
vectors = Array(Array(Double()))
|
||||
chunk = Bytes()
|
||||
|
||||
chunk_embeddings_ingest_queue = topic('chunk-embeddings-load')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Graph embeddings are embeddings associated with a graph entity
|
||||
|
||||
class GraphEmbeddings(Record):
|
||||
source = Source()
|
||||
vectors = Array(Array(Double()))
|
||||
entity = Value()
|
||||
|
||||
graph_embeddings_store_queue = topic('graph-embeddings-store')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Graph embeddings query
|
||||
|
||||
class GraphEmbeddingsRequest(Record):
|
||||
vectors = Array(Array(Double()))
|
||||
limit = Integer()
|
||||
|
||||
class GraphEmbeddingsResponse(Record):
|
||||
error = Error()
|
||||
entities = Array(Value())
|
||||
|
||||
graph_embeddings_request_queue = topic(
|
||||
'graph-embeddings', kind='non-persistent', namespace='request'
|
||||
)
|
||||
graph_embeddings_response_queue = topic(
|
||||
'graph-embeddings-response', kind='non-persistent', namespace='response',
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
# Doc embeddings query
|
||||
|
||||
class DocumentEmbeddingsRequest(Record):
|
||||
vectors = Array(Array(Double()))
|
||||
limit = Integer()
|
||||
|
||||
class DocumentEmbeddingsResponse(Record):
|
||||
error = Error()
|
||||
documents = Array(Bytes())
|
||||
|
||||
document_embeddings_request_queue = topic(
|
||||
'doc-embeddings', kind='non-persistent', namespace='request'
|
||||
)
|
||||
document_embeddings_response_queue = topic(
|
||||
'doc-embeddings-response', kind='non-persistent', namespace='response',
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
# Graph triples
|
||||
|
||||
class Triple(Record):
|
||||
source = Source()
|
||||
s = Value()
|
||||
p = Value()
|
||||
o = Value()
|
||||
|
||||
triples_store_queue = topic('triples-store')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Triples query
|
||||
|
||||
class TriplesQueryRequest(Record):
|
||||
s = Value()
|
||||
p = Value()
|
||||
o = Value()
|
||||
limit = Integer()
|
||||
|
||||
class TriplesQueryResponse(Record):
|
||||
error = Error()
|
||||
triples = Array(Triple())
|
||||
|
||||
triples_request_queue = topic(
|
||||
'triples', kind='non-persistent', namespace='request'
|
||||
)
|
||||
triples_response_queue = topic(
|
||||
'triples-response', kind='non-persistent', namespace='response',
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
# chunk_embeddings_store_queue = topic('chunk-embeddings-store')
|
||||
|
||||
############################################################################
|
||||
|
||||
# LLM text completion
|
||||
|
||||
class TextCompletionRequest(Record):
|
||||
prompt = String()
|
||||
|
||||
class TextCompletionResponse(Record):
|
||||
error = Error()
|
||||
response = String()
|
||||
|
||||
text_completion_request_queue = topic(
|
||||
'text-completion', kind='non-persistent', namespace='request'
|
||||
)
|
||||
text_completion_response_queue = topic(
|
||||
'text-completion-response', kind='non-persistent', namespace='response',
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
# Embeddings
|
||||
|
||||
class EmbeddingsRequest(Record):
|
||||
text = String()
|
||||
|
||||
class EmbeddingsResponse(Record):
|
||||
error = Error()
|
||||
vectors = Array(Array(Double()))
|
||||
|
||||
embeddings_request_queue = topic(
|
||||
'embeddings', kind='non-persistent', namespace='request'
|
||||
)
|
||||
embeddings_response_queue = topic(
|
||||
'embeddings-response', kind='non-persistent', namespace='response'
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
# Graph RAG text retrieval
|
||||
|
||||
class GraphRagQuery(Record):
|
||||
query = String()
|
||||
|
||||
class GraphRagResponse(Record):
|
||||
error = Error()
|
||||
response = String()
|
||||
|
||||
graph_rag_request_queue = topic(
|
||||
'graph-rag', kind='non-persistent', namespace='request'
|
||||
)
|
||||
graph_rag_response_queue = topic(
|
||||
'graph-rag-response', kind='non-persistent', namespace='response'
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
# Document RAG text retrieval
|
||||
|
||||
class DocumentRagQuery(Record):
|
||||
query = String()
|
||||
|
||||
class DocumentRagResponse(Record):
|
||||
error = Error()
|
||||
response = String()
|
||||
|
||||
document_rag_request_queue = topic(
|
||||
'doc-rag', kind='non-persistent', namespace='request'
|
||||
)
|
||||
document_rag_response_queue = topic(
|
||||
'doc-rag-response', kind='non-persistent', namespace='response'
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
# Prompt services, abstract the prompt generation
|
||||
|
||||
class Definition(Record):
|
||||
name = String()
|
||||
definition = String()
|
||||
|
||||
class Relationship(Record):
|
||||
s = String()
|
||||
p = String()
|
||||
o = String()
|
||||
o_entity = Boolean()
|
||||
|
||||
class Fact(Record):
|
||||
s = String()
|
||||
p = String()
|
||||
o = String()
|
||||
|
||||
# extract-definitions:
|
||||
# chunk -> definitions
|
||||
# extract-relationships:
|
||||
# chunk -> relationships
|
||||
# kg-prompt:
|
||||
# query, triples -> answer
|
||||
# document-prompt:
|
||||
# query, documents -> answer
|
||||
|
||||
class PromptRequest(Record):
|
||||
kind = String()
|
||||
chunk = String()
|
||||
query = String()
|
||||
kg = Array(Fact())
|
||||
documents = Array(Bytes())
|
||||
|
||||
class PromptResponse(Record):
|
||||
error = Error()
|
||||
answer = String()
|
||||
definitions = Array(Definition())
|
||||
relationships = Array(Relationship())
|
||||
|
||||
prompt_request_queue = topic(
|
||||
'prompt', kind='non-persistent', namespace='request'
|
||||
)
|
||||
prompt_response_queue = topic(
|
||||
'prompt-response', kind='non-persistent', namespace='response'
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
12
trustgraph/schema/__init__.py
Normal file
12
trustgraph/schema/__init__.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
|
||||
from . types import *
|
||||
from . prompt import *
|
||||
from . documents import *
|
||||
from . models import *
|
||||
from . object import *
|
||||
from . topic import *
|
||||
from . graph import *
|
||||
from . retrieval import *
|
||||
|
||||
|
||||
|
||||
68
trustgraph/schema/documents.py
Normal file
68
trustgraph/schema/documents.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
|
||||
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double
|
||||
from . topic import topic
|
||||
from . types import Error
|
||||
|
||||
class Source(Record):
|
||||
source = String()
|
||||
id = String()
|
||||
title = String()
|
||||
|
||||
############################################################################
|
||||
|
||||
# PDF docs etc.
|
||||
class Document(Record):
|
||||
source = Source()
|
||||
data = Bytes()
|
||||
|
||||
document_ingest_queue = topic('document-load')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Text documents / text from PDF
|
||||
|
||||
class TextDocument(Record):
|
||||
source = Source()
|
||||
text = Bytes()
|
||||
|
||||
text_ingest_queue = topic('text-document-load')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Chunks of text
|
||||
|
||||
class Chunk(Record):
|
||||
source = Source()
|
||||
chunk = Bytes()
|
||||
|
||||
chunk_ingest_queue = topic('chunk-load')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Chunk embeddings are an embeddings associated with a text chunk
|
||||
|
||||
class ChunkEmbeddings(Record):
|
||||
source = Source()
|
||||
vectors = Array(Array(Double()))
|
||||
chunk = Bytes()
|
||||
|
||||
chunk_embeddings_ingest_queue = topic('chunk-embeddings-load')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Doc embeddings query
|
||||
|
||||
class DocumentEmbeddingsRequest(Record):
|
||||
vectors = Array(Array(Double()))
|
||||
limit = Integer()
|
||||
|
||||
class DocumentEmbeddingsResponse(Record):
|
||||
error = Error()
|
||||
documents = Array(Bytes())
|
||||
|
||||
document_embeddings_request_queue = topic(
|
||||
'doc-embeddings', kind='non-persistent', namespace='request'
|
||||
)
|
||||
document_embeddings_response_queue = topic(
|
||||
'doc-embeddings-response', kind='non-persistent', namespace='response',
|
||||
)
|
||||
69
trustgraph/schema/graph.py
Normal file
69
trustgraph/schema/graph.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
|
||||
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double
|
||||
|
||||
from . documents import Source
|
||||
from . types import Error, Value
|
||||
from . topic import topic
|
||||
|
||||
############################################################################
|
||||
|
||||
# Graph embeddings are embeddings associated with a graph entity
|
||||
|
||||
class GraphEmbeddings(Record):
|
||||
source = Source()
|
||||
vectors = Array(Array(Double()))
|
||||
entity = Value()
|
||||
|
||||
graph_embeddings_store_queue = topic('graph-embeddings-store')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Graph embeddings query
|
||||
|
||||
class GraphEmbeddingsRequest(Record):
|
||||
vectors = Array(Array(Double()))
|
||||
limit = Integer()
|
||||
|
||||
class GraphEmbeddingsResponse(Record):
|
||||
error = Error()
|
||||
entities = Array(Value())
|
||||
|
||||
graph_embeddings_request_queue = topic(
|
||||
'graph-embeddings', kind='non-persistent', namespace='request'
|
||||
)
|
||||
graph_embeddings_response_queue = topic(
|
||||
'graph-embeddings-response', kind='non-persistent', namespace='response',
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
# Graph triples
|
||||
|
||||
class Triple(Record):
|
||||
source = Source()
|
||||
s = Value()
|
||||
p = Value()
|
||||
o = Value()
|
||||
|
||||
triples_store_queue = topic('triples-store')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Triples query
|
||||
|
||||
class TriplesQueryRequest(Record):
|
||||
s = Value()
|
||||
p = Value()
|
||||
o = Value()
|
||||
limit = Integer()
|
||||
|
||||
class TriplesQueryResponse(Record):
|
||||
error = Error()
|
||||
triples = Array(Triple())
|
||||
|
||||
triples_request_queue = topic(
|
||||
'triples', kind='non-persistent', namespace='request'
|
||||
)
|
||||
triples_response_queue = topic(
|
||||
'triples-response', kind='non-persistent', namespace='response',
|
||||
)
|
||||
41
trustgraph/schema/models.py
Normal file
41
trustgraph/schema/models.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
|
||||
from pulsar.schema import Record, String, Array, Double
|
||||
|
||||
from . topic import topic
|
||||
from . types import Error
|
||||
|
||||
############################################################################
|
||||
|
||||
# LLM text completion
|
||||
|
||||
class TextCompletionRequest(Record):
|
||||
prompt = String()
|
||||
|
||||
class TextCompletionResponse(Record):
|
||||
error = Error()
|
||||
response = String()
|
||||
|
||||
text_completion_request_queue = topic(
|
||||
'text-completion', kind='non-persistent', namespace='request'
|
||||
)
|
||||
text_completion_response_queue = topic(
|
||||
'text-completion-response', kind='non-persistent', namespace='response',
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
# Embeddings
|
||||
|
||||
class EmbeddingsRequest(Record):
|
||||
text = String()
|
||||
|
||||
class EmbeddingsResponse(Record):
|
||||
error = Error()
|
||||
vectors = Array(Array(Double()))
|
||||
|
||||
embeddings_request_queue = topic(
|
||||
'embeddings', kind='non-persistent', namespace='request'
|
||||
)
|
||||
embeddings_response_queue = topic(
|
||||
'embeddings-response', kind='non-persistent', namespace='response'
|
||||
)
|
||||
33
trustgraph/schema/object.py
Normal file
33
trustgraph/schema/object.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
|
||||
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array
|
||||
from pulsar.schema import Double, Map
|
||||
|
||||
from . documents import Source
|
||||
from . types import Value, RowSchema
|
||||
from . topic import topic
|
||||
|
||||
############################################################################
|
||||
|
||||
# Object embeddings are embeddings associated with the primary key of an
|
||||
# object
|
||||
|
||||
class ObjectEmbeddings(Record):
|
||||
source = Source()
|
||||
vectors = Array(Array(Double()))
|
||||
name = String()
|
||||
key_name = String()
|
||||
id = String()
|
||||
|
||||
object_embeddings_store_queue = topic('object-embeddings-store')
|
||||
|
||||
############################################################################
|
||||
|
||||
# Stores rows of information
|
||||
|
||||
class Rows(Record):
|
||||
source = Source()
|
||||
row_schema = RowSchema()
|
||||
rows = Array(Map(String()))
|
||||
|
||||
rows_store_queue = topic('rows-store')
|
||||
|
||||
60
trustgraph/schema/prompt.py
Normal file
60
trustgraph/schema/prompt.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
|
||||
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer
|
||||
|
||||
from . topic import topic
|
||||
from . types import Error, RowSchema
|
||||
|
||||
############################################################################
|
||||
|
||||
# Prompt services, abstract the prompt generation
|
||||
|
||||
class Definition(Record):
|
||||
name = String()
|
||||
definition = String()
|
||||
|
||||
class Relationship(Record):
|
||||
s = String()
|
||||
p = String()
|
||||
o = String()
|
||||
o_entity = Boolean()
|
||||
|
||||
class Fact(Record):
|
||||
s = String()
|
||||
p = String()
|
||||
o = String()
|
||||
|
||||
# extract-definitions:
|
||||
# chunk -> definitions
|
||||
# extract-relationships:
|
||||
# chunk -> relationships
|
||||
# kg-prompt:
|
||||
# query, triples -> answer
|
||||
# document-prompt:
|
||||
# query, documents -> answer
|
||||
# extract-rows
|
||||
# schema, chunk -> rows
|
||||
|
||||
class PromptRequest(Record):
|
||||
kind = String()
|
||||
chunk = String()
|
||||
query = String()
|
||||
kg = Array(Fact())
|
||||
documents = Array(Bytes())
|
||||
row_schema = RowSchema()
|
||||
|
||||
class PromptResponse(Record):
|
||||
error = Error()
|
||||
answer = String()
|
||||
definitions = Array(Definition())
|
||||
relationships = Array(Relationship())
|
||||
rows = Array(Map(String()))
|
||||
|
||||
prompt_request_queue = topic(
|
||||
'prompt', kind='non-persistent', namespace='request'
|
||||
)
|
||||
prompt_response_queue = topic(
|
||||
'prompt-response', kind='non-persistent', namespace='response'
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
40
trustgraph/schema/retrieval.py
Normal file
40
trustgraph/schema/retrieval.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
|
||||
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double
|
||||
from . topic import topic
|
||||
from . types import Error, Value
|
||||
|
||||
############################################################################
|
||||
|
||||
# Graph RAG text retrieval
|
||||
|
||||
class GraphRagQuery(Record):
|
||||
query = String()
|
||||
|
||||
class GraphRagResponse(Record):
|
||||
error = Error()
|
||||
response = String()
|
||||
|
||||
graph_rag_request_queue = topic(
|
||||
'graph-rag', kind='non-persistent', namespace='request'
|
||||
)
|
||||
graph_rag_response_queue = topic(
|
||||
'graph-rag-response', kind='non-persistent', namespace='response'
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
# Document RAG text retrieval
|
||||
|
||||
class DocumentRagQuery(Record):
|
||||
query = String()
|
||||
|
||||
class DocumentRagResponse(Record):
|
||||
error = Error()
|
||||
response = String()
|
||||
|
||||
document_rag_request_queue = topic(
|
||||
'doc-rag', kind='non-persistent', namespace='request'
|
||||
)
|
||||
document_rag_response_queue = topic(
|
||||
'doc-rag-response', kind='non-persistent', namespace='response'
|
||||
)
|
||||
4
trustgraph/schema/topic.py
Normal file
4
trustgraph/schema/topic.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
|
||||
def topic(topic, kind='persistent', tenant='tg', namespace='flow'):
|
||||
return f"{kind}://{tenant}/{namespace}/{topic}"
|
||||
|
||||
25
trustgraph/schema/types.py
Normal file
25
trustgraph/schema/types.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
|
||||
from pulsar.schema import Record, String, Boolean, Array, Integer
|
||||
|
||||
class Error(Record):
|
||||
type = String()
|
||||
message = String()
|
||||
|
||||
class Value(Record):
|
||||
value = String()
|
||||
is_uri = Boolean()
|
||||
type = String()
|
||||
|
||||
class Field(Record):
|
||||
name = String()
|
||||
# int, string, long, bool, float, double
|
||||
type = String()
|
||||
size = Integer()
|
||||
primary = Boolean()
|
||||
description = String()
|
||||
|
||||
class RowSchema(Record):
|
||||
name = String()
|
||||
description = String()
|
||||
fields = Array(Field())
|
||||
|
||||
0
trustgraph/storage/object_embeddings/__init__.py
Normal file
0
trustgraph/storage/object_embeddings/__init__.py
Normal file
3
trustgraph/storage/object_embeddings/milvus/__init__.py
Normal file
3
trustgraph/storage/object_embeddings/milvus/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . write import *
|
||||
|
||||
7
trustgraph/storage/object_embeddings/milvus/__main__.py
Executable file
7
trustgraph/storage/object_embeddings/milvus/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . write import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
61
trustgraph/storage/object_embeddings/milvus/write.py
Executable file
61
trustgraph/storage/object_embeddings/milvus/write.py
Executable file
|
|
@ -0,0 +1,61 @@
|
|||
|
||||
"""
|
||||
Accepts entity/vector pairs and writes them to a Milvus store.
|
||||
"""
|
||||
|
||||
from .... schema import ObjectEmbeddings
|
||||
from .... schema import object_embeddings_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... direct.milvus_object_embeddings import ObjectVectors
|
||||
from .... base import Consumer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = object_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ObjectEmbeddings,
|
||||
"store_uri": store_uri,
|
||||
}
|
||||
)
|
||||
|
||||
self.vecstore = ObjectVectors(store_uri)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
if v.id != "" and v.id is not None:
|
||||
for vec in v.vectors:
|
||||
self.vecstore.insert(vec, v.name, v.key_name, v.id)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
default=default_store_uri,
|
||||
help=f'Milvus store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
0
trustgraph/storage/rows/__init__.py
Normal file
0
trustgraph/storage/rows/__init__.py
Normal file
3
trustgraph/storage/rows/cassandra/__init__.py
Normal file
3
trustgraph/storage/rows/cassandra/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . write import *
|
||||
|
||||
7
trustgraph/storage/rows/cassandra/__main__.py
Executable file
7
trustgraph/storage/rows/cassandra/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . write import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
127
trustgraph/storage/rows/cassandra/write.py
Executable file
127
trustgraph/storage/rows/cassandra/write.py
Executable file
|
|
@ -0,0 +1,127 @@
|
|||
|
||||
"""
|
||||
Graph writer. Input is graph edge. Writes edges to Cassandra graph.
|
||||
"""
|
||||
|
||||
import pulsar
|
||||
import base64
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
|
||||
from .... schema import Rows
|
||||
from .... schema import rows_store_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import Consumer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = rows_store_queue
|
||||
default_subscriber = module
|
||||
default_graph_host='localhost'
|
||||
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
graph_host = params.get("graph_host", default_graph_host)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Rows,
|
||||
"graph_host": graph_host,
|
||||
}
|
||||
)
|
||||
|
||||
self.cluster = Cluster(graph_host.split(","))
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
self.tables = set()
|
||||
|
||||
self.session.execute("""
|
||||
create keyspace if not exists trustgraph
|
||||
with replication = {
|
||||
'class' : 'SimpleStrategy',
|
||||
'replication_factor' : 1
|
||||
};
|
||||
""");
|
||||
|
||||
self.session.execute("use trustgraph");
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
name = v.row_schema.name
|
||||
|
||||
if name not in self.tables:
|
||||
|
||||
# FIXME: SQL injection?
|
||||
|
||||
pkey = []
|
||||
|
||||
stmt = "create table if not exists " + name + " ( "
|
||||
|
||||
for field in v.row_schema.fields:
|
||||
|
||||
stmt += field.name + " text, "
|
||||
|
||||
if field.primary:
|
||||
pkey.append(field.name)
|
||||
|
||||
stmt += "PRIMARY KEY (" + ", ".join(pkey) + "));"
|
||||
|
||||
self.session.execute(stmt)
|
||||
|
||||
self.tables.add(name);
|
||||
|
||||
for row in v.rows:
|
||||
|
||||
field_names = []
|
||||
values = []
|
||||
|
||||
for field in v.row_schema.fields:
|
||||
field_names.append(field.name)
|
||||
values.append(row[field.name])
|
||||
|
||||
# FIXME: SQL injection?
|
||||
stmt = (
|
||||
"insert into " + name + " (" + ", ".join(field_names) +
|
||||
") values (" + ",".join(["%s"] * len(values)) + ")"
|
||||
)
|
||||
|
||||
self.session.execute(stmt, values)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", str(e), flush=True)
|
||||
|
||||
# If there's an error make sure to do table creation etc.
|
||||
self.tables.remove(name)
|
||||
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph-host',
|
||||
default="localhost",
|
||||
help=f'Graph host (default: localhost)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue