mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-16 19:05:14 +02:00
Feature/pkgsplit (#83)
* Starting to spawn base package * More package hacking * Bedrock and VertexAI * Parquet split * Updated templates * Utils
This commit is contained in:
parent
3fb75c617b
commit
9b91d5eee3
262 changed files with 630 additions and 420 deletions
0
trustgraph-flow/trustgraph/__init__.py
Normal file
0
trustgraph-flow/trustgraph/__init__.py
Normal file
0
trustgraph-flow/trustgraph/chunking/__init__.py
Normal file
0
trustgraph-flow/trustgraph/chunking/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . chunker import *
|
||||
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . chunker import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
108
trustgraph-flow/trustgraph/chunking/recursive/chunker.py
Executable file
108
trustgraph-flow/trustgraph/chunking/recursive/chunker.py
Executable file
|
|
@ -0,0 +1,108 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts text documents on input, outputs chunks from the
|
||||
as text as separate output objects.
|
||||
"""
|
||||
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk, Source
|
||||
from ... schema import text_ingest_queue, chunk_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_ingest_queue
|
||||
default_output_queue = chunk_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
chunk_size = params.get("chunk_size", 2000)
|
||||
chunk_overlap = params.get("chunk_overlap", 100)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextDocument,
|
||||
"output_schema": Chunk,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
'chunk_size', 'Chunk size',
|
||||
buckets=[100, 160, 250, 400, 650, 1000, 1600,
|
||||
2500, 4000, 6400, 10000, 16000]
|
||||
)
|
||||
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
is_separator_regex=False,
|
||||
)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Chunking {v.source.id}...", flush=True)
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
)
|
||||
|
||||
for ix, chunk in enumerate(texts):
|
||||
|
||||
id = v.source.id + "-c" + str(ix)
|
||||
|
||||
r = Chunk(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
id=id,
|
||||
title=v.source.title
|
||||
),
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
)
|
||||
|
||||
__class__.chunk_metric.observe(len(chunk.page_content))
|
||||
|
||||
self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--chunk-size',
|
||||
type=int,
|
||||
default=2000,
|
||||
help=f'Chunk size (default: 2000)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-v', '--chunk-overlap',
|
||||
type=int,
|
||||
default=100,
|
||||
help=f'Chunk overlap (default: 100)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
3
trustgraph-flow/trustgraph/chunking/token/__init__.py
Normal file
3
trustgraph-flow/trustgraph/chunking/token/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . chunker import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/chunking/token/__main__.py
Normal file
7
trustgraph-flow/trustgraph/chunking/token/__main__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . chunker import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
107
trustgraph-flow/trustgraph/chunking/token/chunker.py
Executable file
107
trustgraph-flow/trustgraph/chunking/token/chunker.py
Executable file
|
|
@ -0,0 +1,107 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts text documents on input, outputs chunks from the
|
||||
as text as separate output objects.
|
||||
"""
|
||||
|
||||
from langchain_text_splitters import TokenTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk, Source
|
||||
from ... schema import text_ingest_queue, chunk_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_ingest_queue
|
||||
default_output_queue = chunk_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
chunk_size = params.get("chunk_size", 250)
|
||||
chunk_overlap = params.get("chunk_overlap", 15)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextDocument,
|
||||
"output_schema": Chunk,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
'chunk_size', 'Chunk size',
|
||||
buckets=[100, 160, 250, 400, 650, 1000, 1600,
|
||||
2500, 4000, 6400, 10000, 16000]
|
||||
)
|
||||
|
||||
self.text_splitter = TokenTextSplitter(
|
||||
encoding_name="cl100k_base",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Chunking {v.source.id}...", flush=True)
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
)
|
||||
|
||||
for ix, chunk in enumerate(texts):
|
||||
|
||||
id = v.source.id + "-c" + str(ix)
|
||||
|
||||
r = Chunk(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
id=id,
|
||||
title=v.source.title
|
||||
),
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
)
|
||||
|
||||
__class__.chunk_metric.observe(len(chunk.page_content))
|
||||
|
||||
self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--chunk-size',
|
||||
type=int,
|
||||
default=250,
|
||||
help=f'Chunk size (default: 250)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-v', '--chunk-overlap',
|
||||
type=int,
|
||||
default=15,
|
||||
help=f'Chunk overlap (default: 15)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
0
trustgraph-flow/trustgraph/decoding/__init__.py
Normal file
0
trustgraph-flow/trustgraph/decoding/__init__.py
Normal file
3
trustgraph-flow/trustgraph/decoding/pdf/__init__.py
Normal file
3
trustgraph-flow/trustgraph/decoding/pdf/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . pdf_decoder import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/decoding/pdf/__main__.py
Executable file
7
trustgraph-flow/trustgraph/decoding/pdf/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . pdf_decoder import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
87
trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
Executable file
87
trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
Executable file
|
|
@ -0,0 +1,87 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts PDF documents on input, outputs pages from the
|
||||
PDF document as text as separate output objects.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import base64
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
|
||||
from ... schema import Document, TextDocument, Source
|
||||
from ... schema import document_ingest_queue, text_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_ingest_queue
|
||||
default_output_queue = text_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Document,
|
||||
"output_schema": TextDocument,
|
||||
}
|
||||
)
|
||||
|
||||
print("PDF inited")
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
print("PDF message received")
|
||||
|
||||
v = msg.value()
|
||||
|
||||
print(f"Decoding {v.source.id}...", flush=True)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
|
||||
|
||||
fp.write(base64.b64decode(v.data))
|
||||
fp.close()
|
||||
|
||||
with open(fp.name, mode='rb') as f:
|
||||
|
||||
loader = PyPDFLoader(fp.name)
|
||||
pages = loader.load()
|
||||
|
||||
for ix, page in enumerate(pages):
|
||||
|
||||
id = v.source.id + "-p" + str(ix)
|
||||
r = TextDocument(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
title=v.source.title,
|
||||
id=id,
|
||||
),
|
||||
text=page.page_content.encode("utf-8"),
|
||||
)
|
||||
|
||||
self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
0
trustgraph-flow/trustgraph/direct/__init__.py
Normal file
0
trustgraph-flow/trustgraph/direct/__init__.py
Normal file
108
trustgraph-flow/trustgraph/direct/cassandra.py
Normal file
108
trustgraph-flow/trustgraph/direct/cassandra.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
|
||||
class TrustGraph:
|
||||
|
||||
def __init__(self, hosts=None):
|
||||
|
||||
if hosts is None:
|
||||
hosts = ["localhost"]
|
||||
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
self.init()
|
||||
|
||||
def clear(self):
|
||||
|
||||
self.session.execute("""
|
||||
drop keyspace if exists trustgraph;
|
||||
""");
|
||||
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
|
||||
self.session.execute("""
|
||||
create keyspace if not exists trustgraph
|
||||
with replication = {
|
||||
'class' : 'SimpleStrategy',
|
||||
'replication_factor' : 1
|
||||
};
|
||||
""");
|
||||
|
||||
self.session.set_keyspace('trustgraph')
|
||||
|
||||
self.session.execute("""
|
||||
create table if not exists triples (
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
PRIMARY KEY (s, p, o)
|
||||
);
|
||||
""");
|
||||
|
||||
self.session.execute("""
|
||||
create index if not exists triples_p
|
||||
ON triples (p);
|
||||
""");
|
||||
|
||||
self.session.execute("""
|
||||
create index if not exists triples_o
|
||||
ON triples (o);
|
||||
""");
|
||||
|
||||
def insert(self, s, p, o):
|
||||
|
||||
self.session.execute(
|
||||
"insert into triples (s, p, o) values (%s, %s, %s)",
|
||||
(s, p, o)
|
||||
)
|
||||
|
||||
def get_all(self, limit=50):
|
||||
return self.session.execute(
|
||||
f"select s, p, o from triples limit {limit}"
|
||||
)
|
||||
|
||||
def get_s(self, s, limit=10):
|
||||
return self.session.execute(
|
||||
f"select p, o from triples where s = %s limit {limit}",
|
||||
(s,)
|
||||
)
|
||||
|
||||
def get_p(self, p, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s, o from triples where p = %s limit {limit}",
|
||||
(p,)
|
||||
)
|
||||
|
||||
def get_o(self, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s, p from triples where o = %s limit {limit}",
|
||||
(o,)
|
||||
)
|
||||
|
||||
def get_sp(self, s, p, limit=10):
|
||||
return self.session.execute(
|
||||
f"select o from triples where s = %s and p = %s limit {limit}",
|
||||
(s, p)
|
||||
)
|
||||
|
||||
def get_po(self, p, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s from triples where p = %s and o = %s allow filtering limit {limit}",
|
||||
(p, o)
|
||||
)
|
||||
|
||||
def get_os(self, o, s, limit=10):
|
||||
return self.session.execute(
|
||||
f"select p from triples where o = %s and s = %s limit {limit}",
|
||||
(o, s)
|
||||
)
|
||||
|
||||
def get_spo(self, s, p, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"""select s as x from triples where s = %s and p = %s and o = %s limit {limit}""",
|
||||
(s, p, o)
|
||||
)
|
||||
138
trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py
Normal file
138
trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
|
||||
class DocVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='doc'):
|
||||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
# Time between reloads
|
||||
self.reload_time = 90
|
||||
|
||||
# Next time to reload - this forces a reload at next window
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
print("Reload at", self.next_reload)
|
||||
|
||||
def init_collection(self, dimension):
|
||||
|
||||
collection_name = self.prefix + "_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
doc_field = FieldSchema(
|
||||
name="doc",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [pkey_field, vec_field, doc_field],
|
||||
description = "Document embedding schema",
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[dimension] = collection_name
|
||||
|
||||
def insert(self, embeds, doc):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"doc": doc,
|
||||
}
|
||||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[dim],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, fields=["doc"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
coll = self.collections[dim]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
print("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
print("Searching...")
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
|
||||
# If reload time has passed, unload collection
|
||||
if time.time() > self.next_reload:
|
||||
print("Unloading, reload at", self.next_reload)
|
||||
self.client.release_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
|
||||
return res
|
||||
|
||||
138
trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py
Normal file
138
trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
|
||||
class EntityVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='entity'):
|
||||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
# Time between reloads
|
||||
self.reload_time = 90
|
||||
|
||||
# Next time to reload - this forces a reload at next window
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
print("Reload at", self.next_reload)
|
||||
|
||||
def init_collection(self, dimension):
|
||||
|
||||
collection_name = self.prefix + "_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
entity_field = FieldSchema(
|
||||
name="entity",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [pkey_field, vec_field, entity_field],
|
||||
description = "Graph embedding schema",
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[dimension] = collection_name
|
||||
|
||||
def insert(self, embeds, entity):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"entity": entity,
|
||||
}
|
||||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[dim],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, fields=["entity"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
coll = self.collections[dim]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
print("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
print("Searching...")
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
|
||||
# If reload time has passed, unload collection
|
||||
if time.time() > self.next_reload:
|
||||
print("Unloading, reload at", self.next_reload)
|
||||
self.client.release_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
|
||||
return res
|
||||
|
||||
154
trustgraph-flow/trustgraph/direct/milvus_object_embeddings.py
Normal file
154
trustgraph-flow/trustgraph/direct/milvus_object_embeddings.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
|
||||
class ObjectVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='obj'):
|
||||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
# Time between reloads
|
||||
self.reload_time = 90
|
||||
|
||||
# Next time to reload - this forces a reload at next window
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
print("Reload at", self.next_reload)
|
||||
|
||||
def init_collection(self, dimension, name):
|
||||
|
||||
collection_name = self.prefix + "_" + name + "_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
name_field = FieldSchema(
|
||||
name="name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
key_name_field = FieldSchema(
|
||||
name="key_name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
key_field = FieldSchema(
|
||||
name="key",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [
|
||||
pkey_field, vec_field, name_field, key_name_field, key_field
|
||||
],
|
||||
description = "Object embedding schema",
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[(dimension, name)] = collection_name
|
||||
|
||||
def insert(self, embeds, name, key_name, key):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if (dim, name) not in self.collections:
|
||||
self.init_collection(dim, name)
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"name": name,
|
||||
"key_name": key_name,
|
||||
"key": key,
|
||||
}
|
||||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[(dim, name)],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, name, fields=["key_name", "name"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim, name)
|
||||
|
||||
coll = self.collections[(dim, name)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
print("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
print("Searching...")
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
|
||||
# If reload time has passed, unload collection
|
||||
if time.time() > self.next_reload:
|
||||
print("Unloading, reload at", self.next_reload)
|
||||
self.client.release_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
|
||||
return res
|
||||
|
||||
132
trustgraph-flow/trustgraph/document_rag.py
Normal file
132
trustgraph-flow/trustgraph/document_rag.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
|
||||
from . clients.document_embeddings_client import DocumentEmbeddingsClient
|
||||
from . clients.triples_query_client import TriplesQueryClient
|
||||
from . clients.embeddings_client import EmbeddingsClient
|
||||
from . clients.prompt_client import PromptClient
|
||||
|
||||
from . schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||
from . schema import TriplesQueryRequest, TriplesQueryResponse
|
||||
from . schema import prompt_request_queue
|
||||
from . schema import prompt_response_queue
|
||||
from . schema import embeddings_request_queue
|
||||
from . schema import embeddings_response_queue
|
||||
from . schema import document_embeddings_request_queue
|
||||
from . schema import document_embeddings_response_queue
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
|
||||
|
||||
class DocumentRag:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pr_request_queue=None,
|
||||
pr_response_queue=None,
|
||||
emb_request_queue=None,
|
||||
emb_response_queue=None,
|
||||
de_request_queue=None,
|
||||
de_response_queue=None,
|
||||
verbose=False,
|
||||
module="test",
|
||||
):
|
||||
|
||||
self.verbose=verbose
|
||||
|
||||
if pr_request_queue is None:
|
||||
pr_request_queue = prompt_request_queue
|
||||
|
||||
if pr_response_queue is None:
|
||||
pr_response_queue = prompt_response_queue
|
||||
|
||||
if emb_request_queue is None:
|
||||
emb_request_queue = embeddings_request_queue
|
||||
|
||||
if emb_response_queue is None:
|
||||
emb_response_queue = embeddings_response_queue
|
||||
|
||||
if de_request_queue is None:
|
||||
de_request_queue = document_embeddings_request_queue
|
||||
|
||||
if de_response_queue is None:
|
||||
de_response_queue = document_embeddings_response_queue
|
||||
|
||||
if self.verbose:
|
||||
print("Initialising...", flush=True)
|
||||
|
||||
# FIXME: Configurable
|
||||
self.entity_limit = 20
|
||||
|
||||
self.de_client = DocumentEmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
subscriber=module + "-de",
|
||||
input_queue=de_request_queue,
|
||||
output_queue=de_response_queue,
|
||||
)
|
||||
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
self.lang = PromptClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber=module + "-de-prompt",
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = self.embeddings.request(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
def get_docs(self, query):
|
||||
|
||||
vectors = self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
docs = self.de_client.request(
|
||||
vectors, self.entity_limit
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Docs:", flush=True)
|
||||
for doc in docs:
|
||||
print(doc, flush=True)
|
||||
|
||||
return docs
|
||||
|
||||
def query(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
docs = self.get_docs(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(docs)
|
||||
print(query)
|
||||
|
||||
resp = self.lang.request_document_prompt(query, docs)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
||||
return resp
|
||||
|
||||
0
trustgraph-flow/trustgraph/embeddings/__init__.py
Normal file
0
trustgraph-flow/trustgraph/embeddings/__init__.py
Normal file
3
trustgraph-flow/trustgraph/embeddings/ollama/__init__.py
Normal file
3
trustgraph-flow/trustgraph/embeddings/ollama/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . processor import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/embeddings/ollama/__main__.py
Executable file
7
trustgraph-flow/trustgraph/embeddings/ollama/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . processor import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
84
trustgraph-flow/trustgraph/embeddings/ollama/processor.py
Executable file
84
trustgraph-flow/trustgraph/embeddings/ollama/processor.py
Executable file
|
|
@ -0,0 +1,84 @@
|
|||
|
||||
"""
|
||||
Embeddings service, applies an embeddings model selected from HuggingFace.
|
||||
Input is text, output is embeddings vector.
|
||||
"""
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
|
||||
from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from ... schema import embeddings_request_queue, embeddings_response_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = embeddings_request_queue
|
||||
default_output_queue = embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_model="mxbai-embed-large"
|
||||
default_ollama = 'http://localhost:11434'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": EmbeddingsRequest,
|
||||
"output_schema": EmbeddingsResponse,
|
||||
}
|
||||
)
|
||||
|
||||
self.embeddings = OllamaEmbeddings(base_url=ollama, model=model)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
text = v.text
|
||||
embeds = self.embeddings.embed_query([text])
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = EmbeddingsResponse(vectors=[embeds])
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default=default_model,
|
||||
help=f'Embeddings model (default: {default_model})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-r', '--ollama',
|
||||
default=default_ollama,
|
||||
help=f'ollama (default: {default_ollama})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . vectorize import *
|
||||
|
||||
6
trustgraph-flow/trustgraph/embeddings/vectorize/__main__.py
Executable file
6
trustgraph-flow/trustgraph/embeddings/vectorize/__main__.py
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
|
||||
from . vectorize import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
103
trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py
Executable file
103
trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py
Executable file
|
|
@ -0,0 +1,103 @@
|
|||
|
||||
"""
|
||||
Vectorizer, calls the embeddings service to get embeddings for a chunk.
|
||||
Input is text chunk, output is chunk and vectors.
|
||||
"""
|
||||
|
||||
from ... schema import Chunk, ChunkEmbeddings
|
||||
from ... schema import chunk_ingest_queue, chunk_embeddings_ingest_queue
|
||||
from ... schema import embeddings_request_queue, embeddings_response_queue
|
||||
from ... clients.embeddings_client import EmbeddingsClient
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_ingest_queue
|
||||
default_output_queue = chunk_embeddings_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
emb_request_queue = params.get(
|
||||
"embeddings_request_queue", embeddings_request_queue
|
||||
)
|
||||
emb_response_queue = params.get(
|
||||
"embeddings_response_queue", embeddings_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"embeddings_request_queue": emb_request_queue,
|
||||
"embeddings_response_queue": emb_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Chunk,
|
||||
"output_schema": ChunkEmbeddings,
|
||||
}
|
||||
)
|
||||
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
def emit(self, source, chunk, vectors):
|
||||
|
||||
r = ChunkEmbeddings(source=source, chunk=chunk, vectors=vectors)
|
||||
self.producer.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
vectors = self.embeddings.request(chunk)
|
||||
|
||||
self.emit(
|
||||
source=v.source,
|
||||
chunk=chunk.encode("utf-8"),
|
||||
vectors=vectors
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-request-queue',
|
||||
default=embeddings_request_queue,
|
||||
help=f'Embeddings request queue (default: {embeddings_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-response-queue',
|
||||
default=embeddings_response_queue,
|
||||
help=f'Embeddings request queue (default: {embeddings_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
0
trustgraph-flow/trustgraph/extract/__init__.py
Normal file
0
trustgraph-flow/trustgraph/extract/__init__.py
Normal file
0
trustgraph-flow/trustgraph/extract/kg/__init__.py
Normal file
0
trustgraph-flow/trustgraph/extract/kg/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/extract/kg/definitions/__main__.py
Executable file
7
trustgraph-flow/trustgraph/extract/kg/definitions/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
134
trustgraph-flow/trustgraph/extract/kg/definitions/extract.py
Executable file
134
trustgraph-flow/trustgraph/extract/kg/definitions/extract.py
Executable file
|
|
@ -0,0 +1,134 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
|
||||
get entity definitions which are output as graph edges.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import json
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Source, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Triple,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
def to_uri(self, text):
|
||||
|
||||
part = text.replace(" ", "-").lower().encode("utf-8")
|
||||
quoted = urllib.parse.quote(part)
|
||||
uri = TRUSTGRAPH_ENTITIES + quoted
|
||||
|
||||
return uri
|
||||
|
||||
def get_definitions(self, chunk):
|
||||
|
||||
return self.prompt.request_definitions(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
defs = self.get_definitions(chunk)
|
||||
|
||||
for defn in defs:
|
||||
|
||||
s = defn.name
|
||||
o = defn.definition
|
||||
|
||||
if s == "": continue
|
||||
if o == "": continue
|
||||
|
||||
if s is None: continue
|
||||
if o is None: continue
|
||||
|
||||
s_uri = self.to_uri(s)
|
||||
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-completion-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/extract/kg/relationships/__main__.py
Executable file
7
trustgraph-flow/trustgraph/extract/kg/relationships/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
208
trustgraph-flow/trustgraph/extract/kg/relationships/extract.py
Executable file
208
trustgraph-flow/trustgraph/extract/kg/relationships/extract.py
Executable file
|
|
@ -0,0 +1,208 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts vector+text chunks input, applies entity
|
||||
relationship analysis to get entity relationship edges which are output as
|
||||
graph edges.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import graph_embeddings_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_vector_queue = graph_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
vector_queue = params.get("vector_queue", default_vector_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Triple,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.vec_prod = self.client.create_producer(
|
||||
topic=vector_queue,
|
||||
schema=JsonSchema(GraphEmbeddings),
|
||||
)
|
||||
|
||||
__class__.pubsub_metric.info({
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"vector_queue": vector_queue,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings.__name__,
|
||||
"output_schema": Triple.__name__,
|
||||
"vector_schema": GraphEmbeddings.__name__,
|
||||
})
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
def to_uri(self, text):
|
||||
|
||||
part = text.replace(" ", "-").lower().encode("utf-8")
|
||||
quoted = urllib.parse.quote(part)
|
||||
uri = TRUSTGRAPH_ENTITIES + quoted
|
||||
|
||||
return uri
|
||||
|
||||
def get_relationships(self, chunk):
|
||||
|
||||
return self.prompt.request_relationships(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_vec(self, ent, vec):
|
||||
|
||||
r = GraphEmbeddings(entity=ent, vectors=vec)
|
||||
self.vec_prod.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
rels = self.get_relationships(chunk)
|
||||
|
||||
for rel in rels:
|
||||
|
||||
s = rel.s
|
||||
p = rel.p
|
||||
o = rel.o
|
||||
|
||||
if s == "": continue
|
||||
if p == "": continue
|
||||
if o == "": continue
|
||||
|
||||
if s is None: continue
|
||||
if p is None: continue
|
||||
if o is None: continue
|
||||
|
||||
s_uri = self.to_uri(s)
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
|
||||
p_uri = self.to_uri(p)
|
||||
p_value = Value(value=str(p_uri), is_uri=True)
|
||||
|
||||
if rel.o_entity:
|
||||
o_uri = self.to_uri(o)
|
||||
o_value = Value(value=str(o_uri), is_uri=True)
|
||||
else:
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(
|
||||
s_value,
|
||||
p_value,
|
||||
o_value
|
||||
)
|
||||
|
||||
# Label for s
|
||||
self.emit_edge(
|
||||
s_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(s), is_uri=False)
|
||||
)
|
||||
|
||||
# Label for p
|
||||
self.emit_edge(
|
||||
p_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(p), is_uri=False)
|
||||
)
|
||||
|
||||
if rel.o_entity:
|
||||
# Label for o
|
||||
self.emit_edge(
|
||||
o_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(o), is_uri=False)
|
||||
)
|
||||
|
||||
self.emit_vec(s_value, v.vectors)
|
||||
self.emit_vec(p_value, v.vectors)
|
||||
if rel.o_entity:
|
||||
self.emit_vec(o_value, v.vectors)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--vector-queue',
|
||||
default=default_vector_queue,
|
||||
help=f'Vector output queue (default: {default_vector_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
3
trustgraph-flow/trustgraph/extract/kg/topics/__init__.py
Normal file
3
trustgraph-flow/trustgraph/extract/kg/topics/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/extract/kg/topics/__main__.py
Executable file
7
trustgraph-flow/trustgraph/extract/kg/topics/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
134
trustgraph-flow/trustgraph/extract/kg/topics/extract.py
Executable file
134
trustgraph-flow/trustgraph/extract/kg/topics/extract.py
Executable file
|
|
@ -0,0 +1,134 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
|
||||
get entity definitions which are output as graph edges.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import json
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Source, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Triple,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
def to_uri(self, text):
|
||||
|
||||
part = text.replace(" ", "-").lower().encode("utf-8")
|
||||
quoted = urllib.parse.quote(part)
|
||||
uri = TRUSTGRAPH_ENTITIES + quoted
|
||||
|
||||
return uri
|
||||
|
||||
def get_topics(self, chunk):
|
||||
|
||||
return self.prompt.request_topics(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
defs = self.get_topics(chunk)
|
||||
|
||||
for defn in defs:
|
||||
|
||||
s = defn.name
|
||||
o = defn.definition
|
||||
|
||||
if s == "": continue
|
||||
if o == "": continue
|
||||
|
||||
if s is None: continue
|
||||
if o is None: continue
|
||||
|
||||
s_uri = self.to_uri(s)
|
||||
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-completion-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
0
trustgraph-flow/trustgraph/extract/object/__init__.py
Normal file
0
trustgraph-flow/trustgraph/extract/object/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/extract/object/row/__main__.py
Executable file
7
trustgraph-flow/trustgraph/extract/object/row/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
220
trustgraph-flow/trustgraph/extract/object/row/extract.py
Executable file
220
trustgraph-flow/trustgraph/extract/object/row/extract.py
Executable file
|
|
@ -0,0 +1,220 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts vector+text chunks input, applies analysis to pull
|
||||
out a row of fields. Output as a vector plus object.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Source
|
||||
from .... schema import RowSchema, Field
|
||||
from .... schema import chunk_embeddings_ingest_queue, rows_store_queue
|
||||
from .... schema import object_embeddings_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
from .... objects.field import Field as FieldParser
|
||||
from .... objects.object import Schema
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = rows_store_queue
|
||||
default_vector_queue = object_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
vector_queue = params.get("vector_queue", default_vector_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Rows,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.vec_prod = self.client.create_producer(
|
||||
topic=vector_queue,
|
||||
schema=JsonSchema(ObjectEmbeddings),
|
||||
)
|
||||
|
||||
__class__.pubsub_metric.info({
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"vector_queue": vector_queue,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings.__name__,
|
||||
"output_schema": Rows.__name__,
|
||||
"vector_schema": ObjectEmbeddings.__name__,
|
||||
})
|
||||
|
||||
flds = __class__.parse_fields(params["field"])
|
||||
|
||||
for fld in flds:
|
||||
print(fld)
|
||||
|
||||
self.primary = None
|
||||
|
||||
for f in flds:
|
||||
if f.primary:
|
||||
if self.primary:
|
||||
raise RuntimeError(
|
||||
"Only one primary key field is supported"
|
||||
)
|
||||
self.primary = f
|
||||
|
||||
if self.primary == None:
|
||||
raise RuntimeError(
|
||||
"Must have exactly one primary key field"
|
||||
)
|
||||
|
||||
self.schema = Schema(
|
||||
name = params["name"],
|
||||
description = params["description"],
|
||||
fields = flds
|
||||
)
|
||||
|
||||
self.row_schema=RowSchema(
|
||||
name=self.schema.name,
|
||||
description=self.schema.description,
|
||||
fields=[
|
||||
Field(
|
||||
name=f.name, type=str(f.type), size=f.size,
|
||||
primary=f.primary, description=f.description,
|
||||
)
|
||||
for f in self.schema.fields
|
||||
]
|
||||
)
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_fields(fields):
|
||||
return [ FieldParser.parse(f) for f in fields ]
|
||||
|
||||
def get_rows(self, chunk):
|
||||
return self.prompt.request_rows(self.schema, chunk)
|
||||
|
||||
def emit_rows(self, source, rows):
|
||||
|
||||
t = Rows(
|
||||
source=source, row_schema=self.row_schema, rows=rows
|
||||
)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_vec(self, source, name, vec, key_name, key):
|
||||
|
||||
r = ObjectEmbeddings(
|
||||
source=source, vectors=vec, name=name, key_name=key_name, id=key
|
||||
)
|
||||
self.vec_prod.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
rows = self.get_rows(chunk)
|
||||
|
||||
self.emit_rows(
|
||||
source=v.source,
|
||||
rows=rows
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
self.emit_vec(
|
||||
source=v.source, vec=v.vectors,
|
||||
name=self.schema.name, key_name=self.primary.name,
|
||||
key=row[self.primary.name]
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
print(row)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--vector-queue',
|
||||
default=default_vector_queue,
|
||||
help=f'Vector output queue (default: {default_vector_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--field',
|
||||
required=True,
|
||||
action='append',
|
||||
help=f'Field definition, format name:type:size:pri:descriptionn',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-n', '--name',
|
||||
required=True,
|
||||
help=f'Name of row object',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-d', '--description',
|
||||
required=True,
|
||||
help=f'Description of object',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
250
trustgraph-flow/trustgraph/graph_rag.py
Normal file
250
trustgraph-flow/trustgraph/graph_rag.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
|
||||
from . clients.graph_embeddings_client import GraphEmbeddingsClient
|
||||
from . clients.triples_query_client import TriplesQueryClient
|
||||
from . clients.embeddings_client import EmbeddingsClient
|
||||
from . clients.prompt_client import PromptClient
|
||||
|
||||
from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from . schema import TriplesQueryRequest, TriplesQueryResponse
|
||||
from . schema import prompt_request_queue
|
||||
from . schema import prompt_response_queue
|
||||
from . schema import embeddings_request_queue
|
||||
from . schema import embeddings_response_queue
|
||||
from . schema import graph_embeddings_request_queue
|
||||
from . schema import graph_embeddings_response_queue
|
||||
from . schema import triples_request_queue
|
||||
from . schema import triples_response_queue
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
|
||||
|
||||
class GraphRag:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pr_request_queue=None,
|
||||
pr_response_queue=None,
|
||||
emb_request_queue=None,
|
||||
emb_response_queue=None,
|
||||
ge_request_queue=None,
|
||||
ge_response_queue=None,
|
||||
tpl_request_queue=None,
|
||||
tpl_response_queue=None,
|
||||
verbose=False,
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=3000,
|
||||
module="test",
|
||||
):
|
||||
|
||||
self.verbose=verbose
|
||||
|
||||
if pr_request_queue is None:
|
||||
pr_request_queue = prompt_request_queue
|
||||
|
||||
if pr_response_queue is None:
|
||||
pr_response_queue = prompt_response_queue
|
||||
|
||||
if emb_request_queue is None:
|
||||
emb_request_queue = embeddings_request_queue
|
||||
|
||||
if emb_response_queue is None:
|
||||
emb_response_queue = embeddings_response_queue
|
||||
|
||||
if ge_request_queue is None:
|
||||
ge_request_queue = graph_embeddings_request_queue
|
||||
|
||||
if ge_response_queue is None:
|
||||
ge_response_queue = graph_embeddings_response_queue
|
||||
|
||||
if tpl_request_queue is None:
|
||||
tpl_request_queue = triples_request_queue
|
||||
|
||||
if tpl_response_queue is None:
|
||||
tpl_response_queue = triples_response_queue
|
||||
|
||||
if self.verbose:
|
||||
print("Initialising...", flush=True)
|
||||
|
||||
self.ge_client = GraphEmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
subscriber=module + "-ge",
|
||||
input_queue=ge_request_queue,
|
||||
output_queue=ge_response_queue,
|
||||
)
|
||||
|
||||
self.triples_client = TriplesQueryClient(
|
||||
pulsar_host=pulsar_host,
|
||||
subscriber=module + "-tpl",
|
||||
input_queue=tpl_request_queue,
|
||||
output_queue=tpl_response_queue
|
||||
)
|
||||
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
self.entity_limit=entity_limit
|
||||
self.query_limit=triple_limit
|
||||
self.max_subgraph_size=max_subgraph_size
|
||||
|
||||
self.label_cache = {}
|
||||
|
||||
self.lang = PromptClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber=module + "-prompt",
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = self.embeddings.request(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
def get_entities(self, query):
|
||||
|
||||
vectors = self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
entities = self.ge_client.request(
|
||||
vectors, self.entity_limit
|
||||
)
|
||||
|
||||
entities = [
|
||||
e.value
|
||||
for e in entities
|
||||
]
|
||||
|
||||
if self.verbose:
|
||||
print("Entities:", flush=True)
|
||||
for ent in entities:
|
||||
print(" ", ent, flush=True)
|
||||
|
||||
return entities
|
||||
|
||||
def maybe_label(self, e):
|
||||
|
||||
if e in self.label_cache:
|
||||
return self.label_cache[e]
|
||||
|
||||
res = self.triples_client.request(
|
||||
e, LABEL, None, limit=1
|
||||
)
|
||||
|
||||
if len(res) == 0:
|
||||
self.label_cache[e] = e
|
||||
return e
|
||||
|
||||
self.label_cache[e] = res[0].o.value
|
||||
return self.label_cache[e]
|
||||
|
||||
def get_subgraph(self, query):
|
||||
|
||||
entities = self.get_entities(query)
|
||||
|
||||
subgraph = set()
|
||||
|
||||
if self.verbose:
|
||||
print("Get subgraph...", flush=True)
|
||||
|
||||
for e in entities:
|
||||
|
||||
res = self.triples_client.request(
|
||||
e, None, None,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.triples_client.request(
|
||||
None, e, None,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.triples_client.request(
|
||||
None, None, e,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
|
||||
subgraph = subgraph[0:self.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
print("Subgraph:", flush=True)
|
||||
for edge in subgraph:
|
||||
print(" ", str(edge), flush=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return subgraph
|
||||
|
||||
def get_labelgraph(self, query):
|
||||
|
||||
subgraph = self.get_subgraph(query)
|
||||
|
||||
sg2 = []
|
||||
|
||||
for edge in subgraph:
|
||||
|
||||
if edge[1] == LABEL:
|
||||
continue
|
||||
|
||||
s = self.maybe_label(edge[0])
|
||||
p = self.maybe_label(edge[1])
|
||||
o = self.maybe_label(edge[2])
|
||||
|
||||
sg2.append((s, p, o))
|
||||
|
||||
return sg2
|
||||
|
||||
def query(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
kg = self.get_labelgraph(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(kg)
|
||||
print(query)
|
||||
|
||||
resp = self.lang.request_kg_prompt(query, kg)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
||||
return resp
|
||||
|
||||
3
trustgraph-flow/trustgraph/metering/__init__.py
Normal file
3
trustgraph-flow/trustgraph/metering/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . counter import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/metering/__main__.py
Executable file
7
trustgraph-flow/trustgraph/metering/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . counter import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
75
trustgraph-flow/trustgraph/metering/counter.py
Normal file
75
trustgraph-flow/trustgraph/metering/counter.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""
|
||||
Simple token counter for each LLM response.
|
||||
"""
|
||||
|
||||
from prometheus_client import Histogram, Info
|
||||
from . pricelist import price_list
|
||||
|
||||
from .. schema import TextCompletionResponse, Error
|
||||
from .. schema import text_completion_response_queue
|
||||
from .. log_level import LogLevel
|
||||
from .. base import Consumer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionResponse,
|
||||
}
|
||||
)
|
||||
|
||||
def get_prices(self, prices, modelname):
|
||||
for model in prices["price_list"]:
|
||||
if model["model_name"] == modelname:
|
||||
return model["input_price"], model["output_price"]
|
||||
return None, None # Return None if model is not found
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
modelname = v.model
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling response {id}...", flush=True)
|
||||
|
||||
num_in = v.in_token
|
||||
num_out = v.out_token
|
||||
|
||||
model_input_price, model_output_price = self.get_prices(price_list, modelname)
|
||||
|
||||
if model_input_price == None:
|
||||
cost_per_call = f"Model Not Found in Price list"
|
||||
else:
|
||||
cost_in = num_in * model_input_price
|
||||
cost_out = num_out * model_output_price
|
||||
cost_per_call = round(cost_in + cost_out, 6)
|
||||
|
||||
print(f"Input Tokens: {num_in}", flush=True)
|
||||
print(f"Output Tokens: {num_out}", flush=True)
|
||||
print(f"Cost for call: ${cost_per_call}", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
104
trustgraph-flow/trustgraph/metering/pricelist.py
Normal file
104
trustgraph-flow/trustgraph/metering/pricelist.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
price_list = {
|
||||
"price_list": [
|
||||
{
|
||||
"model_name": "mistral.mistral-large-2407-v1:0",
|
||||
"input_price": 0.000004,
|
||||
"output_price": 0.000012
|
||||
},
|
||||
{
|
||||
"model_name": "meta.llama3-1-405b-instruct-v1:0",
|
||||
"input_price": 0.00000532,
|
||||
"output_price": 0.000016
|
||||
},
|
||||
{
|
||||
"model_name": "mistral.mixtral-8x7b-instruct-v0:1",
|
||||
"input_price": 0.00000045,
|
||||
"output_price": 0.0000007
|
||||
},
|
||||
{
|
||||
"model_name": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"input_price": 0.00000099,
|
||||
"output_price": 0.00000099
|
||||
},
|
||||
{
|
||||
"model_name": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"input_price": 0.00000022,
|
||||
"output_price": 0.00000022
|
||||
},
|
||||
{
|
||||
"model_name": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"input_price": 0.00000025,
|
||||
"output_price": 0.00000125
|
||||
},
|
||||
{
|
||||
"model_name": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"input_price": 0.000003,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "cohere.command-r-plus-v1:0",
|
||||
"input_price": 0.0000030,
|
||||
"output_price": 0.0000150
|
||||
},
|
||||
{
|
||||
"model_name": "ollama",
|
||||
"input_price": 0,
|
||||
"output_price": 0
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-haiku-20240307",
|
||||
"input_price": 0.00000025,
|
||||
"output_price": 0.00000125
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-5-sonnet-20240620",
|
||||
"input_price": 0.000003,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-opus-20240229",
|
||||
"input_price": 0.000015,
|
||||
"output_price": 0.000075
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-sonnet-20240229",
|
||||
"input_price": 0.000003,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "command-r-08-202",
|
||||
"input_price": 0.0000025,
|
||||
"output_price": 0.000010
|
||||
},
|
||||
{
|
||||
"model_name": "c4ai-aya-23-8b",
|
||||
"input_price": 0,
|
||||
"output_price": 0
|
||||
},
|
||||
{
|
||||
"model_name": "llama.cpp",
|
||||
"input_price": 0,
|
||||
"output_price": 0
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o",
|
||||
"input_price": 0.000005,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o-2024-08-06",
|
||||
"input_price": 0.0000025,
|
||||
"output_price": 0.000010
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o-2024-05-13",
|
||||
"input_price": 0.000005,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"input_price": 0.00000015,
|
||||
"output_price": 0.0000006
|
||||
},
|
||||
]
|
||||
}
|
||||
0
trustgraph-flow/trustgraph/model/__init__.py
Normal file
0
trustgraph-flow/trustgraph/model/__init__.py
Normal file
0
trustgraph-flow/trustgraph/model/prompt/__init__.py
Normal file
0
trustgraph-flow/trustgraph/model/prompt/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/prompt/generic/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/prompt/generic/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
176
trustgraph-flow/trustgraph/model/prompt/generic/prompts.py
Normal file
176
trustgraph-flow/trustgraph/model/prompt/generic/prompts.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
|
||||
def to_relationships(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.
|
||||
|
||||
Read the provided text. You will model the text as an information network for a RDF knowledge graph in JSON.
|
||||
|
||||
Information Network Rules:
|
||||
- An information network has subjects connected by predicates to objects.
|
||||
- A subject is a named-entity or a conceptual topic.
|
||||
- One subject can have many predicates and objects.
|
||||
- An object is a property or attribute of a subject.
|
||||
- A subject can be connected by a predicate to another subject.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Obey the information network rules.
|
||||
- Do not return special characters.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of JSON objects with keys "subject", "predicate", "object", and "object-entity".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"subject": string, "predicate": string, "object": string, "object-entity": boolean}}]
|
||||
```
|
||||
|
||||
- The key "object-entity" is TRUE only if the "object" is a subject.
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_topics(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify topics and their definitions in JSON.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Do not respond with special characters.
|
||||
- Return only topics that are concepts and unique to the provided text.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of objects with keys "topic" and "definition".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"topic": string, "definition": string}}]
|
||||
```
|
||||
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_definitions(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify entities and their definitions in JSON.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Do not respond with special characters.
|
||||
- Return only entities that are named-entities such as: people, organizations, physical objects, locations, animals, products, commodotities, or substances.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of objects with keys "entity" and "definition".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"entity": string, "definition": string}}]
|
||||
```
|
||||
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_rows(schema, text):
|
||||
|
||||
field_schema = [
|
||||
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
|
||||
for f in schema.fields
|
||||
]
|
||||
|
||||
field_schema = "\n".join(field_schema)
|
||||
|
||||
schema = f"""Object name: {schema.name}
|
||||
Description: {schema.description}
|
||||
|
||||
Fields:
|
||||
{field_schema}"""
|
||||
|
||||
prompt = f"""<instructions>
|
||||
Study the following text and derive objects which match the schema provided.
|
||||
|
||||
You must output an array of JSON objects for each object you discover
|
||||
which matches the schema. For each object, output a JSON object whose fields
|
||||
carry the name field specified in the schema.
|
||||
</instructions>
|
||||
|
||||
<schema>
|
||||
{schema}
|
||||
</schema>
|
||||
|
||||
<text>
|
||||
{text}
|
||||
</text>
|
||||
|
||||
<requirements>
|
||||
You will respond only with raw JSON format data. Do not provide
|
||||
explanations. Do not add markdown formatting or headers or prefixes.
|
||||
</requirements>"""
|
||||
|
||||
return prompt
|
||||
|
||||
def get_cypher(kg):
|
||||
|
||||
sg2 = []
|
||||
|
||||
for f in kg:
|
||||
|
||||
print(f)
|
||||
|
||||
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
|
||||
|
||||
print(sg2)
|
||||
|
||||
kg = "\n".join(sg2)
|
||||
kg = kg.replace("\\", "-")
|
||||
|
||||
return kg
|
||||
|
||||
def to_kg_query(query, kg):
|
||||
|
||||
cypher = get_cypher(kg)
|
||||
|
||||
prompt=f"""Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
|
||||
|
||||
Here's the knowledge statements:
|
||||
{cypher}
|
||||
|
||||
Use only the provided knowledge statements to respond to the following:
|
||||
{query}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_document_query(query, documents):
|
||||
|
||||
documents = "\n\n".join(documents)
|
||||
|
||||
prompt=f"""Study the following context. Use only the information provided in the context in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
|
||||
|
||||
Here is the context:
|
||||
{documents}
|
||||
|
||||
Use only the provided knowledge statements to respond to the following:
|
||||
{query}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
473
trustgraph-flow/trustgraph/model/prompt/generic/service.py
Executable file
473
trustgraph-flow/trustgraph/model/prompt/generic/service.py
Executable file
|
|
@ -0,0 +1,473 @@
|
|||
"""
|
||||
Language service abstracts prompt engineering from LLM.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .... schema import Definition, Relationship, Triple
|
||||
from .... schema import Topic
|
||||
from .... schema import PromptRequest, PromptResponse, Error
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... schema import prompt_request_queue, prompt_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... clients.llm_client import LlmClient
|
||||
|
||||
from . prompts import to_definitions, to_relationships, to_topics
|
||||
from . prompts import to_kg_query, to_document_query, to_rows
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = prompt_request_queue
|
||||
default_output_queue = prompt_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
tc_request_queue = params.get(
|
||||
"text_completion_request_queue", text_completion_request_queue
|
||||
)
|
||||
tc_response_queue = params.get(
|
||||
"text_completion_response_queue", text_completion_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": PromptRequest,
|
||||
"output_schema": PromptResponse,
|
||||
"text_completion_request_queue": tc_request_queue,
|
||||
"text_completion_response_queue": tc_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.llm = LlmClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=tc_request_queue,
|
||||
output_queue=tc_response_queue,
|
||||
pulsar_host = self.pulsar_host
|
||||
)
|
||||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# If no delimiters, assume the entire output is JSON
|
||||
json_str = text.strip()
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
kind = v.kind
|
||||
|
||||
print(f"Handling kind {kind}...", flush=True)
|
||||
|
||||
if kind == "extract-definitions":
|
||||
|
||||
self.handle_extract_definitions(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-topics":
|
||||
|
||||
self.handle_extract_topics(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-relationships":
|
||||
|
||||
self.handle_extract_relationships(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-rows":
|
||||
|
||||
self.handle_extract_rows(id, v)
|
||||
return
|
||||
|
||||
elif kind == "kg-prompt":
|
||||
|
||||
self.handle_kg_prompt(id, v)
|
||||
return
|
||||
|
||||
elif kind == "document-prompt":
|
||||
|
||||
self.handle_document_prompt(id, v)
|
||||
return
|
||||
|
||||
else:
|
||||
|
||||
print("Invalid kind.", flush=True)
|
||||
return
|
||||
|
||||
def handle_extract_definitions(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_definitions(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["entity"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Definition(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(definitions=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_topics(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_topics(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["topic"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Topic(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(topics=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_relationships(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_relationships(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
|
||||
s = defn["subject"]
|
||||
p = defn["predicate"]
|
||||
o = defn["object"]
|
||||
o_entity = defn["object-entity"]
|
||||
|
||||
if s == "": continue
|
||||
if s is None: continue
|
||||
|
||||
if p == "": continue
|
||||
if p is None: continue
|
||||
|
||||
if o == "": continue
|
||||
if o is None: continue
|
||||
|
||||
if o_entity == "" or o_entity is None:
|
||||
o_entity = False
|
||||
|
||||
output.append(
|
||||
Relationship(
|
||||
s = s,
|
||||
p = p,
|
||||
o = o,
|
||||
o_entity = o_entity,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("relationship fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(relationships=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_rows(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
fields = v.row_schema.fields
|
||||
|
||||
prompt = to_rows(v.row_schema, v.chunk)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
objs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
objs = []
|
||||
|
||||
output = []
|
||||
|
||||
for obj in objs:
|
||||
|
||||
try:
|
||||
|
||||
row = {}
|
||||
|
||||
for f in fields:
|
||||
|
||||
if f.name not in obj:
|
||||
print(f"Object ignored, missing field {f.name}")
|
||||
row = {}
|
||||
break
|
||||
|
||||
row[f.name] = obj[f.name]
|
||||
|
||||
if row == {}:
|
||||
continue
|
||||
|
||||
output.append(row)
|
||||
|
||||
except Exception as e:
|
||||
print("row fields missing, ignored", flush=True)
|
||||
|
||||
for row in output:
|
||||
print(row)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(rows=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_kg_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_kg_query(v.query, v.kg)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_document_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_document_query(v.query, v.documents)
|
||||
|
||||
print("prompt")
|
||||
print(prompt)
|
||||
|
||||
print("Call LLM...")
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-request-queue',
|
||||
default=text_completion_request_queue,
|
||||
help=f'Text completion request queue (default: {text_completion_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-response-queue',
|
||||
default=text_completion_response_queue,
|
||||
help=f'Text completion response queue (default: {text_completion_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/prompt/template/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/prompt/template/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
47
trustgraph-flow/trustgraph/model/prompt/template/prompts.py
Normal file
47
trustgraph-flow/trustgraph/model/prompt/template/prompts.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
|
||||
def to_relationships(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_definitions(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_topics(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_rows(template, schema, text):
|
||||
|
||||
field_schema = [
|
||||
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
|
||||
for f in schema.fields
|
||||
]
|
||||
|
||||
field_schema = "\n".join(field_schema)
|
||||
|
||||
return template.format(schema=schema, text=text)
|
||||
|
||||
schema = f"""Object name: {schema.name}
|
||||
Description: {schema.description}
|
||||
|
||||
Fields:
|
||||
{schema}"""
|
||||
|
||||
prompt = f""""""
|
||||
|
||||
return prompt
|
||||
|
||||
def get_cypher(kg):
|
||||
sg2 = []
|
||||
for f in kg:
|
||||
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
|
||||
kg = "\n".join(sg2)
|
||||
kg = kg.replace("\\", "-")
|
||||
return kg
|
||||
|
||||
def to_kg_query(template, query, kg):
|
||||
cypher = get_cypher(kg)
|
||||
return template.format(query=query, graph=cypher)
|
||||
|
||||
def to_document_query(template, query, docs):
|
||||
docs = "\n\n".join(docs)
|
||||
return template.format(query=query, documents=docs)
|
||||
|
||||
523
trustgraph-flow/trustgraph/model/prompt/template/service.py
Executable file
523
trustgraph-flow/trustgraph/model/prompt/template/service.py
Executable file
|
|
@ -0,0 +1,523 @@
|
|||
|
||||
"""
|
||||
Language service abstracts prompt engineering from LLM.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .... schema import Definition, Relationship, Triple
|
||||
from .... schema import Topic
|
||||
from .... schema import PromptRequest, PromptResponse, Error
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... schema import prompt_request_queue, prompt_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... clients.llm_client import LlmClient
|
||||
|
||||
from . prompts import to_definitions, to_relationships, to_rows
|
||||
from . prompts import to_kg_query, to_document_query, to_topics
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = prompt_request_queue
|
||||
default_output_queue = prompt_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
tc_request_queue = params.get(
|
||||
"text_completion_request_queue", text_completion_request_queue
|
||||
)
|
||||
tc_response_queue = params.get(
|
||||
"text_completion_response_queue", text_completion_response_queue
|
||||
)
|
||||
definition_template = params.get("definition_template")
|
||||
relationship_template = params.get("relationship_template")
|
||||
topic_template = params.get("topic_template")
|
||||
rows_template = params.get("rows_template")
|
||||
knowledge_query_template = params.get("knowledge_query_template")
|
||||
document_query_template = params.get("document_query_template")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": PromptRequest,
|
||||
"output_schema": PromptResponse,
|
||||
"text_completion_request_queue": tc_request_queue,
|
||||
"text_completion_response_queue": tc_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.llm = LlmClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=tc_request_queue,
|
||||
output_queue=tc_response_queue,
|
||||
pulsar_host = self.pulsar_host
|
||||
)
|
||||
|
||||
self.definition_template = definition_template
|
||||
self.topic_template = topic_template
|
||||
self.relationship_template = relationship_template
|
||||
self.rows_template = rows_template
|
||||
self.knowledge_query_template = knowledge_query_template
|
||||
self.document_query_template = document_query_template
|
||||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# If no delimiters, assume the entire output is JSON
|
||||
json_str = text.strip()
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
kind = v.kind
|
||||
|
||||
print(f"Handling kind {kind}...", flush=True)
|
||||
|
||||
if kind == "extract-definitions":
|
||||
|
||||
self.handle_extract_definitions(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-topics":
|
||||
|
||||
self.handle_extract_topics(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-relationships":
|
||||
|
||||
self.handle_extract_relationships(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-rows":
|
||||
|
||||
self.handle_extract_rows(id, v)
|
||||
return
|
||||
|
||||
elif kind == "kg-prompt":
|
||||
|
||||
self.handle_kg_prompt(id, v)
|
||||
return
|
||||
|
||||
elif kind == "document-prompt":
|
||||
|
||||
self.handle_document_prompt(id, v)
|
||||
return
|
||||
|
||||
else:
|
||||
|
||||
print("Invalid kind.", flush=True)
|
||||
return
|
||||
|
||||
def handle_extract_definitions(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_definitions(self.definition_template, v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["entity"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Definition(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(definitions=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_topics(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_topics(self.topic_template, v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["topic"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Topic(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(topics=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
|
||||
def handle_extract_relationships(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_relationships(self.relationship_template, v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
|
||||
s = defn["subject"]
|
||||
p = defn["predicate"]
|
||||
o = defn["object"]
|
||||
o_entity = defn["object-entity"]
|
||||
|
||||
if s == "": continue
|
||||
if s is None: continue
|
||||
|
||||
if p == "": continue
|
||||
if p is None: continue
|
||||
|
||||
if o == "": continue
|
||||
if o is None: continue
|
||||
|
||||
if o_entity == "" or o_entity is None:
|
||||
o_entity = False
|
||||
|
||||
output.append(
|
||||
Relationship(
|
||||
s = s,
|
||||
p = p,
|
||||
o = o,
|
||||
o_entity = o_entity,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("relationship fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(relationships=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_rows(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
fields = v.row_schema.fields
|
||||
|
||||
prompt = to_rows(self.rows_template, v.row_schema, v.chunk)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
objs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
objs = []
|
||||
|
||||
output = []
|
||||
|
||||
for obj in objs:
|
||||
|
||||
try:
|
||||
|
||||
row = {}
|
||||
|
||||
for f in fields:
|
||||
|
||||
if f.name not in obj:
|
||||
print(f"Object ignored, missing field {f.name}")
|
||||
row = {}
|
||||
break
|
||||
|
||||
row[f.name] = obj[f.name]
|
||||
|
||||
if row == {}:
|
||||
continue
|
||||
|
||||
output.append(row)
|
||||
|
||||
except Exception as e:
|
||||
print("row fields missing, ignored", flush=True)
|
||||
|
||||
for row in output:
|
||||
print(row)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(rows=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_kg_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_kg_query(self.knowledge_query_template, v.query, v.kg)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_document_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_document_query(
|
||||
self.document_query_template, v.query, v.documents
|
||||
)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-request-queue',
|
||||
default=text_completion_request_queue,
|
||||
help=f'Text completion request queue (default: {text_completion_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-response-queue',
|
||||
default=text_completion_response_queue,
|
||||
help=f'Text completion response queue (default: {text_completion_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--definition-template',
|
||||
required=True,
|
||||
help=f'Definition extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--topic-template',
|
||||
required=True,
|
||||
help=f'Topic extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rows-template',
|
||||
required=True,
|
||||
help=f'Rows extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--relationship-template',
|
||||
required=True,
|
||||
help=f'Relationship extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--knowledge-query-template',
|
||||
required=True,
|
||||
help=f'Knowledge query template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--document-query-template',
|
||||
required=True,
|
||||
help=f'Document query template',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/azure/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/azure/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
226
trustgraph-flow/trustgraph/model/text_completion/azure/llm.py
Executable file
226
trustgraph-flow/trustgraph/model/text_completion/azure/llm.py
Executable file
|
|
@ -0,0 +1,226 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using the Azure
|
||||
serverless endpoint service. Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_temperature = 0.0
|
||||
default_max_output = 4192
|
||||
default_model = "AzureAI"
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
endpoint = params.get("endpoint")
|
||||
token = params.get("token")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
model = default_model
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.token = token
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.model = model
|
||||
|
||||
def build_prompt(self, system, content):
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system", "content": system
|
||||
},
|
||||
{
|
||||
"role": "user", "content": content
|
||||
}
|
||||
],
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 1
|
||||
}
|
||||
|
||||
body = json.dumps(data)
|
||||
|
||||
return body
|
||||
|
||||
def call_llm(self, body):
|
||||
|
||||
url = self.endpoint
|
||||
|
||||
# Replace this with the primary/secondary key, AMLToken, or
|
||||
# Microsoft Entra ID token for the endpoint
|
||||
api_key = self.token
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {api_key}'
|
||||
}
|
||||
|
||||
resp = requests.post(url, data=body, headers=headers)
|
||||
|
||||
if resp.status_code == 429:
|
||||
raise TooManyRequests()
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError("LLM failure")
|
||||
|
||||
result = resp.json()
|
||||
|
||||
return result
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
try:
|
||||
|
||||
prompt = self.build_prompt(
|
||||
"You are a helpful chatbot",
|
||||
v.prompt
|
||||
)
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
response = self.call_llm(prompt)
|
||||
|
||||
resp = response['choices'][0]['message']['content']
|
||||
inputtokens = response['usage']['prompt_tokens']
|
||||
outputtokens = response['usage']['completion_tokens']
|
||||
|
||||
print(resp, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-e', '--endpoint',
|
||||
help=f'LLM model endpoint'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--token',
|
||||
help=f'LLM model token'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/claude/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/claude/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
199
trustgraph-flow/trustgraph/model/text_completion/claude/llm.py
Executable file
199
trustgraph-flow/trustgraph/model/text_completion/claude/llm.py
Executable file
|
|
@ -0,0 +1,199 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using Claude.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import anthropic
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'claude-3-5-sonnet-20240620'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 8192
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.claude = anthropic.Anthropic(api_key=api_key)
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
# FIXME: Rate limits?
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
|
||||
response = message = self.claude.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_output,
|
||||
temperature=self.temperature,
|
||||
system = "You are a helpful chatbot.",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
resp = response.content[0].text
|
||||
inputtokens = response.usage.input_tokens
|
||||
outputtokens = response.usage.output_tokens
|
||||
print(resp, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="claude-3-5-sonnet-20240620",
|
||||
help=f'LLM model (default: claude-3-5-sonnet-20240620)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
help=f'Claude API key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/cohere/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/cohere/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
179
trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py
Executable file
179
trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py
Executable file
|
|
@ -0,0 +1,179 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using Cohere.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import cohere
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'c4ai-aya-23-8b'
|
||||
default_temperature = 0.0
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.cohere = cohere.Client(api_key=api_key)
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
|
||||
output = self.cohere.chat(
|
||||
model=self.model,
|
||||
message=prompt,
|
||||
preamble = "You are a helpful AI-assistant.",
|
||||
temperature=self.temperature,
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
resp = output.text
|
||||
inputtokens = int(output.meta.billed_units.input_tokens)
|
||||
outputtokens = int(output.meta.billed_units.output_tokens)
|
||||
|
||||
print(resp, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="c4ai-aya-23-8b",
|
||||
help=f'Cohere model (default: c4ai-aya-23-8b)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
help=f'Cohere API key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/llamafile/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/llamafile/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
209
trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py
Executable file
209
trustgraph-flow/trustgraph/model/text_completion/llamafile/llm.py
Executable file
|
|
@ -0,0 +1,209 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using OpenAI.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'LLaMA_CPP'
|
||||
default_llamafile = 'http://localhost:8080/v1'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 4096
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
llamafile = params.get("llamafile", default_llamafile)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
"llamafile" : llamafile,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.llamafile=llamafile
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.openai = OpenAI(
|
||||
base_url=self.llamafile,
|
||||
api_key = "sk-no-key-required",
|
||||
)
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
# FIXME: Rate limits
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
#temperature=self.temperature,
|
||||
#max_tokens=self.max_output,
|
||||
#top_p=1,
|
||||
#frequency_penalty=0,
|
||||
#presence_penalty=0,
|
||||
#response_format={
|
||||
# "type": "text"
|
||||
#}
|
||||
)
|
||||
|
||||
inputtokens = resp.usage.prompt_tokens
|
||||
outputtokens = resp.usage.completion_tokens
|
||||
|
||||
print(resp.choices[0].message.content, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TextCompletionResponse(
|
||||
response=resp.choices[0].message.content,
|
||||
error=None,
|
||||
in_token=inputtokens,
|
||||
out_token=outputtokens,
|
||||
model="llama.cpp"
|
||||
)
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default=default_model,
|
||||
help=f'LLM model (default: LLaMA_CPP)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-r', '--llamafile',
|
||||
default=default_llamafile,
|
||||
help=f'ollama (default: {default_llamafile})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/ollama/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/ollama/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
168
trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py
Executable file
168
trustgraph-flow/trustgraph/model/text_completion/ollama/llm.py
Executable file
|
|
@ -0,0 +1,168 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using an Ollama service.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
from ollama import Client
|
||||
from prometheus_client import Histogram, Info
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'gemma2'
|
||||
default_ollama = 'http://localhost:11434'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
ollama = params.get("ollama", default_ollama)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"model": model,
|
||||
"ollama": ollama,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "model_metric"):
|
||||
__class__.model_metric = Info(
|
||||
'model', 'Model information'
|
||||
)
|
||||
|
||||
__class__.model_metric.info({
|
||||
"model": model,
|
||||
"ollama": ollama,
|
||||
})
|
||||
|
||||
self.model = model
|
||||
self.llm = Client(host=ollama)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
response = self.llm.generate(self.model, prompt)
|
||||
|
||||
response_text = response['response']
|
||||
print("Send response...", flush=True)
|
||||
print(response_text, flush=True)
|
||||
|
||||
inputtokens = int(response['prompt_eval_count'])
|
||||
outputtokens = int(response['eval_count'])
|
||||
|
||||
r = TextCompletionResponse(response=response_text, error=None, in_token=inputtokens, out_token=outputtokens, model="ollama")
|
||||
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="gemma2",
|
||||
help=f'LLM model (default: gemma2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-r', '--ollama',
|
||||
default=default_ollama,
|
||||
help=f'ollama (default: {default_ollama})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/openai/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/openai/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
209
trustgraph-flow/trustgraph/model/text_completion/openai/llm.py
Executable file
209
trustgraph-flow/trustgraph/model/text_completion/openai/llm.py
Executable file
|
|
@ -0,0 +1,209 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using OpenAI.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'gpt-3.5-turbo'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 4096
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.openai = OpenAI(api_key=api_key)
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
# FIXME: Rate limits
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
|
||||
resp = self.openai.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_output,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
response_format={
|
||||
"type": "text"
|
||||
}
|
||||
)
|
||||
|
||||
inputtokens = resp.usage.prompt_tokens
|
||||
outputtokens = resp.usage.completion_tokens
|
||||
print(resp.choices[0].message.content, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TextCompletionResponse(
|
||||
response=resp.choices[0].message.content,
|
||||
error=None,
|
||||
in_token=inputtokens,
|
||||
out_token=outputtokens,
|
||||
model=self.model
|
||||
)
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="gpt-3.5-turbo",
|
||||
help=f'LLM model (default: GPT-3.5-Turbo)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
help=f'OpenAI API key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
3
trustgraph-flow/trustgraph/processing/__init__.py
Normal file
3
trustgraph-flow/trustgraph/processing/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . processing import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/processing/__main__.py
Executable file
7
trustgraph-flow/trustgraph/processing/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . processing import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
171
trustgraph-flow/trustgraph/processing/processing.py
Normal file
171
trustgraph-flow/trustgraph/processing/processing.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
|
||||
import argparse
|
||||
import time
|
||||
import os
|
||||
from yaml import load, Loader
|
||||
import json
|
||||
import multiprocessing
|
||||
from multiprocessing.connection import wait
|
||||
|
||||
import importlib
|
||||
|
||||
from .. log_level import LogLevel
|
||||
|
||||
def fn(module_name, class_name, params, w):
|
||||
|
||||
print(f"Starting {module_name}...")
|
||||
|
||||
if "log_level" in params:
|
||||
params["log_level"] = LogLevel(params["log_level"])
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
|
||||
print(f"Starting {class_name} using {module_name}...")
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
class_object = getattr(module, class_name)
|
||||
|
||||
processor = class_object(**params)
|
||||
|
||||
processor.run()
|
||||
print(f"{module_name} stopped.")
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e)
|
||||
|
||||
print("Restarting in 10...")
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
print("Closing")
|
||||
w.close()
|
||||
|
||||
class Processing:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host,
|
||||
log_level,
|
||||
file,
|
||||
):
|
||||
self.pulsar_host = pulsar_host
|
||||
self.log_level = log_level
|
||||
self.file = file
|
||||
|
||||
self.defs = load(open(file, "r"), Loader=Loader)
|
||||
|
||||
def run(self):
|
||||
|
||||
procs = []
|
||||
readers = []
|
||||
services = []
|
||||
|
||||
for service in self.defs["services"]:
|
||||
|
||||
sdef = self.defs["services"][service]
|
||||
|
||||
params = {
|
||||
"pulsar_host": self.pulsar_host,
|
||||
"log_level": str(self.log_level),
|
||||
}
|
||||
|
||||
if "parameters" in sdef:
|
||||
for par in sdef["parameters"]:
|
||||
params[par] = sdef["parameters"][par]
|
||||
|
||||
module_name = sdef["module"]
|
||||
class_name = sdef.get("class", "Processor")
|
||||
|
||||
r, w = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
process = multiprocessing.Process(
|
||||
target=fn,
|
||||
args=(module_name, class_name, params, w)
|
||||
)
|
||||
process.start()
|
||||
|
||||
w.close()
|
||||
|
||||
procs.append(process)
|
||||
services.append(service)
|
||||
readers.append(r)
|
||||
|
||||
wait_for = len(readers)
|
||||
|
||||
while wait_for > 0:
|
||||
|
||||
ret = wait(readers)
|
||||
|
||||
for r in ret:
|
||||
|
||||
try:
|
||||
msg = r.recv()
|
||||
except EOFError:
|
||||
readers.remove(r)
|
||||
wait_for -= 1
|
||||
|
||||
print("All processes exited")
|
||||
|
||||
for p in procs:
|
||||
p.join()
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
|
||||
def run():
|
||||
|
||||
# Seems not to work.
|
||||
# multiprocessing.set_start_method('spawn')
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='run-processing',
|
||||
description=__doc__,
|
||||
)
|
||||
|
||||
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
|
||||
|
||||
parser.add_argument(
|
||||
'-p', '--pulsar-host',
|
||||
default=default_pulsar_host,
|
||||
help=f'Pulsar host (default: {default_pulsar_host})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--log-level',
|
||||
type=LogLevel,
|
||||
default=LogLevel.INFO,
|
||||
choices=list(LogLevel),
|
||||
help=f'Output queue (default: info)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--file',
|
||||
default="processing.yaml",
|
||||
help=f'Processing definition file (default: processing.yaml)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
p = Processing(
|
||||
pulsar_host=args.pulsar_host,
|
||||
file=args.file,
|
||||
log_level=args.log_level,
|
||||
)
|
||||
|
||||
p.run()
|
||||
|
||||
print("Finished.")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
print("Will retry...", flush=True)
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
0
trustgraph-flow/trustgraph/query/__init__.py
Normal file
0
trustgraph-flow/trustgraph/query/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/query/doc_embeddings/milvus/__main__.py
Executable file
7
trustgraph-flow/trustgraph/query/doc_embeddings/milvus/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . hf import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
106
trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py
Executable file
106
trustgraph-flow/trustgraph/query/doc_embeddings/milvus/service.py
Executable file
|
|
@ -0,0 +1,106 @@
|
|||
|
||||
"""
|
||||
Document embeddings query service. Input is vector, output is an array
|
||||
of chunks
|
||||
"""
|
||||
|
||||
from .... direct.milvus_doc_embeddings import DocVectors
|
||||
from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import document_embeddings_request_queue
|
||||
from .... schema import document_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_embeddings_request_queue
|
||||
default_output_queue = document_embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": DocumentEmbeddingsRequest,
|
||||
"output_schema": DocumentEmbeddingsResponse,
|
||||
"store_uri": store_uri,
|
||||
}
|
||||
)
|
||||
|
||||
self.vecstore = DocVectors(store_uri)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
chunks = []
|
||||
|
||||
for vec in v.vectors:
|
||||
|
||||
resp = self.vecstore.search(vec, limit=v.limit)
|
||||
|
||||
for r in resp:
|
||||
chunk = r["entity"]["doc"]
|
||||
chunk = chunk.encode("utf-8")
|
||||
chunks.append(chunk)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = DocumentEmbeddingsResponse(documents=chunks, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = DocumentEmbeddingsResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
documents=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
default=default_store_uri,
|
||||
help=f'Milvus store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/__main__.py
Executable file
7
trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . hf import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
117
trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
Executable file
117
trustgraph-flow/trustgraph/query/doc_embeddings/qdrant/service.py
Executable file
|
|
@ -0,0 +1,117 @@
|
|||
|
||||
"""
|
||||
Document embeddings query service. Input is vector, output is an array
|
||||
of chunks
|
||||
"""
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
import uuid
|
||||
|
||||
from .... schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import document_embeddings_request_queue
|
||||
from .... schema import document_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_embeddings_request_queue
|
||||
default_output_queue = document_embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": DocumentEmbeddingsRequest,
|
||||
"output_schema": DocumentEmbeddingsResponse,
|
||||
"store_uri": store_uri,
|
||||
}
|
||||
)
|
||||
|
||||
self.client = QdrantClient(url=store_uri)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
chunks = []
|
||||
|
||||
for vec in v.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = "doc_" + str(dim)
|
||||
|
||||
search_result = self.client.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
limit=v.limit,
|
||||
with_payload=True,
|
||||
).points
|
||||
|
||||
for r in search_result:
|
||||
ent = r.payload["doc"]
|
||||
chunks.append(ent)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = DocumentEmbeddingsResponse(documents=chunks, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = DocumentEmbeddingsResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
documents=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
default=default_store_uri,
|
||||
help=f'Milvus store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/query/graph_embeddings/milvus/__main__.py
Executable file
7
trustgraph-flow/trustgraph/query/graph_embeddings/milvus/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . hf import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
121
trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py
Executable file
121
trustgraph-flow/trustgraph/query/graph_embeddings/milvus/service.py
Executable file
|
|
@ -0,0 +1,121 @@
|
|||
|
||||
"""
|
||||
Graph embeddings query service. Input is vector, output is list of
|
||||
entities
|
||||
"""
|
||||
|
||||
from .... direct.milvus_graph_embeddings import EntityVectors
|
||||
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import graph_embeddings_request_queue
|
||||
from .... schema import graph_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = graph_embeddings_request_queue
|
||||
default_output_queue = graph_embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_store_uri = 'http://localhost:19530'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": GraphEmbeddingsRequest,
|
||||
"output_schema": GraphEmbeddingsResponse,
|
||||
"store_uri": store_uri,
|
||||
}
|
||||
)
|
||||
|
||||
self.vecstore = EntityVectors(store_uri)
|
||||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
entities = set()
|
||||
|
||||
for vec in v.vectors:
|
||||
|
||||
resp = self.vecstore.search(vec, limit=v.limit)
|
||||
|
||||
for r in resp:
|
||||
ent = r["entity"]["entity"]
|
||||
entities.add(ent)
|
||||
|
||||
# Convert set to list
|
||||
entities = list(entities)
|
||||
|
||||
ents2 = []
|
||||
|
||||
for ent in entities:
|
||||
ents2.append(self.create_value(ent))
|
||||
|
||||
entities = ents2
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = GraphEmbeddingsResponse(entities=entities, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = GraphEmbeddingsResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
entities=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
default=default_store_uri,
|
||||
help=f'Milvus store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/__main__.py
Executable file
7
trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . hf import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
133
trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py
Executable file
133
trustgraph-flow/trustgraph/query/graph_embeddings/qdrant/service.py
Executable file
|
|
@ -0,0 +1,133 @@
|
|||
|
||||
"""
|
||||
Graph embeddings query service. Input is vector, output is list of
|
||||
entities
|
||||
"""
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
import uuid
|
||||
|
||||
from .... schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from .... schema import Error, Value
|
||||
from .... schema import graph_embeddings_request_queue
|
||||
from .... schema import graph_embeddings_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = graph_embeddings_request_queue
|
||||
default_output_queue = graph_embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_store_uri = 'http://localhost:6333'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
store_uri = params.get("store_uri", default_store_uri)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": GraphEmbeddingsRequest,
|
||||
"output_schema": GraphEmbeddingsResponse,
|
||||
"store_uri": store_uri,
|
||||
}
|
||||
)
|
||||
|
||||
self.client = QdrantClient(url=store_uri)
|
||||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
entities = set()
|
||||
|
||||
for vec in v.vectors:
|
||||
|
||||
dim = len(vec)
|
||||
collection = "triples_" + str(dim)
|
||||
|
||||
search_result = self.client.query_points(
|
||||
collection_name=collection,
|
||||
query=vec,
|
||||
limit=v.limit,
|
||||
with_payload=True,
|
||||
).points
|
||||
|
||||
for r in search_result:
|
||||
ent = r.payload["entity"]
|
||||
entities.add(ent)
|
||||
|
||||
# Convert set to list
|
||||
entities = list(entities)
|
||||
|
||||
ents2 = []
|
||||
|
||||
for ent in entities:
|
||||
ents2.append(self.create_value(ent))
|
||||
|
||||
entities = ents2
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = GraphEmbeddingsResponse(entities=entities, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = GraphEmbeddingsResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
entities=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--store-uri',
|
||||
default=default_store_uri,
|
||||
help=f'Milvus store URI (default: {default_store_uri})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
0
trustgraph-flow/trustgraph/query/triples/__init__.py
Normal file
0
trustgraph-flow/trustgraph/query/triples/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/query/triples/cassandra/__main__.py
Executable file
7
trustgraph-flow/trustgraph/query/triples/cassandra/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . hf import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
173
trustgraph-flow/trustgraph/query/triples/cassandra/service.py
Executable file
173
trustgraph-flow/trustgraph/query/triples/cassandra/service.py
Executable file
|
|
@ -0,0 +1,173 @@
|
|||
|
||||
"""
|
||||
Triples query service. Input is a (s, p, o) triple, some values may be
|
||||
null. Output is a list of triples.
|
||||
"""
|
||||
|
||||
from .... direct.cassandra import TrustGraph
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Value, Triple
|
||||
from .... schema import triples_request_queue
|
||||
from .... schema import triples_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = triples_request_queue
|
||||
default_output_queue = triples_response_queue
|
||||
default_subscriber = module
|
||||
default_graph_host='localhost'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
graph_host = params.get("graph_host", default_graph_host)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TriplesQueryRequest,
|
||||
"output_schema": TriplesQueryResponse,
|
||||
"graph_host": graph_host,
|
||||
}
|
||||
)
|
||||
|
||||
self.tg = TrustGraph([graph_host])
|
||||
|
||||
def create_value(self, ent):
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
triples = []
|
||||
|
||||
if v.s is not None:
|
||||
if v.p is not None:
|
||||
if v.o is not None:
|
||||
resp = self.tg.get_spo(
|
||||
v.s.value, v.p.value, v.o.value,
|
||||
limit=v.limit
|
||||
)
|
||||
triples.append((v.s.value, v.p.value, v.o.value))
|
||||
else:
|
||||
resp = self.tg.get_sp(
|
||||
v.s.value, v.p.value,
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((v.s.value, v.p.value, t.o))
|
||||
else:
|
||||
if v.o is not None:
|
||||
resp = self.tg.get_os(
|
||||
v.o.value, v.s.value,
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((v.s.value, t.p, v.o.value))
|
||||
else:
|
||||
resp = self.tg.get_s(
|
||||
v.s.value,
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((v.s.value, t.p, t.o))
|
||||
else:
|
||||
if v.p is not None:
|
||||
if v.o is not None:
|
||||
resp = self.tg.get_po(
|
||||
v.p.value, v.o.value,
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, v.p.value, v.o.value))
|
||||
else:
|
||||
resp = self.tg.get_p(
|
||||
v.p.value,
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, v.p.value, t.o))
|
||||
else:
|
||||
if v.o is not None:
|
||||
resp = self.tg.get_o(
|
||||
v.o.value,
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, t.p, v.o.value))
|
||||
else:
|
||||
resp = self.tg.get_all(
|
||||
limit=v.limit
|
||||
)
|
||||
for t in resp:
|
||||
triples.append((t.s, t.p, t.o))
|
||||
|
||||
triples = [
|
||||
Triple(
|
||||
s=self.create_value(t[0]),
|
||||
p=self.create_value(t[1]),
|
||||
o=self.create_value(t[2])
|
||||
)
|
||||
for t in triples
|
||||
]
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TriplesQueryResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph-host',
|
||||
default="localhost",
|
||||
help=f'Graph host (default: localhost)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/query/triples/neo4j/__main__.py
Executable file
7
trustgraph-flow/trustgraph/query/triples/neo4j/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . hf import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
348
trustgraph-flow/trustgraph/query/triples/neo4j/service.py
Executable file
348
trustgraph-flow/trustgraph/query/triples/neo4j/service.py
Executable file
|
|
@ -0,0 +1,348 @@
|
|||
|
||||
"""
|
||||
Triples query service. Input is a (s, p, o) triple, some values may be
|
||||
null. Output is a list of triples.
|
||||
"""
|
||||
|
||||
from neo4j import GraphDatabase
|
||||
|
||||
from .... schema import TriplesQueryRequest, TriplesQueryResponse, Error
|
||||
from .... schema import Value, Triple
|
||||
from .... schema import triples_request_queue
|
||||
from .... schema import triples_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = triples_request_queue
|
||||
default_output_queue = triples_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
default_graph_host = 'bolt://neo4j:7687'
|
||||
default_username = 'neo4j'
|
||||
default_password = 'password'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
graph_host = params.get("graph_host", default_graph_host)
|
||||
username = params.get("username", default_username)
|
||||
password = params.get("passowrd", default_password)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TriplesQueryRequest,
|
||||
"output_schema": TriplesQueryResponse,
|
||||
"graph_host": graph_host,
|
||||
}
|
||||
)
|
||||
|
||||
self.db = "neo4j"
|
||||
|
||||
self.io = GraphDatabase.driver(graph_host, auth=(username, password))
|
||||
|
||||
def create_value(self, ent):
|
||||
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
else:
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
try:
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
triples = []
|
||||
|
||||
if v.s is not None:
|
||||
if v.p is not None:
|
||||
if v.o is not None:
|
||||
|
||||
# SPO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal {value: $value}) "
|
||||
"RETURN $src as src",
|
||||
src=v.s.value, rel=v.p.value, value=v.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, v.p.value, v.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node {uri: $uri}) "
|
||||
"RETURN $src as src",
|
||||
src=v.s.value, rel=v.p.value, uri=v.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
triples.append((v.s.value, v.p.value, v.o.value))
|
||||
|
||||
else:
|
||||
|
||||
# SP
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Literal) "
|
||||
"RETURN dest.value as dest",
|
||||
src=v.s.value, rel=v.p.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, v.p.value, data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel {uri: $rel}]->(dest:Node) "
|
||||
"RETURN dest.uri as dest",
|
||||
src=v.s.value, rel=v.p.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, v.p.value, data["dest"]))
|
||||
|
||||
else:
|
||||
|
||||
if v.o is not None:
|
||||
|
||||
# SO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal {value: $value}) "
|
||||
"RETURN rel.uri as rel",
|
||||
src=v.s.value, value=v.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, data["rel"], v.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"RETURN rel.uri as rel",
|
||||
src=v.s.value, uri=v.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, data["rel"], v.o.value))
|
||||
|
||||
else:
|
||||
|
||||
# S
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Literal) "
|
||||
"RETURN rel.uri as rel, dest.value as dest",
|
||||
src=v.s.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, data["rel"], data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node {uri: $src})-[rel:Rel]->(dest:Node) "
|
||||
"RETURN rel.uri as rel, dest.uri as dest",
|
||||
src=v.s.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((v.s.value, data["rel"], data["dest"]))
|
||||
|
||||
|
||||
else:
|
||||
|
||||
if v.p is not None:
|
||||
|
||||
if v.o is not None:
|
||||
|
||||
# PO
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal {value: $value}) "
|
||||
"RETURN src.uri as src",
|
||||
uri=v.p.value, value=v.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], v.p.value, v.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node {uri: $uri}) "
|
||||
"RETURN src.uri as src",
|
||||
uri=v.p.value, dest=v.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], v.p.value, v.o.value))
|
||||
|
||||
else:
|
||||
|
||||
# P
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Literal) "
|
||||
"RETURN src.uri as src, dest.value as dest",
|
||||
uri=v.p.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], v.p.value, data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel {uri: $uri}]->(dest:Node) "
|
||||
"RETURN src.uri as src, dest.uri as dest",
|
||||
uri=v.p.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], v.p.value, data["dest"]))
|
||||
|
||||
else:
|
||||
|
||||
if v.o is not None:
|
||||
|
||||
# O
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Literal {value: $value}) "
|
||||
"RETURN src.uri as src, rel.uri as rel",
|
||||
value=v.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], data["rel"], v.o.value))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Node {uri: $uri}) "
|
||||
"RETURN src.uri as src, rel.uri as rel",
|
||||
uri=v.o.value,
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], data["rel"], v.o.value))
|
||||
|
||||
else:
|
||||
|
||||
# *
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Literal) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.value as dest",
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], data["rel"], data["dest"]))
|
||||
|
||||
records, summary, keys = self.io.execute_query(
|
||||
"MATCH (src:Node)-[rel:Rel]->(dest:Node) "
|
||||
"RETURN src.uri as src, rel.uri as rel, dest.uri as dest",
|
||||
database_=self.db,
|
||||
)
|
||||
|
||||
for rec in records:
|
||||
data = rec.data()
|
||||
triples.append((data["src"], data["rel"], data["dest"]))
|
||||
|
||||
triples = [
|
||||
Triple(
|
||||
s=self.create_value(t[0]),
|
||||
p=self.create_value(t[1]),
|
||||
o=self.create_value(t[2])
|
||||
)
|
||||
for t in triples
|
||||
]
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TriplesQueryResponse(triples=triples, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TriplesQueryResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-g', '--graph-host',
|
||||
default=default_graph_host,
|
||||
help=f'Graph host (default: {default_graph_host})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--username',
|
||||
default=default_username,
|
||||
help=f'Neo4j username (default: {default_username})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--password',
|
||||
default=default_password,
|
||||
help=f'Neo4j password (default: {default_password})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
0
trustgraph-flow/trustgraph/retrieval/__init__.py
Normal file
0
trustgraph-flow/trustgraph/retrieval/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue