Extract rows and apply object embeddings (#42)

* - Restructured the extract directories
- Added an extractor for 'rows' == a row of a table
- Added a row extractor prompt to prompter.
* Add row support to template prompter
* Row extraction working
* Bump version
* Emit extracted info
* Object embeddings store
* Invocation script
* Add script to package, remove cruft output
* Write rows to Cassandra
* Remove output cruft
This commit is contained in:
cybermaggedon 2024-08-27 21:55:12 +01:00 committed by GitHub
parent b574ba26a8
commit e4c4774b5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
70 changed files with 1624 additions and 520 deletions

View file

@ -116,4 +116,4 @@ class BaseProcessor:
print("Exception:", e, flush=True)
print("Will retry...", flush=True)
time.sleep(10)
time.sleep(4)

View file

@ -10,6 +10,7 @@ class Consumer(BaseProcessor):
def __init__(self, **params):
print("HERE2")
super(Consumer, self).__init__(**params)
input_queue = params.get("input_queue")
@ -29,6 +30,7 @@ class Consumer(BaseProcessor):
'pubsub', 'Pub/sub configuration'
)
print("HERE")
if not hasattr(__class__, "processing_metric"):
__class__.processing_metric = Counter(
'processing_count', 'Processing count', ["status"]
@ -65,7 +67,7 @@ class Consumer(BaseProcessor):
self.consumer.negative_acknowledge(msg)
print("TooManyRequests: will retry")
__class__.processing_metric.labels(status="rate-limit").inc()
time.sleep(2)
time.sleep(5)
continue
except Exception as e:

View file

@ -84,7 +84,7 @@ class ConsumerProducer(BaseProcessor):
self.consumer.negative_acknowledge(msg)
print("TooManyRequests: will retry")
__class__.processing_metric.labels(status="rate-limit").inc()
time.sleep(2)
time.sleep(5)
continue
except Exception as e:

View file

@ -1,7 +1,7 @@
import _pulsar
from .. schema import PromptRequest, PromptResponse, Fact
from .. schema import PromptRequest, PromptResponse, Fact, RowSchema, Field
from .. schema import prompt_request_queue
from .. schema import prompt_response_queue
from . base import BaseClient
@ -52,6 +52,24 @@ class PromptClient(BaseClient):
timeout=timeout
).relationships
def request_rows(self, schema, chunk, timeout=300):
return self.call(
kind="extract-rows", chunk=chunk,
row_schema=RowSchema(
name=schema.name,
description=schema.description,
fields=[
Field(
name=f.name, type=str(f.type), size=f.size,
primary=f.primary, description=f.description,
)
for f in schema.fields
]
),
timeout=timeout
).rows
def request_kg_prompt(self, query, kg, timeout=300):
return self.call(

View file

@ -0,0 +1,154 @@
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
import time
class ObjectVectors:
def __init__(self, uri="http://localhost:19530", prefix='obj'):
self.client = MilvusClient(uri=uri)
# Strategy is to create collections per dimension. Probably only
# going to be using 1 anyway, but that means we don't need to
# hard-code the dimension anywhere, and no big deal if more than
# one are created.
self.collections = {}
self.prefix = prefix
# Time between reloads
self.reload_time = 90
# Next time to reload - this forces a reload at next window
self.next_reload = time.time() + self.reload_time
print("Reload at", self.next_reload)
def init_collection(self, dimension, name):
collection_name = self.prefix + "_" + name + "_" + str(dimension)
pkey_field = FieldSchema(
name="id",
dtype=DataType.INT64,
is_primary=True,
auto_id=True,
)
vec_field = FieldSchema(
name="vector",
dtype=DataType.FLOAT_VECTOR,
dim=dimension,
)
name_field = FieldSchema(
name="name",
dtype=DataType.VARCHAR,
max_length=65535,
)
key_name_field = FieldSchema(
name="key_name",
dtype=DataType.VARCHAR,
max_length=65535,
)
key_field = FieldSchema(
name="key",
dtype=DataType.VARCHAR,
max_length=65535,
)
schema = CollectionSchema(
fields = [
pkey_field, vec_field, name_field, key_name_field, key_field
],
description = "Object embedding schema",
)
self.client.create_collection(
collection_name=collection_name,
schema=schema,
metric_type="COSINE",
)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE",
index_type="IVF_SQ8",
index_name="vector_index",
params={ "nlist": 128 }
)
self.client.create_index(
collection_name=collection_name,
index_params=index_params
)
self.collections[(dimension, name)] = collection_name
def insert(self, embeds, name, key_name, key):
dim = len(embeds)
if (dim, name) not in self.collections:
self.init_collection(dim, name)
data = [
{
"vector": embeds,
"name": name,
"key_name": key_name,
"key": key,
}
]
self.client.insert(
collection_name=self.collections[(dim, name)],
data=data
)
def search(self, embeds, name, fields=["key_name", "name"], limit=10):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim, name)
coll = self.collections[(dim, name)]
search_params = {
"metric_type": "COSINE",
"params": {
"radius": 0.1,
"range_filter": 0.8
}
}
print("Loading...")
self.client.load_collection(
collection_name=coll,
)
print("Searching...")
res = self.client.search(
collection_name=coll,
data=[embeds],
limit=limit,
output_fields=fields,
search_params=search_params,
)[0]
# If reload time has passed, unload collection
if time.time() > self.next_reload:
print("Unloading, reload at", self.next_reload)
self.client.release_collection(
collection_name=coll,
)
self.next_reload = time.time() + self.reload_time
return res

View file

View file

@ -7,14 +7,14 @@ get entity definitions which are output as graph edges.
import urllib.parse
import json
from ... schema import ChunkEmbeddings, Triple, Source, Value
from ... schema import chunk_embeddings_ingest_queue, triples_store_queue
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from ... log_level import LogLevel
from ... clients.prompt_client import PromptClient
from ... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
from ... base import ConsumerProducer
from .... schema import ChunkEmbeddings, Triple, Source, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
from .... base import ConsumerProducer
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)

View file

@ -9,15 +9,15 @@ import urllib.parse
import os
from pulsar.schema import JsonSchema
from ... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
from ... schema import chunk_embeddings_ingest_queue, triples_store_queue
from ... schema import graph_embeddings_store_queue
from ... schema import prompt_request_queue
from ... schema import prompt_response_queue
from ... log_level import LogLevel
from ... clients.prompt_client import PromptClient
from ... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES
from ... base import ConsumerProducer
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import graph_embeddings_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES
from .... base import ConsumerProducer
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)

View file

View file

@ -0,0 +1,3 @@
from . extract import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . extract import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,220 @@
"""
Simple decoder, accepts vector+text chunks input, applies analysis to pull
out a row of fields. Output as a vector plus object.
"""
import urllib.parse
import os
from pulsar.schema import JsonSchema
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Source
from .... schema import RowSchema, Field
from .... schema import chunk_embeddings_ingest_queue, rows_store_queue
from .... schema import object_embeddings_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... base import ConsumerProducer
from .... objects.field import Field as FieldParser
from .... objects.object import Schema
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue
default_output_queue = rows_store_queue
default_vector_queue = object_embeddings_store_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
vector_queue = params.get("vector_queue", default_vector_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings,
"output_schema": Rows,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
}
)
self.vec_prod = self.client.create_producer(
topic=vector_queue,
schema=JsonSchema(ObjectEmbeddings),
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"output_queue": output_queue,
"vector_queue": vector_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings.__name__,
"output_schema": Rows.__name__,
"vector_schema": ObjectEmbeddings.__name__,
})
flds = __class__.parse_fields(params["field"])
for fld in flds:
print(fld)
self.primary = None
for f in flds:
if f.primary:
if self.primary:
raise RuntimeError(
"Only one primary key field is supported"
)
self.primary = f
if self.primary == None:
raise RuntimeError(
"Must have exactly one primary key field"
)
self.schema = Schema(
name = params["name"],
description = params["description"],
fields = flds
)
self.row_schema=RowSchema(
name=self.schema.name,
description=self.schema.description,
fields=[
Field(
name=f.name, type=str(f.type), size=f.size,
primary=f.primary, description=f.description,
)
for f in self.schema.fields
]
)
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",
)
@staticmethod
def parse_fields(fields):
return [ FieldParser.parse(f) for f in fields ]
def get_rows(self, chunk):
return self.prompt.request_rows(self.schema, chunk)
def emit_rows(self, source, rows):
t = Rows(
source=source, row_schema=self.row_schema, rows=rows
)
self.producer.send(t)
def emit_vec(self, source, name, vec, key_name, key):
r = ObjectEmbeddings(
source=source, vectors=vec, name=name, key_name=key_name, id=key
)
self.vec_prod.send(r)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
try:
rows = self.get_rows(chunk)
self.emit_rows(
source=v.source,
rows=rows
)
for row in rows:
self.emit_vec(
source=v.source, vec=v.vectors,
name=self.schema.name, key_name=self.primary.name,
key=row[self.primary.name]
)
for row in rows:
print(row)
except Exception as e:
print("Exception: ", e, flush=True)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-c', '--vector-queue',
default=default_vector_queue,
help=f'Vector output queue (default: {default_vector_queue})'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
parser.add_argument(
'-f', '--field',
required=True,
action='append',
help=f'Field definition, format name:type:size:pri:descriptionn',
)
parser.add_argument(
'-n', '--name',
required=True,
help=f'Name of row object',
)
parser.add_argument(
'-d', '--description',
required=True,
help=f'Description of object',
)
def run():
Processor.start(module, __doc__)

View file

@ -48,6 +48,44 @@ or headers or prefixes. Do not include null or unknown definitions.
return prompt
def to_rows(schema, text):
field_schema = [
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
for f in schema.fields
]
field_schema = "\n".join(field_schema)
schema = f"""Object name: {schema.name}
Description: {schema.description}
Fields:
{field_schema}"""
prompt = f"""<instructions>
Study the following text and derive objects which match the schema provided.
You must output an array of JSON objects for each object you discover
which matches the schema. For each object, output a JSON object whose fields
carry the name field specified in the schema.
</instructions>
<schema>
{schema}
</schema>
<text>
{text}
</text>
<requirements>
You will respond only with raw JSON format data. Do not provide
explanations. Do not add markdown formatting or headers or prefixes.
</requirements>"""
return prompt
def get_cypher(kg):
sg2 = []

View file

@ -14,7 +14,7 @@ from .... base import ConsumerProducer
from .... clients.llm_client import LlmClient
from . prompts import to_definitions, to_relationships
from . prompts import to_kg_query, to_document_query
from . prompts import to_kg_query, to_document_query, to_rows
module = ".".join(__name__.split(".")[1:-1])
@ -77,6 +77,11 @@ class Processor(ConsumerProducer):
self.handle_extract_relationships(id, v)
return
elif kind == "extract-rows":
self.handle_extract_rows(id, v)
return
elif kind == "kg-prompt":
self.handle_kg_prompt(id, v)
@ -222,6 +227,77 @@ class Processor(ConsumerProducer):
)
self.producer.send(r, properties={"id": id})
def handle_extract_rows(self, id, v):
try:
fields = v.row_schema.fields
prompt = to_rows(v.row_schema, v.chunk)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
# Silently ignore JSON parse error
try:
objs = json.loads(ans)
except:
print("JSON parse error, ignored", flush=True)
objs = []
output = []
for obj in objs:
try:
row = {}
for f in fields:
if f.name not in obj:
print(f"Object ignored, missing field {f.name}")
row = {}
break
row[f.name] = obj[f.name]
if row == {}:
continue
output.append(row)
except Exception as e:
print("row fields missing, ignored", flush=True)
for row in output:
print(row)
print("Send response...", flush=True)
r = PromptResponse(rows=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_kg_prompt(self, id, v):

View file

@ -5,6 +5,27 @@ def to_relationships(template, text):
def to_definitions(template, text):
return template.format(text=text)
def to_rows(template, schema, text):
field_schema = [
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
for f in schema.fields
]
field_schema = "\n".join(field_schema)
return template.format(schema=schema, text=text)
schema = f"""Object name: {schema.name}
Description: {schema.description}
Fields:
{schema}"""
prompt = f""""""
return prompt
def get_cypher(kg):
sg2 = []
for f in kg:

View file

@ -14,7 +14,7 @@ from .... schema import prompt_request_queue, prompt_response_queue
from .... base import ConsumerProducer
from .... clients.llm_client import LlmClient
from . prompts import to_definitions, to_relationships
from . prompts import to_definitions, to_relationships, to_rows
from . prompts import to_kg_query, to_document_query
module = ".".join(__name__.split(".")[1:-1])
@ -38,6 +38,7 @@ class Processor(ConsumerProducer):
)
definition_template = params.get("definition_template")
relationship_template = params.get("relationship_template")
rows_template = params.get("rows_template")
knowledge_query_template = params.get("knowledge_query_template")
document_query_template = params.get("document_query_template")
@ -62,6 +63,7 @@ class Processor(ConsumerProducer):
self.definition_template = definition_template
self.relationship_template = relationship_template
self.rows_template = rows_template
self.knowledge_query_template = knowledge_query_template
self.document_query_template = document_query_template
@ -87,6 +89,11 @@ class Processor(ConsumerProducer):
self.handle_extract_relationships(id, v)
return
elif kind == "extract-rows":
self.handle_extract_rows(id, v)
return
elif kind == "kg-prompt":
self.handle_kg_prompt(id, v)
@ -232,6 +239,77 @@ class Processor(ConsumerProducer):
)
self.producer.send(r, properties={"id": id})
def handle_extract_rows(self, id, v):
try:
fields = v.row_schema.fields
prompt = to_rows(self.rows_template, v.row_schema, v.chunk)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
# Silently ignore JSON parse error
try:
objs = json.loads(ans)
except:
print("JSON parse error, ignored", flush=True)
objs = []
output = []
for obj in objs:
try:
row = {}
for f in fields:
if f.name not in obj:
print(f"Object ignored, missing field {f.name}")
row = {}
break
row[f.name] = obj[f.name]
if row == {}:
continue
output.append(row)
except Exception as e:
print("row fields missing, ignored", flush=True)
for row in output:
print(row)
print("Send response...", flush=True)
r = PromptResponse(rows=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_kg_prompt(self, id, v):
@ -329,6 +407,12 @@ class Processor(ConsumerProducer):
help=f'Definition extraction template',
)
parser.add_argument(
'--rows-template',
required=True,
help=f'Rows extraction template',
)
parser.add_argument(
'--relationship-template',
required=True,

View file

View file

@ -0,0 +1,72 @@
from dataclasses import dataclass
from enum import Enum
class FieldType(Enum):
STRING = 0
INT = 1
LONG = 2
BOOL = 3
FLOAT = 4
DOUBLE = 5
def __str__(self):
return self.name.lower()
@dataclass
class Field:
name: str
size: int = -1
primary: bool = False
type: str = "undefined"
description: str = ""
@staticmethod
def parse(defn):
if defn == "" or defn is None:
raise RuntimeError("Field definition cannot be empty")
parts = defn.split(":")
if len(parts) == 0:
raise RuntimeError("Field definition cannot be empty")
if len(parts) == 1: parts.append("string")
if len(parts) == 2: parts.append("0")
if len(parts) == 3: parts.append("")
if len(parts) == 4: parts.append("")
name, type, size, pri, description = parts
size = int(size)
try:
type = FieldType[type.upper()]
except:
raise RuntimeError(f"Field type {type} is not known")
pri = True if pri == "pri" else False
return Field(
name=name, type=type, size=size, primary=pri,
description=description
)
def __repr__(self):
name = self.name
type = self.type
size = self.size
pri = "pri" if self.primary else ""
description = self.description
return f"{name}:{type}:{size}:{pri}:{description}"
def __str__(self):
name = self.name
type = self.type
size = self.size
pri = "pri" if self.primary else ""
description = self.description
return f"{name}:{type}:{size}:{pri}:{description}"

View file

@ -0,0 +1,8 @@
class Schema:
def __init__(self, name, description, fields):
self.name = name
self.description = description
self.fields = fields

View file

@ -1,274 +0,0 @@
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double
from enum import Enum
def topic(topic, kind='persistent', tenant='tg', namespace='flow'):
return f"{kind}://{tenant}/{namespace}/{topic}"
############################################################################
class Error(Record):
type = String()
message = String()
############################################################################
class Value(Record):
value = String()
is_uri = Boolean()
type = String()
class Source(Record):
source = String()
id = String()
title = String()
############################################################################
# PDF docs etc.
class Document(Record):
source = Source()
data = Bytes()
document_ingest_queue = topic('document-load')
############################################################################
# Text documents / text from PDF
class TextDocument(Record):
source = Source()
text = Bytes()
text_ingest_queue = topic('text-document-load')
############################################################################
# Chunks of text
class Chunk(Record):
source = Source()
chunk = Bytes()
chunk_ingest_queue = topic('chunk-load')
############################################################################
# Chunk embeddings are an embeddings associated with a text chunk
class ChunkEmbeddings(Record):
source = Source()
vectors = Array(Array(Double()))
chunk = Bytes()
chunk_embeddings_ingest_queue = topic('chunk-embeddings-load')
############################################################################
# Graph embeddings are embeddings associated with a graph entity
class GraphEmbeddings(Record):
source = Source()
vectors = Array(Array(Double()))
entity = Value()
graph_embeddings_store_queue = topic('graph-embeddings-store')
############################################################################
# Graph embeddings query
class GraphEmbeddingsRequest(Record):
vectors = Array(Array(Double()))
limit = Integer()
class GraphEmbeddingsResponse(Record):
error = Error()
entities = Array(Value())
graph_embeddings_request_queue = topic(
'graph-embeddings', kind='non-persistent', namespace='request'
)
graph_embeddings_response_queue = topic(
'graph-embeddings-response', kind='non-persistent', namespace='response',
)
############################################################################
# Doc embeddings query
class DocumentEmbeddingsRequest(Record):
vectors = Array(Array(Double()))
limit = Integer()
class DocumentEmbeddingsResponse(Record):
error = Error()
documents = Array(Bytes())
document_embeddings_request_queue = topic(
'doc-embeddings', kind='non-persistent', namespace='request'
)
document_embeddings_response_queue = topic(
'doc-embeddings-response', kind='non-persistent', namespace='response',
)
############################################################################
# Graph triples
class Triple(Record):
source = Source()
s = Value()
p = Value()
o = Value()
triples_store_queue = topic('triples-store')
############################################################################
# Triples query
class TriplesQueryRequest(Record):
s = Value()
p = Value()
o = Value()
limit = Integer()
class TriplesQueryResponse(Record):
error = Error()
triples = Array(Triple())
triples_request_queue = topic(
'triples', kind='non-persistent', namespace='request'
)
triples_response_queue = topic(
'triples-response', kind='non-persistent', namespace='response',
)
############################################################################
# chunk_embeddings_store_queue = topic('chunk-embeddings-store')
############################################################################
# LLM text completion
class TextCompletionRequest(Record):
prompt = String()
class TextCompletionResponse(Record):
error = Error()
response = String()
text_completion_request_queue = topic(
'text-completion', kind='non-persistent', namespace='request'
)
text_completion_response_queue = topic(
'text-completion-response', kind='non-persistent', namespace='response',
)
############################################################################
# Embeddings
class EmbeddingsRequest(Record):
text = String()
class EmbeddingsResponse(Record):
error = Error()
vectors = Array(Array(Double()))
embeddings_request_queue = topic(
'embeddings', kind='non-persistent', namespace='request'
)
embeddings_response_queue = topic(
'embeddings-response', kind='non-persistent', namespace='response'
)
############################################################################
# Graph RAG text retrieval
class GraphRagQuery(Record):
query = String()
class GraphRagResponse(Record):
error = Error()
response = String()
graph_rag_request_queue = topic(
'graph-rag', kind='non-persistent', namespace='request'
)
graph_rag_response_queue = topic(
'graph-rag-response', kind='non-persistent', namespace='response'
)
############################################################################
# Document RAG text retrieval
class DocumentRagQuery(Record):
query = String()
class DocumentRagResponse(Record):
error = Error()
response = String()
document_rag_request_queue = topic(
'doc-rag', kind='non-persistent', namespace='request'
)
document_rag_response_queue = topic(
'doc-rag-response', kind='non-persistent', namespace='response'
)
############################################################################
# Prompt services, abstract the prompt generation
class Definition(Record):
name = String()
definition = String()
class Relationship(Record):
s = String()
p = String()
o = String()
o_entity = Boolean()
class Fact(Record):
s = String()
p = String()
o = String()
# extract-definitions:
# chunk -> definitions
# extract-relationships:
# chunk -> relationships
# kg-prompt:
# query, triples -> answer
# document-prompt:
# query, documents -> answer
class PromptRequest(Record):
kind = String()
chunk = String()
query = String()
kg = Array(Fact())
documents = Array(Bytes())
class PromptResponse(Record):
error = Error()
answer = String()
definitions = Array(Definition())
relationships = Array(Relationship())
prompt_request_queue = topic(
'prompt', kind='non-persistent', namespace='request'
)
prompt_response_queue = topic(
'prompt-response', kind='non-persistent', namespace='response'
)
############################################################################

View file

@ -0,0 +1,12 @@
from . types import *
from . prompt import *
from . documents import *
from . models import *
from . object import *
from . topic import *
from . graph import *
from . retrieval import *

View file

@ -0,0 +1,68 @@
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double
from . topic import topic
from . types import Error
class Source(Record):
source = String()
id = String()
title = String()
############################################################################
# PDF docs etc.
class Document(Record):
source = Source()
data = Bytes()
document_ingest_queue = topic('document-load')
############################################################################
# Text documents / text from PDF
class TextDocument(Record):
source = Source()
text = Bytes()
text_ingest_queue = topic('text-document-load')
############################################################################
# Chunks of text
class Chunk(Record):
source = Source()
chunk = Bytes()
chunk_ingest_queue = topic('chunk-load')
############################################################################
# Chunk embeddings are an embeddings associated with a text chunk
class ChunkEmbeddings(Record):
source = Source()
vectors = Array(Array(Double()))
chunk = Bytes()
chunk_embeddings_ingest_queue = topic('chunk-embeddings-load')
############################################################################
# Doc embeddings query
class DocumentEmbeddingsRequest(Record):
vectors = Array(Array(Double()))
limit = Integer()
class DocumentEmbeddingsResponse(Record):
error = Error()
documents = Array(Bytes())
document_embeddings_request_queue = topic(
'doc-embeddings', kind='non-persistent', namespace='request'
)
document_embeddings_response_queue = topic(
'doc-embeddings-response', kind='non-persistent', namespace='response',
)

View file

@ -0,0 +1,69 @@
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double
from . documents import Source
from . types import Error, Value
from . topic import topic
############################################################################
# Graph embeddings are embeddings associated with a graph entity
class GraphEmbeddings(Record):
source = Source()
vectors = Array(Array(Double()))
entity = Value()
graph_embeddings_store_queue = topic('graph-embeddings-store')
############################################################################
# Graph embeddings query
class GraphEmbeddingsRequest(Record):
vectors = Array(Array(Double()))
limit = Integer()
class GraphEmbeddingsResponse(Record):
error = Error()
entities = Array(Value())
graph_embeddings_request_queue = topic(
'graph-embeddings', kind='non-persistent', namespace='request'
)
graph_embeddings_response_queue = topic(
'graph-embeddings-response', kind='non-persistent', namespace='response',
)
############################################################################
# Graph triples
class Triple(Record):
source = Source()
s = Value()
p = Value()
o = Value()
triples_store_queue = topic('triples-store')
############################################################################
# Triples query
class TriplesQueryRequest(Record):
s = Value()
p = Value()
o = Value()
limit = Integer()
class TriplesQueryResponse(Record):
error = Error()
triples = Array(Triple())
triples_request_queue = topic(
'triples', kind='non-persistent', namespace='request'
)
triples_response_queue = topic(
'triples-response', kind='non-persistent', namespace='response',
)

View file

@ -0,0 +1,41 @@
from pulsar.schema import Record, String, Array, Double
from . topic import topic
from . types import Error
############################################################################
# LLM text completion
class TextCompletionRequest(Record):
prompt = String()
class TextCompletionResponse(Record):
error = Error()
response = String()
text_completion_request_queue = topic(
'text-completion', kind='non-persistent', namespace='request'
)
text_completion_response_queue = topic(
'text-completion-response', kind='non-persistent', namespace='response',
)
############################################################################
# Embeddings
class EmbeddingsRequest(Record):
text = String()
class EmbeddingsResponse(Record):
error = Error()
vectors = Array(Array(Double()))
embeddings_request_queue = topic(
'embeddings', kind='non-persistent', namespace='request'
)
embeddings_response_queue = topic(
'embeddings-response', kind='non-persistent', namespace='response'
)

View file

@ -0,0 +1,33 @@
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array
from pulsar.schema import Double, Map
from . documents import Source
from . types import Value, RowSchema
from . topic import topic
############################################################################
# Object embeddings are embeddings associated with the primary key of an
# object
class ObjectEmbeddings(Record):
source = Source()
vectors = Array(Array(Double()))
name = String()
key_name = String()
id = String()
object_embeddings_store_queue = topic('object-embeddings-store')
############################################################################
# Stores rows of information
class Rows(Record):
source = Source()
row_schema = RowSchema()
rows = Array(Map(String()))
rows_store_queue = topic('rows-store')

View file

@ -0,0 +1,60 @@
from pulsar.schema import Record, Bytes, String, Boolean, Array, Map, Integer
from . topic import topic
from . types import Error, RowSchema
############################################################################
# Prompt services, abstract the prompt generation
class Definition(Record):
name = String()
definition = String()
class Relationship(Record):
s = String()
p = String()
o = String()
o_entity = Boolean()
class Fact(Record):
s = String()
p = String()
o = String()
# extract-definitions:
# chunk -> definitions
# extract-relationships:
# chunk -> relationships
# kg-prompt:
# query, triples -> answer
# document-prompt:
# query, documents -> answer
# extract-rows
# schema, chunk -> rows
class PromptRequest(Record):
kind = String()
chunk = String()
query = String()
kg = Array(Fact())
documents = Array(Bytes())
row_schema = RowSchema()
class PromptResponse(Record):
error = Error()
answer = String()
definitions = Array(Definition())
relationships = Array(Relationship())
rows = Array(Map(String()))
prompt_request_queue = topic(
'prompt', kind='non-persistent', namespace='request'
)
prompt_response_queue = topic(
'prompt-response', kind='non-persistent', namespace='response'
)
############################################################################

View file

@ -0,0 +1,40 @@
from pulsar.schema import Record, Bytes, String, Boolean, Integer, Array, Double
from . topic import topic
from . types import Error, Value
############################################################################
# Graph RAG text retrieval
class GraphRagQuery(Record):
query = String()
class GraphRagResponse(Record):
error = Error()
response = String()
graph_rag_request_queue = topic(
'graph-rag', kind='non-persistent', namespace='request'
)
graph_rag_response_queue = topic(
'graph-rag-response', kind='non-persistent', namespace='response'
)
############################################################################
# Document RAG text retrieval
class DocumentRagQuery(Record):
query = String()
class DocumentRagResponse(Record):
error = Error()
response = String()
document_rag_request_queue = topic(
'doc-rag', kind='non-persistent', namespace='request'
)
document_rag_response_queue = topic(
'doc-rag-response', kind='non-persistent', namespace='response'
)

View file

@ -0,0 +1,4 @@
def topic(topic, kind='persistent', tenant='tg', namespace='flow'):
return f"{kind}://{tenant}/{namespace}/{topic}"

View file

@ -0,0 +1,25 @@
from pulsar.schema import Record, String, Boolean, Array, Integer
class Error(Record):
type = String()
message = String()
class Value(Record):
value = String()
is_uri = Boolean()
type = String()
class Field(Record):
name = String()
# int, string, long, bool, float, double
type = String()
size = Integer()
primary = Boolean()
description = String()
class RowSchema(Record):
name = String()
description = String()
fields = Array(Field())

View file

@ -0,0 +1,3 @@
from . write import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . write import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,61 @@
"""
Accepts entity/vector pairs and writes them to a Milvus store.
"""
from .... schema import ObjectEmbeddings
from .... schema import object_embeddings_store_queue
from .... log_level import LogLevel
from .... direct.milvus_object_embeddings import ObjectVectors
from .... base import Consumer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = object_embeddings_store_queue
default_subscriber = module
default_store_uri = 'http://localhost:19530'
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
store_uri = params.get("store_uri", default_store_uri)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": ObjectEmbeddings,
"store_uri": store_uri,
}
)
self.vecstore = ObjectVectors(store_uri)
def handle(self, msg):
v = msg.value()
if v.id != "" and v.id is not None:
for vec in v.vectors:
self.vecstore.insert(vec, v.name, v.key_name, v.id)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-t', '--store-uri',
default=default_store_uri,
help=f'Milvus store URI (default: {default_store_uri})'
)
def run():
Processor.start(module, __doc__)

View file

View file

@ -0,0 +1,3 @@
from . write import *

View file

@ -0,0 +1,7 @@
#!/usr/bin/env python3
from . write import run
if __name__ == '__main__':
run()

View file

@ -0,0 +1,127 @@
"""
Graph writer. Input is graph edge. Writes edges to Cassandra graph.
"""
import pulsar
import base64
import os
import argparse
import time
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
from .... schema import Rows
from .... schema import rows_store_queue
from .... log_level import LogLevel
from .... base import Consumer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = rows_store_queue
default_subscriber = module
default_graph_host='localhost'
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
graph_host = params.get("graph_host", default_graph_host)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": Rows,
"graph_host": graph_host,
}
)
self.cluster = Cluster(graph_host.split(","))
self.session = self.cluster.connect()
self.tables = set()
self.session.execute("""
create keyspace if not exists trustgraph
with replication = {
'class' : 'SimpleStrategy',
'replication_factor' : 1
};
""");
self.session.execute("use trustgraph");
def handle(self, msg):
try:
v = msg.value()
name = v.row_schema.name
if name not in self.tables:
# FIXME: SQL injection?
pkey = []
stmt = "create table if not exists " + name + " ( "
for field in v.row_schema.fields:
stmt += field.name + " text, "
if field.primary:
pkey.append(field.name)
stmt += "PRIMARY KEY (" + ", ".join(pkey) + "));"
self.session.execute(stmt)
self.tables.add(name);
for row in v.rows:
field_names = []
values = []
for field in v.row_schema.fields:
field_names.append(field.name)
values.append(row[field.name])
# FIXME: SQL injection?
stmt = (
"insert into " + name + " (" + ", ".join(field_names) +
") values (" + ",".join(["%s"] * len(values)) + ")"
)
self.session.execute(stmt, values)
except Exception as e:
print("Exception:", str(e), flush=True)
# If there's an error make sure to do table creation etc.
self.tables.remove(name)
raise e
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-g', '--graph-host',
default="localhost",
help=f'Graph host (default: localhost)'
)
def run():
Processor.start(module, __doc__)