mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-25 06:38:06 +02:00
Feature/pkgsplit (#83)
* Starting to spawn base package * More package hacking * Bedrock and VertexAI * Parquet split * Updated templates * Utils
This commit is contained in:
parent
3fb75c617b
commit
9b91d5eee3
262 changed files with 630 additions and 420 deletions
6
trustgraph-flow/scripts/chunker-recursive
Executable file
6
trustgraph-flow/scripts/chunker-recursive
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.chunking.recursive import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/chunker-token
Executable file
6
trustgraph-flow/scripts/chunker-token
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.chunking.token import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/de-query-milvus
Executable file
6
trustgraph-flow/scripts/de-query-milvus
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.query.doc_embeddings.milvus import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/de-query-qdrant
Executable file
6
trustgraph-flow/scripts/de-query-qdrant
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.query.doc_embeddings.qdrant import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/de-write-milvus
Executable file
6
trustgraph-flow/scripts/de-write-milvus
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.storage.doc_embeddings.milvus import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/de-write-qdrant
Executable file
6
trustgraph-flow/scripts/de-write-qdrant
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.storage.doc_embeddings.qdrant import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/document-rag
Executable file
6
trustgraph-flow/scripts/document-rag
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.retrieval.document_rag import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/embeddings-ollama
Executable file
6
trustgraph-flow/scripts/embeddings-ollama
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.embeddings.ollama import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/embeddings-vectorize
Executable file
6
trustgraph-flow/scripts/embeddings-vectorize
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.embeddings.vectorize import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/ge-query-milvus
Executable file
6
trustgraph-flow/scripts/ge-query-milvus
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.query.graph_embeddings.milvus import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/ge-query-qdrant
Executable file
6
trustgraph-flow/scripts/ge-query-qdrant
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.query.graph_embeddings.qdrant import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/ge-write-milvus
Executable file
6
trustgraph-flow/scripts/ge-write-milvus
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.storage.graph_embeddings.milvus import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/ge-write-qdrant
Executable file
6
trustgraph-flow/scripts/ge-write-qdrant
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.storage.graph_embeddings.qdrant import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/graph-rag
Executable file
6
trustgraph-flow/scripts/graph-rag
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.retrieval.graph_rag import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/kg-extract-definitions
Executable file
6
trustgraph-flow/scripts/kg-extract-definitions
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.extract.kg.definitions import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/kg-extract-relationships
Executable file
6
trustgraph-flow/scripts/kg-extract-relationships
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.extract.kg.relationships import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/kg-extract-topics
Executable file
6
trustgraph-flow/scripts/kg-extract-topics
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.extract.kg.topics import run
|
||||
|
||||
run()
|
||||
|
||||
5
trustgraph-flow/scripts/metering
Executable file
5
trustgraph-flow/scripts/metering
Executable file
|
|
@ -0,0 +1,5 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.metering import run
|
||||
|
||||
run()
|
||||
6
trustgraph-flow/scripts/object-extract-row
Executable file
6
trustgraph-flow/scripts/object-extract-row
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.extract.object.row import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/oe-write-milvus
Executable file
6
trustgraph-flow/scripts/oe-write-milvus
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.storage.object_embeddings.milvus import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/pdf-decoder
Executable file
6
trustgraph-flow/scripts/pdf-decoder
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.decoding.pdf import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/prompt-generic
Executable file
6
trustgraph-flow/scripts/prompt-generic
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.model.prompt.generic import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/prompt-template
Executable file
6
trustgraph-flow/scripts/prompt-template
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.model.prompt.template import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/rows-write-cassandra
Executable file
6
trustgraph-flow/scripts/rows-write-cassandra
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.storage.rows.cassandra import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/run-processing
Executable file
6
trustgraph-flow/scripts/run-processing
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.processing import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/text-completion-azure
Executable file
6
trustgraph-flow/scripts/text-completion-azure
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.model.text_completion.azure import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/text-completion-claude
Executable file
6
trustgraph-flow/scripts/text-completion-claude
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.model.text_completion.claude import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/text-completion-cohere
Executable file
6
trustgraph-flow/scripts/text-completion-cohere
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.model.text_completion.cohere import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/text-completion-llamafile
Executable file
6
trustgraph-flow/scripts/text-completion-llamafile
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.model.text_completion.llamafile import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/text-completion-ollama
Executable file
6
trustgraph-flow/scripts/text-completion-ollama
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.model.text_completion.ollama import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/text-completion-openai
Executable file
6
trustgraph-flow/scripts/text-completion-openai
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.model.text_completion.openai import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/triples-query-cassandra
Executable file
6
trustgraph-flow/scripts/triples-query-cassandra
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.query.triples.cassandra import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/triples-query-neo4j
Executable file
6
trustgraph-flow/scripts/triples-query-neo4j
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.query.triples.neo4j import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/triples-write-cassandra
Executable file
6
trustgraph-flow/scripts/triples-write-cassandra
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.storage.triples.cassandra import run
|
||||
|
||||
run()
|
||||
|
||||
6
trustgraph-flow/scripts/triples-write-neo4j
Executable file
6
trustgraph-flow/scripts/triples-write-neo4j
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from trustgraph.storage.triples.neo4j import run
|
||||
|
||||
run()
|
||||
|
||||
96
trustgraph-flow/setup.py
Normal file
96
trustgraph-flow/setup.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import setuptools
|
||||
import os
|
||||
import importlib
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
# Load a version number module
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
'version', 'trustgraph/flow_version.py'
|
||||
)
|
||||
version_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(version_module)
|
||||
|
||||
version = version_module.__version__
|
||||
|
||||
setuptools.setup(
|
||||
name="trustgraph-flow",
|
||||
version=version,
|
||||
author="trustgraph.ai",
|
||||
author_email="security@trustgraph.ai",
|
||||
description="TrustGraph provides a means to run a pipeline of flexible AI processing components in a flexible means to achieve a processing pipeline.",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/trustgraph-ai/trustgraph",
|
||||
packages=setuptools.find_namespace_packages(
|
||||
where='./',
|
||||
),
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
python_requires='>=3.8',
|
||||
download_url = "https://github.com/trustgraph-ai/trustgraph/archive/refs/tags/v" + version + ".tar.gz",
|
||||
install_requires=[
|
||||
"trustgraph-base",
|
||||
"urllib3",
|
||||
"rdflib",
|
||||
"pymilvus",
|
||||
"langchain",
|
||||
"langchain-core",
|
||||
"langchain-text-splitters",
|
||||
"langchain-community",
|
||||
"requests",
|
||||
"cassandra-driver",
|
||||
"pulsar-client",
|
||||
"pypdf",
|
||||
"qdrant-client",
|
||||
"tabulate",
|
||||
"anthropic",
|
||||
"pyyaml",
|
||||
"prometheus-client",
|
||||
"cohere",
|
||||
"openai",
|
||||
"neo4j",
|
||||
"tiktoken",
|
||||
],
|
||||
scripts=[
|
||||
"scripts/chunker-recursive",
|
||||
"scripts/chunker-token",
|
||||
"scripts/de-query-milvus",
|
||||
"scripts/de-query-qdrant",
|
||||
"scripts/de-write-milvus",
|
||||
"scripts/de-write-qdrant",
|
||||
"scripts/document-rag",
|
||||
"scripts/embeddings-ollama",
|
||||
"scripts/embeddings-vectorize",
|
||||
"scripts/ge-query-milvus",
|
||||
"scripts/ge-query-qdrant",
|
||||
"scripts/ge-write-milvus",
|
||||
"scripts/ge-write-qdrant",
|
||||
"scripts/graph-rag",
|
||||
"scripts/kg-extract-definitions",
|
||||
"scripts/kg-extract-topics",
|
||||
"scripts/kg-extract-relationships",
|
||||
"scripts/metering",
|
||||
"scripts/object-extract-row",
|
||||
"scripts/oe-write-milvus",
|
||||
"scripts/pdf-decoder",
|
||||
"scripts/prompt-generic",
|
||||
"scripts/prompt-template",
|
||||
"scripts/rows-write-cassandra",
|
||||
"scripts/run-processing",
|
||||
"scripts/text-completion-azure",
|
||||
"scripts/text-completion-claude",
|
||||
"scripts/text-completion-cohere",
|
||||
"scripts/text-completion-llamafile",
|
||||
"scripts/text-completion-ollama",
|
||||
"scripts/text-completion-openai",
|
||||
"scripts/triples-query-cassandra",
|
||||
"scripts/triples-query-neo4j",
|
||||
"scripts/triples-write-cassandra",
|
||||
"scripts/triples-write-neo4j",
|
||||
]
|
||||
)
|
||||
0
trustgraph-flow/trustgraph/__init__.py
Normal file
0
trustgraph-flow/trustgraph/__init__.py
Normal file
0
trustgraph-flow/trustgraph/chunking/__init__.py
Normal file
0
trustgraph-flow/trustgraph/chunking/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . chunker import *
|
||||
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . chunker import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
108
trustgraph-flow/trustgraph/chunking/recursive/chunker.py
Executable file
108
trustgraph-flow/trustgraph/chunking/recursive/chunker.py
Executable file
|
|
@ -0,0 +1,108 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts text documents on input, outputs chunks from the
|
||||
as text as separate output objects.
|
||||
"""
|
||||
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk, Source
|
||||
from ... schema import text_ingest_queue, chunk_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_ingest_queue
|
||||
default_output_queue = chunk_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
chunk_size = params.get("chunk_size", 2000)
|
||||
chunk_overlap = params.get("chunk_overlap", 100)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextDocument,
|
||||
"output_schema": Chunk,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
'chunk_size', 'Chunk size',
|
||||
buckets=[100, 160, 250, 400, 650, 1000, 1600,
|
||||
2500, 4000, 6400, 10000, 16000]
|
||||
)
|
||||
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
is_separator_regex=False,
|
||||
)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Chunking {v.source.id}...", flush=True)
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
)
|
||||
|
||||
for ix, chunk in enumerate(texts):
|
||||
|
||||
id = v.source.id + "-c" + str(ix)
|
||||
|
||||
r = Chunk(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
id=id,
|
||||
title=v.source.title
|
||||
),
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
)
|
||||
|
||||
__class__.chunk_metric.observe(len(chunk.page_content))
|
||||
|
||||
self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--chunk-size',
|
||||
type=int,
|
||||
default=2000,
|
||||
help=f'Chunk size (default: 2000)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-v', '--chunk-overlap',
|
||||
type=int,
|
||||
default=100,
|
||||
help=f'Chunk overlap (default: 100)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
3
trustgraph-flow/trustgraph/chunking/token/__init__.py
Normal file
3
trustgraph-flow/trustgraph/chunking/token/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . chunker import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/chunking/token/__main__.py
Normal file
7
trustgraph-flow/trustgraph/chunking/token/__main__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . chunker import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
107
trustgraph-flow/trustgraph/chunking/token/chunker.py
Executable file
107
trustgraph-flow/trustgraph/chunking/token/chunker.py
Executable file
|
|
@ -0,0 +1,107 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts text documents on input, outputs chunks from the
|
||||
as text as separate output objects.
|
||||
"""
|
||||
|
||||
from langchain_text_splitters import TokenTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk, Source
|
||||
from ... schema import text_ingest_queue, chunk_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_ingest_queue
|
||||
default_output_queue = chunk_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
chunk_size = params.get("chunk_size", 250)
|
||||
chunk_overlap = params.get("chunk_overlap", 15)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextDocument,
|
||||
"output_schema": Chunk,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
'chunk_size', 'Chunk size',
|
||||
buckets=[100, 160, 250, 400, 650, 1000, 1600,
|
||||
2500, 4000, 6400, 10000, 16000]
|
||||
)
|
||||
|
||||
self.text_splitter = TokenTextSplitter(
|
||||
encoding_name="cl100k_base",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Chunking {v.source.id}...", flush=True)
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
)
|
||||
|
||||
for ix, chunk in enumerate(texts):
|
||||
|
||||
id = v.source.id + "-c" + str(ix)
|
||||
|
||||
r = Chunk(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
id=id,
|
||||
title=v.source.title
|
||||
),
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
)
|
||||
|
||||
__class__.chunk_metric.observe(len(chunk.page_content))
|
||||
|
||||
self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--chunk-size',
|
||||
type=int,
|
||||
default=250,
|
||||
help=f'Chunk size (default: 250)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-v', '--chunk-overlap',
|
||||
type=int,
|
||||
default=15,
|
||||
help=f'Chunk overlap (default: 15)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
0
trustgraph-flow/trustgraph/decoding/__init__.py
Normal file
0
trustgraph-flow/trustgraph/decoding/__init__.py
Normal file
3
trustgraph-flow/trustgraph/decoding/pdf/__init__.py
Normal file
3
trustgraph-flow/trustgraph/decoding/pdf/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . pdf_decoder import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/decoding/pdf/__main__.py
Executable file
7
trustgraph-flow/trustgraph/decoding/pdf/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . pdf_decoder import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
87
trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
Executable file
87
trustgraph-flow/trustgraph/decoding/pdf/pdf_decoder.py
Executable file
|
|
@ -0,0 +1,87 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts PDF documents on input, outputs pages from the
|
||||
PDF document as text as separate output objects.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import base64
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
|
||||
from ... schema import Document, TextDocument, Source
|
||||
from ... schema import document_ingest_queue, text_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_ingest_queue
|
||||
default_output_queue = text_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Document,
|
||||
"output_schema": TextDocument,
|
||||
}
|
||||
)
|
||||
|
||||
print("PDF inited")
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
print("PDF message received")
|
||||
|
||||
v = msg.value()
|
||||
|
||||
print(f"Decoding {v.source.id}...", flush=True)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
|
||||
|
||||
fp.write(base64.b64decode(v.data))
|
||||
fp.close()
|
||||
|
||||
with open(fp.name, mode='rb') as f:
|
||||
|
||||
loader = PyPDFLoader(fp.name)
|
||||
pages = loader.load()
|
||||
|
||||
for ix, page in enumerate(pages):
|
||||
|
||||
id = v.source.id + "-p" + str(ix)
|
||||
r = TextDocument(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
title=v.source.title,
|
||||
id=id,
|
||||
),
|
||||
text=page.page_content.encode("utf-8"),
|
||||
)
|
||||
|
||||
self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
0
trustgraph-flow/trustgraph/direct/__init__.py
Normal file
0
trustgraph-flow/trustgraph/direct/__init__.py
Normal file
108
trustgraph-flow/trustgraph/direct/cassandra.py
Normal file
108
trustgraph-flow/trustgraph/direct/cassandra.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
|
||||
class TrustGraph:
|
||||
|
||||
def __init__(self, hosts=None):
|
||||
|
||||
if hosts is None:
|
||||
hosts = ["localhost"]
|
||||
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
self.init()
|
||||
|
||||
def clear(self):
|
||||
|
||||
self.session.execute("""
|
||||
drop keyspace if exists trustgraph;
|
||||
""");
|
||||
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
|
||||
self.session.execute("""
|
||||
create keyspace if not exists trustgraph
|
||||
with replication = {
|
||||
'class' : 'SimpleStrategy',
|
||||
'replication_factor' : 1
|
||||
};
|
||||
""");
|
||||
|
||||
self.session.set_keyspace('trustgraph')
|
||||
|
||||
self.session.execute("""
|
||||
create table if not exists triples (
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
PRIMARY KEY (s, p, o)
|
||||
);
|
||||
""");
|
||||
|
||||
self.session.execute("""
|
||||
create index if not exists triples_p
|
||||
ON triples (p);
|
||||
""");
|
||||
|
||||
self.session.execute("""
|
||||
create index if not exists triples_o
|
||||
ON triples (o);
|
||||
""");
|
||||
|
||||
def insert(self, s, p, o):
|
||||
|
||||
self.session.execute(
|
||||
"insert into triples (s, p, o) values (%s, %s, %s)",
|
||||
(s, p, o)
|
||||
)
|
||||
|
||||
def get_all(self, limit=50):
|
||||
return self.session.execute(
|
||||
f"select s, p, o from triples limit {limit}"
|
||||
)
|
||||
|
||||
def get_s(self, s, limit=10):
|
||||
return self.session.execute(
|
||||
f"select p, o from triples where s = %s limit {limit}",
|
||||
(s,)
|
||||
)
|
||||
|
||||
def get_p(self, p, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s, o from triples where p = %s limit {limit}",
|
||||
(p,)
|
||||
)
|
||||
|
||||
def get_o(self, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s, p from triples where o = %s limit {limit}",
|
||||
(o,)
|
||||
)
|
||||
|
||||
def get_sp(self, s, p, limit=10):
|
||||
return self.session.execute(
|
||||
f"select o from triples where s = %s and p = %s limit {limit}",
|
||||
(s, p)
|
||||
)
|
||||
|
||||
def get_po(self, p, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s from triples where p = %s and o = %s allow filtering limit {limit}",
|
||||
(p, o)
|
||||
)
|
||||
|
||||
def get_os(self, o, s, limit=10):
|
||||
return self.session.execute(
|
||||
f"select p from triples where o = %s and s = %s limit {limit}",
|
||||
(o, s)
|
||||
)
|
||||
|
||||
def get_spo(self, s, p, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"""select s as x from triples where s = %s and p = %s and o = %s limit {limit}""",
|
||||
(s, p, o)
|
||||
)
|
||||
138
trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py
Normal file
138
trustgraph-flow/trustgraph/direct/milvus_doc_embeddings.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
|
||||
class DocVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='doc'):
|
||||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
# Time between reloads
|
||||
self.reload_time = 90
|
||||
|
||||
# Next time to reload - this forces a reload at next window
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
print("Reload at", self.next_reload)
|
||||
|
||||
def init_collection(self, dimension):
|
||||
|
||||
collection_name = self.prefix + "_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
doc_field = FieldSchema(
|
||||
name="doc",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [pkey_field, vec_field, doc_field],
|
||||
description = "Document embedding schema",
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[dimension] = collection_name
|
||||
|
||||
def insert(self, embeds, doc):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"doc": doc,
|
||||
}
|
||||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[dim],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, fields=["doc"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
coll = self.collections[dim]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
print("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
print("Searching...")
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
|
||||
# If reload time has passed, unload collection
|
||||
if time.time() > self.next_reload:
|
||||
print("Unloading, reload at", self.next_reload)
|
||||
self.client.release_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
|
||||
return res
|
||||
|
||||
138
trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py
Normal file
138
trustgraph-flow/trustgraph/direct/milvus_graph_embeddings.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
|
||||
class EntityVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='entity'):
|
||||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
# Time between reloads
|
||||
self.reload_time = 90
|
||||
|
||||
# Next time to reload - this forces a reload at next window
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
print("Reload at", self.next_reload)
|
||||
|
||||
def init_collection(self, dimension):
|
||||
|
||||
collection_name = self.prefix + "_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
entity_field = FieldSchema(
|
||||
name="entity",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [pkey_field, vec_field, entity_field],
|
||||
description = "Graph embedding schema",
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[dimension] = collection_name
|
||||
|
||||
def insert(self, embeds, entity):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"entity": entity,
|
||||
}
|
||||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[dim],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, fields=["entity"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
coll = self.collections[dim]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
print("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
print("Searching...")
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
|
||||
# If reload time has passed, unload collection
|
||||
if time.time() > self.next_reload:
|
||||
print("Unloading, reload at", self.next_reload)
|
||||
self.client.release_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
|
||||
return res
|
||||
|
||||
154
trustgraph-flow/trustgraph/direct/milvus_object_embeddings.py
Normal file
154
trustgraph-flow/trustgraph/direct/milvus_object_embeddings.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
|
||||
class ObjectVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='obj'):
|
||||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
# Time between reloads
|
||||
self.reload_time = 90
|
||||
|
||||
# Next time to reload - this forces a reload at next window
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
print("Reload at", self.next_reload)
|
||||
|
||||
def init_collection(self, dimension, name):
|
||||
|
||||
collection_name = self.prefix + "_" + name + "_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
name_field = FieldSchema(
|
||||
name="name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
key_name_field = FieldSchema(
|
||||
name="key_name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
key_field = FieldSchema(
|
||||
name="key",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [
|
||||
pkey_field, vec_field, name_field, key_name_field, key_field
|
||||
],
|
||||
description = "Object embedding schema",
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[(dimension, name)] = collection_name
|
||||
|
||||
def insert(self, embeds, name, key_name, key):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if (dim, name) not in self.collections:
|
||||
self.init_collection(dim, name)
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"name": name,
|
||||
"key_name": key_name,
|
||||
"key": key,
|
||||
}
|
||||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[(dim, name)],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, name, fields=["key_name", "name"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim, name)
|
||||
|
||||
coll = self.collections[(dim, name)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
print("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
print("Searching...")
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
|
||||
# If reload time has passed, unload collection
|
||||
if time.time() > self.next_reload:
|
||||
print("Unloading, reload at", self.next_reload)
|
||||
self.client.release_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
|
||||
return res
|
||||
|
||||
132
trustgraph-flow/trustgraph/document_rag.py
Normal file
132
trustgraph-flow/trustgraph/document_rag.py
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
|
||||
from . clients.document_embeddings_client import DocumentEmbeddingsClient
|
||||
from . clients.triples_query_client import TriplesQueryClient
|
||||
from . clients.embeddings_client import EmbeddingsClient
|
||||
from . clients.prompt_client import PromptClient
|
||||
|
||||
from . schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||
from . schema import TriplesQueryRequest, TriplesQueryResponse
|
||||
from . schema import prompt_request_queue
|
||||
from . schema import prompt_response_queue
|
||||
from . schema import embeddings_request_queue
|
||||
from . schema import embeddings_response_queue
|
||||
from . schema import document_embeddings_request_queue
|
||||
from . schema import document_embeddings_response_queue
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
|
||||
|
||||
class DocumentRag:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pr_request_queue=None,
|
||||
pr_response_queue=None,
|
||||
emb_request_queue=None,
|
||||
emb_response_queue=None,
|
||||
de_request_queue=None,
|
||||
de_response_queue=None,
|
||||
verbose=False,
|
||||
module="test",
|
||||
):
|
||||
|
||||
self.verbose=verbose
|
||||
|
||||
if pr_request_queue is None:
|
||||
pr_request_queue = prompt_request_queue
|
||||
|
||||
if pr_response_queue is None:
|
||||
pr_response_queue = prompt_response_queue
|
||||
|
||||
if emb_request_queue is None:
|
||||
emb_request_queue = embeddings_request_queue
|
||||
|
||||
if emb_response_queue is None:
|
||||
emb_response_queue = embeddings_response_queue
|
||||
|
||||
if de_request_queue is None:
|
||||
de_request_queue = document_embeddings_request_queue
|
||||
|
||||
if de_response_queue is None:
|
||||
de_response_queue = document_embeddings_response_queue
|
||||
|
||||
if self.verbose:
|
||||
print("Initialising...", flush=True)
|
||||
|
||||
# FIXME: Configurable
|
||||
self.entity_limit = 20
|
||||
|
||||
self.de_client = DocumentEmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
subscriber=module + "-de",
|
||||
input_queue=de_request_queue,
|
||||
output_queue=de_response_queue,
|
||||
)
|
||||
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
self.lang = PromptClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber=module + "-de-prompt",
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = self.embeddings.request(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
def get_docs(self, query):
|
||||
|
||||
vectors = self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
docs = self.de_client.request(
|
||||
vectors, self.entity_limit
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Docs:", flush=True)
|
||||
for doc in docs:
|
||||
print(doc, flush=True)
|
||||
|
||||
return docs
|
||||
|
||||
def query(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
docs = self.get_docs(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(docs)
|
||||
print(query)
|
||||
|
||||
resp = self.lang.request_document_prompt(query, docs)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
||||
return resp
|
||||
|
||||
0
trustgraph-flow/trustgraph/embeddings/__init__.py
Normal file
0
trustgraph-flow/trustgraph/embeddings/__init__.py
Normal file
3
trustgraph-flow/trustgraph/embeddings/ollama/__init__.py
Normal file
3
trustgraph-flow/trustgraph/embeddings/ollama/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . processor import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/embeddings/ollama/__main__.py
Executable file
7
trustgraph-flow/trustgraph/embeddings/ollama/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . processor import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
84
trustgraph-flow/trustgraph/embeddings/ollama/processor.py
Executable file
84
trustgraph-flow/trustgraph/embeddings/ollama/processor.py
Executable file
|
|
@ -0,0 +1,84 @@
|
|||
|
||||
"""
|
||||
Embeddings service, applies an embeddings model selected from HuggingFace.
|
||||
Input is text, output is embeddings vector.
|
||||
"""
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
|
||||
from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from ... schema import embeddings_request_queue, embeddings_response_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = embeddings_request_queue
|
||||
default_output_queue = embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_model="mxbai-embed-large"
|
||||
default_ollama = 'http://localhost:11434'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": EmbeddingsRequest,
|
||||
"output_schema": EmbeddingsResponse,
|
||||
}
|
||||
)
|
||||
|
||||
self.embeddings = OllamaEmbeddings(base_url=ollama, model=model)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
text = v.text
|
||||
embeds = self.embeddings.embed_query([text])
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = EmbeddingsResponse(vectors=[embeds])
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default=default_model,
|
||||
help=f'Embeddings model (default: {default_model})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-r', '--ollama',
|
||||
default=default_ollama,
|
||||
help=f'ollama (default: {default_ollama})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . vectorize import *
|
||||
|
||||
6
trustgraph-flow/trustgraph/embeddings/vectorize/__main__.py
Executable file
6
trustgraph-flow/trustgraph/embeddings/vectorize/__main__.py
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
|
||||
from . vectorize import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
103
trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py
Executable file
103
trustgraph-flow/trustgraph/embeddings/vectorize/vectorize.py
Executable file
|
|
@ -0,0 +1,103 @@
|
|||
|
||||
"""
|
||||
Vectorizer, calls the embeddings service to get embeddings for a chunk.
|
||||
Input is text chunk, output is chunk and vectors.
|
||||
"""
|
||||
|
||||
from ... schema import Chunk, ChunkEmbeddings
|
||||
from ... schema import chunk_ingest_queue, chunk_embeddings_ingest_queue
|
||||
from ... schema import embeddings_request_queue, embeddings_response_queue
|
||||
from ... clients.embeddings_client import EmbeddingsClient
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_ingest_queue
|
||||
default_output_queue = chunk_embeddings_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
emb_request_queue = params.get(
|
||||
"embeddings_request_queue", embeddings_request_queue
|
||||
)
|
||||
emb_response_queue = params.get(
|
||||
"embeddings_response_queue", embeddings_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"embeddings_request_queue": emb_request_queue,
|
||||
"embeddings_response_queue": emb_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Chunk,
|
||||
"output_schema": ChunkEmbeddings,
|
||||
}
|
||||
)
|
||||
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
def emit(self, source, chunk, vectors):
|
||||
|
||||
r = ChunkEmbeddings(source=source, chunk=chunk, vectors=vectors)
|
||||
self.producer.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
vectors = self.embeddings.request(chunk)
|
||||
|
||||
self.emit(
|
||||
source=v.source,
|
||||
chunk=chunk.encode("utf-8"),
|
||||
vectors=vectors
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-request-queue',
|
||||
default=embeddings_request_queue,
|
||||
help=f'Embeddings request queue (default: {embeddings_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-response-queue',
|
||||
default=embeddings_response_queue,
|
||||
help=f'Embeddings request queue (default: {embeddings_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
0
trustgraph-flow/trustgraph/extract/__init__.py
Normal file
0
trustgraph-flow/trustgraph/extract/__init__.py
Normal file
0
trustgraph-flow/trustgraph/extract/kg/__init__.py
Normal file
0
trustgraph-flow/trustgraph/extract/kg/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/extract/kg/definitions/__main__.py
Executable file
7
trustgraph-flow/trustgraph/extract/kg/definitions/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
134
trustgraph-flow/trustgraph/extract/kg/definitions/extract.py
Executable file
134
trustgraph-flow/trustgraph/extract/kg/definitions/extract.py
Executable file
|
|
@ -0,0 +1,134 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
|
||||
get entity definitions which are output as graph edges.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import json
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Source, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Triple,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
def to_uri(self, text):
|
||||
|
||||
part = text.replace(" ", "-").lower().encode("utf-8")
|
||||
quoted = urllib.parse.quote(part)
|
||||
uri = TRUSTGRAPH_ENTITIES + quoted
|
||||
|
||||
return uri
|
||||
|
||||
def get_definitions(self, chunk):
|
||||
|
||||
return self.prompt.request_definitions(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
defs = self.get_definitions(chunk)
|
||||
|
||||
for defn in defs:
|
||||
|
||||
s = defn.name
|
||||
o = defn.definition
|
||||
|
||||
if s == "": continue
|
||||
if o == "": continue
|
||||
|
||||
if s is None: continue
|
||||
if o is None: continue
|
||||
|
||||
s_uri = self.to_uri(s)
|
||||
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-completion-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/extract/kg/relationships/__main__.py
Executable file
7
trustgraph-flow/trustgraph/extract/kg/relationships/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
208
trustgraph-flow/trustgraph/extract/kg/relationships/extract.py
Executable file
208
trustgraph-flow/trustgraph/extract/kg/relationships/extract.py
Executable file
|
|
@ -0,0 +1,208 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts vector+text chunks input, applies entity
|
||||
relationship analysis to get entity relationship edges which are output as
|
||||
graph edges.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import graph_embeddings_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_vector_queue = graph_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
vector_queue = params.get("vector_queue", default_vector_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Triple,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.vec_prod = self.client.create_producer(
|
||||
topic=vector_queue,
|
||||
schema=JsonSchema(GraphEmbeddings),
|
||||
)
|
||||
|
||||
__class__.pubsub_metric.info({
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"vector_queue": vector_queue,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings.__name__,
|
||||
"output_schema": Triple.__name__,
|
||||
"vector_schema": GraphEmbeddings.__name__,
|
||||
})
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
def to_uri(self, text):
|
||||
|
||||
part = text.replace(" ", "-").lower().encode("utf-8")
|
||||
quoted = urllib.parse.quote(part)
|
||||
uri = TRUSTGRAPH_ENTITIES + quoted
|
||||
|
||||
return uri
|
||||
|
||||
def get_relationships(self, chunk):
|
||||
|
||||
return self.prompt.request_relationships(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_vec(self, ent, vec):
|
||||
|
||||
r = GraphEmbeddings(entity=ent, vectors=vec)
|
||||
self.vec_prod.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
rels = self.get_relationships(chunk)
|
||||
|
||||
for rel in rels:
|
||||
|
||||
s = rel.s
|
||||
p = rel.p
|
||||
o = rel.o
|
||||
|
||||
if s == "": continue
|
||||
if p == "": continue
|
||||
if o == "": continue
|
||||
|
||||
if s is None: continue
|
||||
if p is None: continue
|
||||
if o is None: continue
|
||||
|
||||
s_uri = self.to_uri(s)
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
|
||||
p_uri = self.to_uri(p)
|
||||
p_value = Value(value=str(p_uri), is_uri=True)
|
||||
|
||||
if rel.o_entity:
|
||||
o_uri = self.to_uri(o)
|
||||
o_value = Value(value=str(o_uri), is_uri=True)
|
||||
else:
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(
|
||||
s_value,
|
||||
p_value,
|
||||
o_value
|
||||
)
|
||||
|
||||
# Label for s
|
||||
self.emit_edge(
|
||||
s_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(s), is_uri=False)
|
||||
)
|
||||
|
||||
# Label for p
|
||||
self.emit_edge(
|
||||
p_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(p), is_uri=False)
|
||||
)
|
||||
|
||||
if rel.o_entity:
|
||||
# Label for o
|
||||
self.emit_edge(
|
||||
o_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(o), is_uri=False)
|
||||
)
|
||||
|
||||
self.emit_vec(s_value, v.vectors)
|
||||
self.emit_vec(p_value, v.vectors)
|
||||
if rel.o_entity:
|
||||
self.emit_vec(o_value, v.vectors)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--vector-queue',
|
||||
default=default_vector_queue,
|
||||
help=f'Vector output queue (default: {default_vector_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
3
trustgraph-flow/trustgraph/extract/kg/topics/__init__.py
Normal file
3
trustgraph-flow/trustgraph/extract/kg/topics/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/extract/kg/topics/__main__.py
Executable file
7
trustgraph-flow/trustgraph/extract/kg/topics/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
134
trustgraph-flow/trustgraph/extract/kg/topics/extract.py
Executable file
134
trustgraph-flow/trustgraph/extract/kg/topics/extract.py
Executable file
|
|
@ -0,0 +1,134 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
|
||||
get entity definitions which are output as graph edges.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import json
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Source, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Triple,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
def to_uri(self, text):
|
||||
|
||||
part = text.replace(" ", "-").lower().encode("utf-8")
|
||||
quoted = urllib.parse.quote(part)
|
||||
uri = TRUSTGRAPH_ENTITIES + quoted
|
||||
|
||||
return uri
|
||||
|
||||
def get_topics(self, chunk):
|
||||
|
||||
return self.prompt.request_topics(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
defs = self.get_topics(chunk)
|
||||
|
||||
for defn in defs:
|
||||
|
||||
s = defn.name
|
||||
o = defn.definition
|
||||
|
||||
if s == "": continue
|
||||
if o == "": continue
|
||||
|
||||
if s is None: continue
|
||||
if o is None: continue
|
||||
|
||||
s_uri = self.to_uri(s)
|
||||
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-completion-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
0
trustgraph-flow/trustgraph/extract/object/__init__.py
Normal file
0
trustgraph-flow/trustgraph/extract/object/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/extract/object/row/__main__.py
Executable file
7
trustgraph-flow/trustgraph/extract/object/row/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
220
trustgraph-flow/trustgraph/extract/object/row/extract.py
Executable file
220
trustgraph-flow/trustgraph/extract/object/row/extract.py
Executable file
|
|
@ -0,0 +1,220 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts vector+text chunks input, applies analysis to pull
|
||||
out a row of fields. Output as a vector plus object.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Source
|
||||
from .... schema import RowSchema, Field
|
||||
from .... schema import chunk_embeddings_ingest_queue, rows_store_queue
|
||||
from .... schema import object_embeddings_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
from .... objects.field import Field as FieldParser
|
||||
from .... objects.object import Schema
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = rows_store_queue
|
||||
default_vector_queue = object_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
vector_queue = params.get("vector_queue", default_vector_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Rows,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.vec_prod = self.client.create_producer(
|
||||
topic=vector_queue,
|
||||
schema=JsonSchema(ObjectEmbeddings),
|
||||
)
|
||||
|
||||
__class__.pubsub_metric.info({
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"vector_queue": vector_queue,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings.__name__,
|
||||
"output_schema": Rows.__name__,
|
||||
"vector_schema": ObjectEmbeddings.__name__,
|
||||
})
|
||||
|
||||
flds = __class__.parse_fields(params["field"])
|
||||
|
||||
for fld in flds:
|
||||
print(fld)
|
||||
|
||||
self.primary = None
|
||||
|
||||
for f in flds:
|
||||
if f.primary:
|
||||
if self.primary:
|
||||
raise RuntimeError(
|
||||
"Only one primary key field is supported"
|
||||
)
|
||||
self.primary = f
|
||||
|
||||
if self.primary == None:
|
||||
raise RuntimeError(
|
||||
"Must have exactly one primary key field"
|
||||
)
|
||||
|
||||
self.schema = Schema(
|
||||
name = params["name"],
|
||||
description = params["description"],
|
||||
fields = flds
|
||||
)
|
||||
|
||||
self.row_schema=RowSchema(
|
||||
name=self.schema.name,
|
||||
description=self.schema.description,
|
||||
fields=[
|
||||
Field(
|
||||
name=f.name, type=str(f.type), size=f.size,
|
||||
primary=f.primary, description=f.description,
|
||||
)
|
||||
for f in self.schema.fields
|
||||
]
|
||||
)
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_fields(fields):
|
||||
return [ FieldParser.parse(f) for f in fields ]
|
||||
|
||||
def get_rows(self, chunk):
|
||||
return self.prompt.request_rows(self.schema, chunk)
|
||||
|
||||
def emit_rows(self, source, rows):
|
||||
|
||||
t = Rows(
|
||||
source=source, row_schema=self.row_schema, rows=rows
|
||||
)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_vec(self, source, name, vec, key_name, key):
|
||||
|
||||
r = ObjectEmbeddings(
|
||||
source=source, vectors=vec, name=name, key_name=key_name, id=key
|
||||
)
|
||||
self.vec_prod.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
rows = self.get_rows(chunk)
|
||||
|
||||
self.emit_rows(
|
||||
source=v.source,
|
||||
rows=rows
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
self.emit_vec(
|
||||
source=v.source, vec=v.vectors,
|
||||
name=self.schema.name, key_name=self.primary.name,
|
||||
key=row[self.primary.name]
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
print(row)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--vector-queue',
|
||||
default=default_vector_queue,
|
||||
help=f'Vector output queue (default: {default_vector_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--field',
|
||||
required=True,
|
||||
action='append',
|
||||
help=f'Field definition, format name:type:size:pri:descriptionn',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-n', '--name',
|
||||
required=True,
|
||||
help=f'Name of row object',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-d', '--description',
|
||||
required=True,
|
||||
help=f'Description of object',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
250
trustgraph-flow/trustgraph/graph_rag.py
Normal file
250
trustgraph-flow/trustgraph/graph_rag.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
|
||||
from . clients.graph_embeddings_client import GraphEmbeddingsClient
|
||||
from . clients.triples_query_client import TriplesQueryClient
|
||||
from . clients.embeddings_client import EmbeddingsClient
|
||||
from . clients.prompt_client import PromptClient
|
||||
|
||||
from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from . schema import TriplesQueryRequest, TriplesQueryResponse
|
||||
from . schema import prompt_request_queue
|
||||
from . schema import prompt_response_queue
|
||||
from . schema import embeddings_request_queue
|
||||
from . schema import embeddings_response_queue
|
||||
from . schema import graph_embeddings_request_queue
|
||||
from . schema import graph_embeddings_response_queue
|
||||
from . schema import triples_request_queue
|
||||
from . schema import triples_response_queue
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
|
||||
|
||||
class GraphRag:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pr_request_queue=None,
|
||||
pr_response_queue=None,
|
||||
emb_request_queue=None,
|
||||
emb_response_queue=None,
|
||||
ge_request_queue=None,
|
||||
ge_response_queue=None,
|
||||
tpl_request_queue=None,
|
||||
tpl_response_queue=None,
|
||||
verbose=False,
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=3000,
|
||||
module="test",
|
||||
):
|
||||
|
||||
self.verbose=verbose
|
||||
|
||||
if pr_request_queue is None:
|
||||
pr_request_queue = prompt_request_queue
|
||||
|
||||
if pr_response_queue is None:
|
||||
pr_response_queue = prompt_response_queue
|
||||
|
||||
if emb_request_queue is None:
|
||||
emb_request_queue = embeddings_request_queue
|
||||
|
||||
if emb_response_queue is None:
|
||||
emb_response_queue = embeddings_response_queue
|
||||
|
||||
if ge_request_queue is None:
|
||||
ge_request_queue = graph_embeddings_request_queue
|
||||
|
||||
if ge_response_queue is None:
|
||||
ge_response_queue = graph_embeddings_response_queue
|
||||
|
||||
if tpl_request_queue is None:
|
||||
tpl_request_queue = triples_request_queue
|
||||
|
||||
if tpl_response_queue is None:
|
||||
tpl_response_queue = triples_response_queue
|
||||
|
||||
if self.verbose:
|
||||
print("Initialising...", flush=True)
|
||||
|
||||
self.ge_client = GraphEmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
subscriber=module + "-ge",
|
||||
input_queue=ge_request_queue,
|
||||
output_queue=ge_response_queue,
|
||||
)
|
||||
|
||||
self.triples_client = TriplesQueryClient(
|
||||
pulsar_host=pulsar_host,
|
||||
subscriber=module + "-tpl",
|
||||
input_queue=tpl_request_queue,
|
||||
output_queue=tpl_response_queue
|
||||
)
|
||||
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
self.entity_limit=entity_limit
|
||||
self.query_limit=triple_limit
|
||||
self.max_subgraph_size=max_subgraph_size
|
||||
|
||||
self.label_cache = {}
|
||||
|
||||
self.lang = PromptClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber=module + "-prompt",
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = self.embeddings.request(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
def get_entities(self, query):
|
||||
|
||||
vectors = self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
entities = self.ge_client.request(
|
||||
vectors, self.entity_limit
|
||||
)
|
||||
|
||||
entities = [
|
||||
e.value
|
||||
for e in entities
|
||||
]
|
||||
|
||||
if self.verbose:
|
||||
print("Entities:", flush=True)
|
||||
for ent in entities:
|
||||
print(" ", ent, flush=True)
|
||||
|
||||
return entities
|
||||
|
||||
def maybe_label(self, e):
|
||||
|
||||
if e in self.label_cache:
|
||||
return self.label_cache[e]
|
||||
|
||||
res = self.triples_client.request(
|
||||
e, LABEL, None, limit=1
|
||||
)
|
||||
|
||||
if len(res) == 0:
|
||||
self.label_cache[e] = e
|
||||
return e
|
||||
|
||||
self.label_cache[e] = res[0].o.value
|
||||
return self.label_cache[e]
|
||||
|
||||
def get_subgraph(self, query):
|
||||
|
||||
entities = self.get_entities(query)
|
||||
|
||||
subgraph = set()
|
||||
|
||||
if self.verbose:
|
||||
print("Get subgraph...", flush=True)
|
||||
|
||||
for e in entities:
|
||||
|
||||
res = self.triples_client.request(
|
||||
e, None, None,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.triples_client.request(
|
||||
None, e, None,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.triples_client.request(
|
||||
None, None, e,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
|
||||
subgraph = subgraph[0:self.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
print("Subgraph:", flush=True)
|
||||
for edge in subgraph:
|
||||
print(" ", str(edge), flush=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return subgraph
|
||||
|
||||
def get_labelgraph(self, query):
|
||||
|
||||
subgraph = self.get_subgraph(query)
|
||||
|
||||
sg2 = []
|
||||
|
||||
for edge in subgraph:
|
||||
|
||||
if edge[1] == LABEL:
|
||||
continue
|
||||
|
||||
s = self.maybe_label(edge[0])
|
||||
p = self.maybe_label(edge[1])
|
||||
o = self.maybe_label(edge[2])
|
||||
|
||||
sg2.append((s, p, o))
|
||||
|
||||
return sg2
|
||||
|
||||
def query(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
kg = self.get_labelgraph(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(kg)
|
||||
print(query)
|
||||
|
||||
resp = self.lang.request_kg_prompt(query, kg)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
||||
return resp
|
||||
|
||||
3
trustgraph-flow/trustgraph/metering/__init__.py
Normal file
3
trustgraph-flow/trustgraph/metering/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . counter import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/metering/__main__.py
Executable file
7
trustgraph-flow/trustgraph/metering/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . counter import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
75
trustgraph-flow/trustgraph/metering/counter.py
Normal file
75
trustgraph-flow/trustgraph/metering/counter.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""
|
||||
Simple token counter for each LLM response.
|
||||
"""
|
||||
|
||||
from prometheus_client import Histogram, Info
|
||||
from . pricelist import price_list
|
||||
|
||||
from .. schema import TextCompletionResponse, Error
|
||||
from .. schema import text_completion_response_queue
|
||||
from .. log_level import LogLevel
|
||||
from .. base import Consumer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionResponse,
|
||||
}
|
||||
)
|
||||
|
||||
def get_prices(self, prices, modelname):
|
||||
for model in prices["price_list"]:
|
||||
if model["model_name"] == modelname:
|
||||
return model["input_price"], model["output_price"]
|
||||
return None, None # Return None if model is not found
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
modelname = v.model
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling response {id}...", flush=True)
|
||||
|
||||
num_in = v.in_token
|
||||
num_out = v.out_token
|
||||
|
||||
model_input_price, model_output_price = self.get_prices(price_list, modelname)
|
||||
|
||||
if model_input_price == None:
|
||||
cost_per_call = f"Model Not Found in Price list"
|
||||
else:
|
||||
cost_in = num_in * model_input_price
|
||||
cost_out = num_out * model_output_price
|
||||
cost_per_call = round(cost_in + cost_out, 6)
|
||||
|
||||
print(f"Input Tokens: {num_in}", flush=True)
|
||||
print(f"Output Tokens: {num_out}", flush=True)
|
||||
print(f"Cost for call: ${cost_per_call}", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
104
trustgraph-flow/trustgraph/metering/pricelist.py
Normal file
104
trustgraph-flow/trustgraph/metering/pricelist.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
price_list = {
|
||||
"price_list": [
|
||||
{
|
||||
"model_name": "mistral.mistral-large-2407-v1:0",
|
||||
"input_price": 0.000004,
|
||||
"output_price": 0.000012
|
||||
},
|
||||
{
|
||||
"model_name": "meta.llama3-1-405b-instruct-v1:0",
|
||||
"input_price": 0.00000532,
|
||||
"output_price": 0.000016
|
||||
},
|
||||
{
|
||||
"model_name": "mistral.mixtral-8x7b-instruct-v0:1",
|
||||
"input_price": 0.00000045,
|
||||
"output_price": 0.0000007
|
||||
},
|
||||
{
|
||||
"model_name": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"input_price": 0.00000099,
|
||||
"output_price": 0.00000099
|
||||
},
|
||||
{
|
||||
"model_name": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"input_price": 0.00000022,
|
||||
"output_price": 0.00000022
|
||||
},
|
||||
{
|
||||
"model_name": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"input_price": 0.00000025,
|
||||
"output_price": 0.00000125
|
||||
},
|
||||
{
|
||||
"model_name": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"input_price": 0.000003,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "cohere.command-r-plus-v1:0",
|
||||
"input_price": 0.0000030,
|
||||
"output_price": 0.0000150
|
||||
},
|
||||
{
|
||||
"model_name": "ollama",
|
||||
"input_price": 0,
|
||||
"output_price": 0
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-haiku-20240307",
|
||||
"input_price": 0.00000025,
|
||||
"output_price": 0.00000125
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-5-sonnet-20240620",
|
||||
"input_price": 0.000003,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-opus-20240229",
|
||||
"input_price": 0.000015,
|
||||
"output_price": 0.000075
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-sonnet-20240229",
|
||||
"input_price": 0.000003,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "command-r-08-202",
|
||||
"input_price": 0.0000025,
|
||||
"output_price": 0.000010
|
||||
},
|
||||
{
|
||||
"model_name": "c4ai-aya-23-8b",
|
||||
"input_price": 0,
|
||||
"output_price": 0
|
||||
},
|
||||
{
|
||||
"model_name": "llama.cpp",
|
||||
"input_price": 0,
|
||||
"output_price": 0
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o",
|
||||
"input_price": 0.000005,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o-2024-08-06",
|
||||
"input_price": 0.0000025,
|
||||
"output_price": 0.000010
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o-2024-05-13",
|
||||
"input_price": 0.000005,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"input_price": 0.00000015,
|
||||
"output_price": 0.0000006
|
||||
},
|
||||
]
|
||||
}
|
||||
0
trustgraph-flow/trustgraph/model/__init__.py
Normal file
0
trustgraph-flow/trustgraph/model/__init__.py
Normal file
0
trustgraph-flow/trustgraph/model/prompt/__init__.py
Normal file
0
trustgraph-flow/trustgraph/model/prompt/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/prompt/generic/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/prompt/generic/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
176
trustgraph-flow/trustgraph/model/prompt/generic/prompts.py
Normal file
176
trustgraph-flow/trustgraph/model/prompt/generic/prompts.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
|
||||
def to_relationships(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.
|
||||
|
||||
Read the provided text. You will model the text as an information network for a RDF knowledge graph in JSON.
|
||||
|
||||
Information Network Rules:
|
||||
- An information network has subjects connected by predicates to objects.
|
||||
- A subject is a named-entity or a conceptual topic.
|
||||
- One subject can have many predicates and objects.
|
||||
- An object is a property or attribute of a subject.
|
||||
- A subject can be connected by a predicate to another subject.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Obey the information network rules.
|
||||
- Do not return special characters.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of JSON objects with keys "subject", "predicate", "object", and "object-entity".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"subject": string, "predicate": string, "object": string, "object-entity": boolean}}]
|
||||
```
|
||||
|
||||
- The key "object-entity" is TRUE only if the "object" is a subject.
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_topics(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify topics and their definitions in JSON.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Do not respond with special characters.
|
||||
- Return only topics that are concepts and unique to the provided text.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of objects with keys "topic" and "definition".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"topic": string, "definition": string}}]
|
||||
```
|
||||
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_definitions(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify entities and their definitions in JSON.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Do not respond with special characters.
|
||||
- Return only entities that are named-entities such as: people, organizations, physical objects, locations, animals, products, commodotities, or substances.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of objects with keys "entity" and "definition".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"entity": string, "definition": string}}]
|
||||
```
|
||||
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_rows(schema, text):
|
||||
|
||||
field_schema = [
|
||||
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
|
||||
for f in schema.fields
|
||||
]
|
||||
|
||||
field_schema = "\n".join(field_schema)
|
||||
|
||||
schema = f"""Object name: {schema.name}
|
||||
Description: {schema.description}
|
||||
|
||||
Fields:
|
||||
{field_schema}"""
|
||||
|
||||
prompt = f"""<instructions>
|
||||
Study the following text and derive objects which match the schema provided.
|
||||
|
||||
You must output an array of JSON objects for each object you discover
|
||||
which matches the schema. For each object, output a JSON object whose fields
|
||||
carry the name field specified in the schema.
|
||||
</instructions>
|
||||
|
||||
<schema>
|
||||
{schema}
|
||||
</schema>
|
||||
|
||||
<text>
|
||||
{text}
|
||||
</text>
|
||||
|
||||
<requirements>
|
||||
You will respond only with raw JSON format data. Do not provide
|
||||
explanations. Do not add markdown formatting or headers or prefixes.
|
||||
</requirements>"""
|
||||
|
||||
return prompt
|
||||
|
||||
def get_cypher(kg):
|
||||
|
||||
sg2 = []
|
||||
|
||||
for f in kg:
|
||||
|
||||
print(f)
|
||||
|
||||
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
|
||||
|
||||
print(sg2)
|
||||
|
||||
kg = "\n".join(sg2)
|
||||
kg = kg.replace("\\", "-")
|
||||
|
||||
return kg
|
||||
|
||||
def to_kg_query(query, kg):
|
||||
|
||||
cypher = get_cypher(kg)
|
||||
|
||||
prompt=f"""Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
|
||||
|
||||
Here's the knowledge statements:
|
||||
{cypher}
|
||||
|
||||
Use only the provided knowledge statements to respond to the following:
|
||||
{query}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_document_query(query, documents):
|
||||
|
||||
documents = "\n\n".join(documents)
|
||||
|
||||
prompt=f"""Study the following context. Use only the information provided in the context in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
|
||||
|
||||
Here is the context:
|
||||
{documents}
|
||||
|
||||
Use only the provided knowledge statements to respond to the following:
|
||||
{query}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
473
trustgraph-flow/trustgraph/model/prompt/generic/service.py
Executable file
473
trustgraph-flow/trustgraph/model/prompt/generic/service.py
Executable file
|
|
@ -0,0 +1,473 @@
|
|||
"""
|
||||
Language service abstracts prompt engineering from LLM.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .... schema import Definition, Relationship, Triple
|
||||
from .... schema import Topic
|
||||
from .... schema import PromptRequest, PromptResponse, Error
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... schema import prompt_request_queue, prompt_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... clients.llm_client import LlmClient
|
||||
|
||||
from . prompts import to_definitions, to_relationships, to_topics
|
||||
from . prompts import to_kg_query, to_document_query, to_rows
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = prompt_request_queue
|
||||
default_output_queue = prompt_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
tc_request_queue = params.get(
|
||||
"text_completion_request_queue", text_completion_request_queue
|
||||
)
|
||||
tc_response_queue = params.get(
|
||||
"text_completion_response_queue", text_completion_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": PromptRequest,
|
||||
"output_schema": PromptResponse,
|
||||
"text_completion_request_queue": tc_request_queue,
|
||||
"text_completion_response_queue": tc_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.llm = LlmClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=tc_request_queue,
|
||||
output_queue=tc_response_queue,
|
||||
pulsar_host = self.pulsar_host
|
||||
)
|
||||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# If no delimiters, assume the entire output is JSON
|
||||
json_str = text.strip()
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
kind = v.kind
|
||||
|
||||
print(f"Handling kind {kind}...", flush=True)
|
||||
|
||||
if kind == "extract-definitions":
|
||||
|
||||
self.handle_extract_definitions(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-topics":
|
||||
|
||||
self.handle_extract_topics(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-relationships":
|
||||
|
||||
self.handle_extract_relationships(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-rows":
|
||||
|
||||
self.handle_extract_rows(id, v)
|
||||
return
|
||||
|
||||
elif kind == "kg-prompt":
|
||||
|
||||
self.handle_kg_prompt(id, v)
|
||||
return
|
||||
|
||||
elif kind == "document-prompt":
|
||||
|
||||
self.handle_document_prompt(id, v)
|
||||
return
|
||||
|
||||
else:
|
||||
|
||||
print("Invalid kind.", flush=True)
|
||||
return
|
||||
|
||||
def handle_extract_definitions(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_definitions(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["entity"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Definition(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(definitions=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_topics(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_topics(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["topic"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Topic(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(topics=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_relationships(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_relationships(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
|
||||
s = defn["subject"]
|
||||
p = defn["predicate"]
|
||||
o = defn["object"]
|
||||
o_entity = defn["object-entity"]
|
||||
|
||||
if s == "": continue
|
||||
if s is None: continue
|
||||
|
||||
if p == "": continue
|
||||
if p is None: continue
|
||||
|
||||
if o == "": continue
|
||||
if o is None: continue
|
||||
|
||||
if o_entity == "" or o_entity is None:
|
||||
o_entity = False
|
||||
|
||||
output.append(
|
||||
Relationship(
|
||||
s = s,
|
||||
p = p,
|
||||
o = o,
|
||||
o_entity = o_entity,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("relationship fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(relationships=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_rows(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
fields = v.row_schema.fields
|
||||
|
||||
prompt = to_rows(v.row_schema, v.chunk)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
objs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
objs = []
|
||||
|
||||
output = []
|
||||
|
||||
for obj in objs:
|
||||
|
||||
try:
|
||||
|
||||
row = {}
|
||||
|
||||
for f in fields:
|
||||
|
||||
if f.name not in obj:
|
||||
print(f"Object ignored, missing field {f.name}")
|
||||
row = {}
|
||||
break
|
||||
|
||||
row[f.name] = obj[f.name]
|
||||
|
||||
if row == {}:
|
||||
continue
|
||||
|
||||
output.append(row)
|
||||
|
||||
except Exception as e:
|
||||
print("row fields missing, ignored", flush=True)
|
||||
|
||||
for row in output:
|
||||
print(row)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(rows=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_kg_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_kg_query(v.query, v.kg)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_document_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_document_query(v.query, v.documents)
|
||||
|
||||
print("prompt")
|
||||
print(prompt)
|
||||
|
||||
print("Call LLM...")
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-request-queue',
|
||||
default=text_completion_request_queue,
|
||||
help=f'Text completion request queue (default: {text_completion_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-response-queue',
|
||||
default=text_completion_response_queue,
|
||||
help=f'Text completion response queue (default: {text_completion_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/prompt/template/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/prompt/template/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
47
trustgraph-flow/trustgraph/model/prompt/template/prompts.py
Normal file
47
trustgraph-flow/trustgraph/model/prompt/template/prompts.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
|
||||
def to_relationships(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_definitions(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_topics(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_rows(template, schema, text):
|
||||
|
||||
field_schema = [
|
||||
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
|
||||
for f in schema.fields
|
||||
]
|
||||
|
||||
field_schema = "\n".join(field_schema)
|
||||
|
||||
return template.format(schema=schema, text=text)
|
||||
|
||||
schema = f"""Object name: {schema.name}
|
||||
Description: {schema.description}
|
||||
|
||||
Fields:
|
||||
{schema}"""
|
||||
|
||||
prompt = f""""""
|
||||
|
||||
return prompt
|
||||
|
||||
def get_cypher(kg):
|
||||
sg2 = []
|
||||
for f in kg:
|
||||
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
|
||||
kg = "\n".join(sg2)
|
||||
kg = kg.replace("\\", "-")
|
||||
return kg
|
||||
|
||||
def to_kg_query(template, query, kg):
|
||||
cypher = get_cypher(kg)
|
||||
return template.format(query=query, graph=cypher)
|
||||
|
||||
def to_document_query(template, query, docs):
|
||||
docs = "\n\n".join(docs)
|
||||
return template.format(query=query, documents=docs)
|
||||
|
||||
523
trustgraph-flow/trustgraph/model/prompt/template/service.py
Executable file
523
trustgraph-flow/trustgraph/model/prompt/template/service.py
Executable file
|
|
@ -0,0 +1,523 @@
|
|||
|
||||
"""
|
||||
Language service abstracts prompt engineering from LLM.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .... schema import Definition, Relationship, Triple
|
||||
from .... schema import Topic
|
||||
from .... schema import PromptRequest, PromptResponse, Error
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... schema import prompt_request_queue, prompt_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... clients.llm_client import LlmClient
|
||||
|
||||
from . prompts import to_definitions, to_relationships, to_rows
|
||||
from . prompts import to_kg_query, to_document_query, to_topics
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = prompt_request_queue
|
||||
default_output_queue = prompt_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
tc_request_queue = params.get(
|
||||
"text_completion_request_queue", text_completion_request_queue
|
||||
)
|
||||
tc_response_queue = params.get(
|
||||
"text_completion_response_queue", text_completion_response_queue
|
||||
)
|
||||
definition_template = params.get("definition_template")
|
||||
relationship_template = params.get("relationship_template")
|
||||
topic_template = params.get("topic_template")
|
||||
rows_template = params.get("rows_template")
|
||||
knowledge_query_template = params.get("knowledge_query_template")
|
||||
document_query_template = params.get("document_query_template")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": PromptRequest,
|
||||
"output_schema": PromptResponse,
|
||||
"text_completion_request_queue": tc_request_queue,
|
||||
"text_completion_response_queue": tc_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.llm = LlmClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=tc_request_queue,
|
||||
output_queue=tc_response_queue,
|
||||
pulsar_host = self.pulsar_host
|
||||
)
|
||||
|
||||
self.definition_template = definition_template
|
||||
self.topic_template = topic_template
|
||||
self.relationship_template = relationship_template
|
||||
self.rows_template = rows_template
|
||||
self.knowledge_query_template = knowledge_query_template
|
||||
self.document_query_template = document_query_template
|
||||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# If no delimiters, assume the entire output is JSON
|
||||
json_str = text.strip()
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
kind = v.kind
|
||||
|
||||
print(f"Handling kind {kind}...", flush=True)
|
||||
|
||||
if kind == "extract-definitions":
|
||||
|
||||
self.handle_extract_definitions(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-topics":
|
||||
|
||||
self.handle_extract_topics(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-relationships":
|
||||
|
||||
self.handle_extract_relationships(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-rows":
|
||||
|
||||
self.handle_extract_rows(id, v)
|
||||
return
|
||||
|
||||
elif kind == "kg-prompt":
|
||||
|
||||
self.handle_kg_prompt(id, v)
|
||||
return
|
||||
|
||||
elif kind == "document-prompt":
|
||||
|
||||
self.handle_document_prompt(id, v)
|
||||
return
|
||||
|
||||
else:
|
||||
|
||||
print("Invalid kind.", flush=True)
|
||||
return
|
||||
|
||||
def handle_extract_definitions(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_definitions(self.definition_template, v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["entity"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Definition(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(definitions=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_topics(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_topics(self.topic_template, v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["topic"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Topic(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(topics=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
|
||||
def handle_extract_relationships(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_relationships(self.relationship_template, v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
|
||||
s = defn["subject"]
|
||||
p = defn["predicate"]
|
||||
o = defn["object"]
|
||||
o_entity = defn["object-entity"]
|
||||
|
||||
if s == "": continue
|
||||
if s is None: continue
|
||||
|
||||
if p == "": continue
|
||||
if p is None: continue
|
||||
|
||||
if o == "": continue
|
||||
if o is None: continue
|
||||
|
||||
if o_entity == "" or o_entity is None:
|
||||
o_entity = False
|
||||
|
||||
output.append(
|
||||
Relationship(
|
||||
s = s,
|
||||
p = p,
|
||||
o = o,
|
||||
o_entity = o_entity,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("relationship fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(relationships=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_rows(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
fields = v.row_schema.fields
|
||||
|
||||
prompt = to_rows(self.rows_template, v.row_schema, v.chunk)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
objs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
objs = []
|
||||
|
||||
output = []
|
||||
|
||||
for obj in objs:
|
||||
|
||||
try:
|
||||
|
||||
row = {}
|
||||
|
||||
for f in fields:
|
||||
|
||||
if f.name not in obj:
|
||||
print(f"Object ignored, missing field {f.name}")
|
||||
row = {}
|
||||
break
|
||||
|
||||
row[f.name] = obj[f.name]
|
||||
|
||||
if row == {}:
|
||||
continue
|
||||
|
||||
output.append(row)
|
||||
|
||||
except Exception as e:
|
||||
print("row fields missing, ignored", flush=True)
|
||||
|
||||
for row in output:
|
||||
print(row)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(rows=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_kg_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_kg_query(self.knowledge_query_template, v.query, v.kg)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_document_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_document_query(
|
||||
self.document_query_template, v.query, v.documents
|
||||
)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-request-queue',
|
||||
default=text_completion_request_queue,
|
||||
help=f'Text completion request queue (default: {text_completion_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-response-queue',
|
||||
default=text_completion_response_queue,
|
||||
help=f'Text completion response queue (default: {text_completion_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--definition-template',
|
||||
required=True,
|
||||
help=f'Definition extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--topic-template',
|
||||
required=True,
|
||||
help=f'Topic extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rows-template',
|
||||
required=True,
|
||||
help=f'Rows extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--relationship-template',
|
||||
required=True,
|
||||
help=f'Relationship extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--knowledge-query-template',
|
||||
required=True,
|
||||
help=f'Knowledge query template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--document-query-template',
|
||||
required=True,
|
||||
help=f'Document query template',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/azure/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/azure/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
226
trustgraph-flow/trustgraph/model/text_completion/azure/llm.py
Executable file
226
trustgraph-flow/trustgraph/model/text_completion/azure/llm.py
Executable file
|
|
@ -0,0 +1,226 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using the Azure
|
||||
serverless endpoint service. Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_temperature = 0.0
|
||||
default_max_output = 4192
|
||||
default_model = "AzureAI"
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
endpoint = params.get("endpoint")
|
||||
token = params.get("token")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
model = default_model
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.token = token
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.model = model
|
||||
|
||||
def build_prompt(self, system, content):
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system", "content": system
|
||||
},
|
||||
{
|
||||
"role": "user", "content": content
|
||||
}
|
||||
],
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 1
|
||||
}
|
||||
|
||||
body = json.dumps(data)
|
||||
|
||||
return body
|
||||
|
||||
def call_llm(self, body):
|
||||
|
||||
url = self.endpoint
|
||||
|
||||
# Replace this with the primary/secondary key, AMLToken, or
|
||||
# Microsoft Entra ID token for the endpoint
|
||||
api_key = self.token
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {api_key}'
|
||||
}
|
||||
|
||||
resp = requests.post(url, data=body, headers=headers)
|
||||
|
||||
if resp.status_code == 429:
|
||||
raise TooManyRequests()
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError("LLM failure")
|
||||
|
||||
result = resp.json()
|
||||
|
||||
return result
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
try:
|
||||
|
||||
prompt = self.build_prompt(
|
||||
"You are a helpful chatbot",
|
||||
v.prompt
|
||||
)
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
response = self.call_llm(prompt)
|
||||
|
||||
resp = response['choices'][0]['message']['content']
|
||||
inputtokens = response['usage']['prompt_tokens']
|
||||
outputtokens = response['usage']['completion_tokens']
|
||||
|
||||
print(resp, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-e', '--endpoint',
|
||||
help=f'LLM model endpoint'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--token',
|
||||
help=f'LLM model token'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/claude/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/claude/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
199
trustgraph-flow/trustgraph/model/text_completion/claude/llm.py
Executable file
199
trustgraph-flow/trustgraph/model/text_completion/claude/llm.py
Executable file
|
|
@ -0,0 +1,199 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using Claude.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import anthropic
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'claude-3-5-sonnet-20240620'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 8192
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.claude = anthropic.Anthropic(api_key=api_key)
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
# FIXME: Rate limits?
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
|
||||
response = message = self.claude.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_output,
|
||||
temperature=self.temperature,
|
||||
system = "You are a helpful chatbot.",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
resp = response.content[0].text
|
||||
inputtokens = response.usage.input_tokens
|
||||
outputtokens = response.usage.output_tokens
|
||||
print(resp, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="claude-3-5-sonnet-20240620",
|
||||
help=f'LLM model (default: claude-3-5-sonnet-20240620)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
help=f'Claude API key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
7
trustgraph-flow/trustgraph/model/text_completion/cohere/__main__.py
Executable file
7
trustgraph-flow/trustgraph/model/text_completion/cohere/__main__.py
Executable file
|
|
@ -0,0 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue