Feature/subpackages (#80)

* Renaming what will become the core package

* Tweaking to get  package build working

* Fix metering merge

* Rename to core directory

* Bump version.  Use namespace searching for packaging trustgraph-core

* Change references to trustgraph-core

* Forming embeddings-hf package

* Reference modules in core package.

* Build both packages to one container, bump version

* Update YAMLs
This commit is contained in:
cybermaggedon 2024-09-30 14:00:29 +01:00 committed by GitHub
parent 14d79ef9f1
commit f081933217
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
303 changed files with 681 additions and 624 deletions

View file

View file

@ -1,6 +0,0 @@
from . base_processor import BaseProcessor
from . consumer import Consumer
from . producer import Producer
from . consumer_producer import ConsumerProducer

View file

@ -1,119 +0,0 @@
import os
import argparse
import pulsar
import _pulsar
import time
from prometheus_client import start_http_server, Info
from .. log_level import LogLevel
class BaseProcessor:
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
def __init__(self, **params):
self.client = None
if not hasattr(__class__, "params_metric"):
__class__.params_metric = Info(
'params', 'Parameters configuration'
)
# FIXME: Maybe outputs information it should not
__class__.params_metric.info({
k: str(params[k])
for k in params
})
pulsar_host = params.get("pulsar_host", self.default_pulsar_host)
log_level = params.get("log_level", LogLevel.INFO)
self.pulsar_host = pulsar_host
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
)
def __del__(self):
if self.client:
self.client.close()
@staticmethod
def add_args(parser):
parser.add_argument(
'-p', '--pulsar-host',
default=__class__.default_pulsar_host,
help=f'Pulsar host (default: {__class__.default_pulsar_host})',
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.INFO,
choices=list(LogLevel),
help=f'Output queue (default: info)'
)
parser.add_argument(
'--metrics',
action=argparse.BooleanOptionalAction,
default=True,
help=f'Metrics enabled (default: true)',
)
parser.add_argument(
'-P', '--metrics-port',
type=int,
default=8000,
help=f'Pulsar host (default: 8000)',
)
def run(self):
raise RuntimeError("Something should have implemented the run method")
@classmethod
def start(cls, prog, doc):
parser = argparse.ArgumentParser(
prog=prog,
description=doc
)
cls.add_args(parser)
args = parser.parse_args()
args = vars(args)
print(args)
if args["metrics"]:
start_http_server(args["metrics_port"])
while True:
try:
p = cls(**args)
p.run()
except KeyboardInterrupt:
print("Keyboard interrupt.")
return
except _pulsar.Interrupted:
print("Pulsar Interrupted.")
return
except Exception as e:
print(type(e))
print("Exception:", e, flush=True)
print("Will retry...", flush=True)
time.sleep(4)

View file

@ -1,107 +0,0 @@
from pulsar.schema import JsonSchema
from prometheus_client import Histogram, Info, Counter, Enum
import time
from . base_processor import BaseProcessor
from .. exceptions import TooManyRequests
class Consumer(BaseProcessor):
def __init__(self, **params):
if not hasattr(__class__, "state_metric"):
__class__.state_metric = Enum(
'processor_state', 'Processor state',
states=['starting', 'running', 'stopped']
)
__class__.state_metric.state('starting')
__class__.state_metric.state('starting')
super(Consumer, self).__init__(**params)
input_queue = params.get("input_queue")
subscriber = params.get("subscriber")
input_schema = params.get("input_schema")
if input_schema == None:
raise RuntimeError("input_schema must be specified")
if not hasattr(__class__, "request_metric"):
__class__.request_metric = Histogram(
'request_latency', 'Request latency (seconds)'
)
if not hasattr(__class__, "pubsub_metric"):
__class__.pubsub_metric = Info(
'pubsub', 'Pub/sub configuration'
)
if not hasattr(__class__, "processing_metric"):
__class__.processing_metric = Counter(
'processing_count', 'Processing count', ["status"]
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": input_schema.__name__,
})
self.consumer = self.client.subscribe(
input_queue, subscriber,
schema=JsonSchema(input_schema),
)
def run(self):
__class__.state_metric.state('running')
while True:
msg = self.consumer.receive()
try:
with __class__.request_metric.time():
self.handle(msg)
# Acknowledge successful processing of the message
self.consumer.acknowledge(msg)
__class__.processing_metric.labels(status="success").inc()
except TooManyRequests:
self.consumer.negative_acknowledge(msg)
print("TooManyRequests: will retry")
__class__.processing_metric.labels(status="rate-limit").inc()
time.sleep(5)
continue
except Exception as e:
print("Exception:", e, flush=True)
# Message failed to be processed
self.consumer.negative_acknowledge(msg)
__class__.processing_metric.labels(status="error").inc()
@staticmethod
def add_args(parser, default_input_queue, default_subscriber):
BaseProcessor.add_args(parser)
parser.add_argument(
'-i', '--input-queue',
default=default_input_queue,
help=f'Input queue (default: {default_input_queue})'
)
parser.add_argument(
'-s', '--subscriber',
default=default_subscriber,
help=f'Queue subscriber name (default: {default_subscriber})'
)

View file

@ -1,139 +0,0 @@
from pulsar.schema import JsonSchema
from prometheus_client import Histogram, Info, Counter, Enum
import time
from . base_processor import BaseProcessor
from .. exceptions import TooManyRequests
# FIXME: Derive from consumer? And producer?
class ConsumerProducer(BaseProcessor):
def __init__(self, **params):
if not hasattr(__class__, "state_metric"):
__class__.state_metric = Enum(
'processor_state', 'Processor state',
states=['starting', 'running', 'stopped']
)
__class__.state_metric.state('starting')
__class__.state_metric.state('starting')
input_queue = params.get("input_queue")
output_queue = params.get("output_queue")
subscriber = params.get("subscriber")
input_schema = params.get("input_schema")
output_schema = params.get("output_schema")
if not hasattr(__class__, "request_metric"):
__class__.request_metric = Histogram(
'request_latency', 'Request latency (seconds)'
)
if not hasattr(__class__, "output_metric"):
__class__.output_metric = Counter(
'output_count', 'Output items created'
)
if not hasattr(__class__, "pubsub_metric"):
__class__.pubsub_metric = Info(
'pubsub', 'Pub/sub configuration'
)
if not hasattr(__class__, "processing_metric"):
__class__.processing_metric = Counter(
'processing_count', 'Processing count', ["status"]
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": input_schema.__name__,
"output_schema": output_schema.__name__,
})
super(ConsumerProducer, self).__init__(**params)
if input_schema == None:
raise RuntimeError("input_schema must be specified")
if output_schema == None:
raise RuntimeError("output_schema must be specified")
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(output_schema),
)
self.consumer = self.client.subscribe(
input_queue, subscriber,
schema=JsonSchema(input_schema),
)
def run(self):
__class__.state_metric.state('running')
while True:
msg = self.consumer.receive()
try:
with __class__.request_metric.time():
resp = self.handle(msg)
# Acknowledge successful processing of the message
self.consumer.acknowledge(msg)
__class__.processing_metric.labels(status="success").inc()
except TooManyRequests:
self.consumer.negative_acknowledge(msg)
print("TooManyRequests: will retry")
__class__.processing_metric.labels(status="rate-limit").inc()
time.sleep(5)
continue
except Exception as e:
print("Exception:", e, flush=True)
# Message failed to be processed
self.consumer.negative_acknowledge(msg)
__class__.processing_metric.labels(status="error").inc()
def send(self, msg, properties={}):
self.producer.send(msg, properties)
__class__.output_metric.inc()
@staticmethod
def add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
):
BaseProcessor.add_args(parser)
parser.add_argument(
'-i', '--input-queue',
default=default_input_queue,
help=f'Input queue (default: {default_input_queue})'
)
parser.add_argument(
'-s', '--subscriber',
default=default_subscriber,
help=f'Queue subscriber name (default: {default_subscriber})'
)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)

View file

@ -1,55 +0,0 @@
from pulsar.schema import JsonSchema
from prometheus_client import Info, Counter
from . base_processor import BaseProcessor
class Producer(BaseProcessor):
def __init__(self, **params):
output_queue = params.get("output_queue")
output_schema = params.get("output_schema")
if not hasattr(__class__, "output_metric"):
__class__.output_metric = Counter(
'output_count', 'Output items created'
)
if not hasattr(__class__, "pubsub_metric"):
__class__.pubsub_metric = Info(
'pubsub', 'Pub/sub configuration'
)
__class__.pubsub_metric.info({
"output_queue": output_queue,
"output_schema": output_schema.__name__,
})
super(Producer, self).__init__(**params)
if output_schema == None:
raise RuntimeError("output_schema must be specified")
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(output_schema),
)
def send(self, msg, properties={}):
self.producer.send(msg, properties)
__class__.output_metric.inc()
@staticmethod
def add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
):
BaseProcessor.add_args(parser)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)

View file

@ -1,3 +0,0 @@
from . chunker import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . chunker import run
if __name__ == '__main__':
run()

View file

@ -1,108 +0,0 @@
"""
Simple decoder, accepts text documents on input, outputs chunks from the
as text as separate output objects.
"""
from langchain_text_splitters import RecursiveCharacterTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Source
from ... schema import text_ingest_queue, chunk_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_ingest_queue
default_output_queue = chunk_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
chunk_size = params.get("chunk_size", 2000)
chunk_overlap = params.get("chunk_overlap", 100)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextDocument,
"output_schema": Chunk,
}
)
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
buckets=[100, 160, 250, 400, 650, 1000, 1600,
2500, 4000, 6400, 10000, 16000]
)
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False,
)
def handle(self, msg):
v = msg.value()
print(f"Chunking {v.source.id}...", flush=True)
texts = self.text_splitter.create_documents(
[v.text.decode("utf-8")]
)
for ix, chunk in enumerate(texts):
id = v.source.id + "-c" + str(ix)
r = Chunk(
source=Source(
source=v.source.source,
id=id,
title=v.source.title
),
chunk=chunk.page_content.encode("utf-8"),
)
__class__.chunk_metric.observe(len(chunk.page_content))
self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-z', '--chunk-size',
type=int,
default=2000,
help=f'Chunk size (default: 2000)'
)
parser.add_argument(
'-v', '--chunk-overlap',
type=int,
default=100,
help=f'Chunk overlap (default: 100)'
)
def run():
Processor.start(module, __doc__)

View file

@ -1,3 +0,0 @@
from . chunker import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . chunker import run
if __name__ == '__main__':
run()

View file

@ -1,107 +0,0 @@
"""
Simple decoder, accepts text documents on input, outputs chunks from the
as text as separate output objects.
"""
from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Source
from ... schema import text_ingest_queue, chunk_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_ingest_queue
default_output_queue = chunk_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
chunk_size = params.get("chunk_size", 250)
chunk_overlap = params.get("chunk_overlap", 15)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextDocument,
"output_schema": Chunk,
}
)
if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size',
buckets=[100, 160, 250, 400, 650, 1000, 1600,
2500, 4000, 6400, 10000, 16000]
)
self.text_splitter = TokenTextSplitter(
encoding_name="cl100k_base",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
def handle(self, msg):
v = msg.value()
print(f"Chunking {v.source.id}...", flush=True)
texts = self.text_splitter.create_documents(
[v.text.decode("utf-8")]
)
for ix, chunk in enumerate(texts):
id = v.source.id + "-c" + str(ix)
r = Chunk(
source=Source(
source=v.source.source,
id=id,
title=v.source.title
),
chunk=chunk.page_content.encode("utf-8"),
)
__class__.chunk_metric.observe(len(chunk.page_content))
self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-z', '--chunk-size',
type=int,
default=250,
help=f'Chunk size (default: 250)'
)
parser.add_argument(
'-v', '--chunk-overlap',
type=int,
default=15,
help=f'Chunk overlap (default: 15)'
)
def run():
Processor.start(module, __doc__)

View file

@ -1,125 +0,0 @@
import pulsar
import _pulsar
import hashlib
import uuid
import time
from pulsar.schema import JsonSchema
from .. exceptions import *
# Default timeout for a request/response. In seconds.
DEFAULT_TIMEOUT=300
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class BaseClient:
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
input_schema=None,
output_schema=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None: raise RuntimeError("Need input_queue")
if output_queue == None: raise RuntimeError("Need output_queue")
if input_schema == None: raise RuntimeError("Need input_schema")
if output_schema == None: raise RuntimeError("Need output_schema")
if subscriber == None:
subscriber = str(uuid.uuid4())
self.client = pulsar.Client(
pulsar_host,
logger=pulsar.ConsoleLogger(log_level),
)
self.producer = self.client.create_producer(
topic=input_queue,
schema=JsonSchema(input_schema),
chunking_enabled=True,
)
self.consumer = self.client.subscribe(
output_queue, subscriber,
schema=JsonSchema(output_schema),
)
self.input_schema = input_schema
self.output_schema = output_schema
def call(self, **args):
timeout = args.get("timeout", DEFAULT_TIMEOUT)
if "timeout" in args:
del args["timeout"]
id = str(uuid.uuid4())
r = self.input_schema(**args)
end_time = time.time() + timeout
self.producer.send(r, properties={ "id": id })
while time.time() < end_time:
try:
msg = self.consumer.receive(timeout_millis=2500)
except pulsar.exceptions.Timeout:
continue
mid = msg.properties()["id"]
if mid == id:
value = msg.value()
if value.error:
self.consumer.acknowledge(msg)
if value.error.type == "llm-error":
raise LlmError(value.error.message)
elif value.error.type == "too-many-requests":
raise TooManyRequests(value.error.message)
elif value.error.type == "ParseError":
raise ParseError(value.error.message)
else:
raise RuntimeError(
f"{value.error.type}: {value.error.message}"
)
resp = msg.value()
self.consumer.acknowledge(msg)
return resp
# Ignore messages with wrong ID
self.consumer.acknowledge(msg)
raise TimeoutError("Timed out waiting for response")
def __del__(self):
if hasattr(self, "consumer"):
self.consumer.close()
if hasattr(self, "producer"):
self.producer.flush()
self.producer.close()
self.client.close()

View file

@ -1,45 +0,0 @@
import _pulsar
from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from .. schema import document_embeddings_request_queue
from .. schema import document_embeddings_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class DocumentEmbeddingsClient(BaseClient):
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = document_embeddings_request_queue
if output_queue == None:
output_queue = document_embeddings_response_queue
super(DocumentEmbeddingsClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=DocumentEmbeddingsRequest,
output_schema=DocumentEmbeddingsResponse,
)
def request(self, vectors, limit=10, timeout=300):
return self.call(
vectors=vectors, limit=limit, timeout=timeout
).documents

View file

@ -1,46 +0,0 @@
import _pulsar
from .. schema import DocumentRagQuery, DocumentRagResponse
from .. schema import document_rag_request_queue, document_rag_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class DocumentRagClient(BaseClient):
def __init__(
self,
log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = document_rag_request_queue
if output_queue == None:
output_queue = document_rag_response_queue
super(DocumentRagClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=DocumentRagQuery,
output_schema=DocumentRagResponse,
)
def request(self, query, timeout=500):
return self.call(
query=query, timeout=timeout
).response

View file

@ -1,44 +0,0 @@
from pulsar.schema import JsonSchema
from .. schema import EmbeddingsRequest, EmbeddingsResponse
from .. schema import embeddings_request_queue, embeddings_response_queue
from . base import BaseClient
import _pulsar
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class EmbeddingsClient(BaseClient):
def __init__(
self, log_level=ERROR,
input_queue=None,
output_queue=None,
subscriber=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue=embeddings_request_queue
if output_queue == None:
output_queue=embeddings_response_queue
super(EmbeddingsClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=EmbeddingsRequest,
output_schema=EmbeddingsResponse,
)
def request(self, text, timeout=300):
return self.call(text=text, timeout=timeout).vectors

View file

@ -1,45 +0,0 @@
import _pulsar
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from .. schema import graph_embeddings_request_queue
from .. schema import graph_embeddings_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class GraphEmbeddingsClient(BaseClient):
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = graph_embeddings_request_queue
if output_queue == None:
output_queue = graph_embeddings_response_queue
super(GraphEmbeddingsClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=GraphEmbeddingsRequest,
output_schema=GraphEmbeddingsResponse,
)
def request(self, vectors, limit=10, timeout=300):
return self.call(
vectors=vectors, limit=limit, timeout=timeout
).entities

View file

@ -1,46 +0,0 @@
import _pulsar
from .. schema import GraphRagQuery, GraphRagResponse
from .. schema import graph_rag_request_queue, graph_rag_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class GraphRagClient(BaseClient):
def __init__(
self,
log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = graph_rag_request_queue
if output_queue == None:
output_queue = graph_rag_response_queue
super(GraphRagClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=GraphRagQuery,
output_schema=GraphRagResponse,
)
def request(self, query, timeout=500):
return self.call(
query=query, timeout=timeout
).response

View file

@ -1,40 +0,0 @@
import _pulsar
from .. schema import TextCompletionRequest, TextCompletionResponse
from .. schema import text_completion_request_queue
from .. schema import text_completion_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class LlmClient(BaseClient):
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue is None: input_queue = text_completion_request_queue
if output_queue is None: output_queue = text_completion_response_queue
super(LlmClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=TextCompletionRequest,
output_schema=TextCompletionResponse,
)
def request(self, prompt, timeout=300):
return self.call(prompt=prompt, timeout=timeout).response

View file

@ -1,100 +0,0 @@
import _pulsar
from .. schema import PromptRequest, PromptResponse, Fact, RowSchema, Field
from .. schema import prompt_request_queue
from .. schema import prompt_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class PromptClient(BaseClient):
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = prompt_request_queue
if output_queue == None:
output_queue = prompt_response_queue
super(PromptClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=PromptRequest,
output_schema=PromptResponse,
)
def request_definitions(self, chunk, timeout=300):
return self.call(
kind="extract-definitions", chunk=chunk,
timeout=timeout
).definitions
def request_topics(self, chunk, timeout=300):
return self.call(
kind="extract-topics", chunk=chunk,
timeout=timeout
).topics
def request_relationships(self, chunk, timeout=300):
return self.call(
kind="extract-relationships", chunk=chunk,
timeout=timeout
).relationships
def request_rows(self, schema, chunk, timeout=300):
return self.call(
kind="extract-rows", chunk=chunk,
row_schema=RowSchema(
name=schema.name,
description=schema.description,
fields=[
Field(
name=f.name, type=str(f.type), size=f.size,
primary=f.primary, description=f.description,
)
for f in schema.fields
]
),
timeout=timeout
).rows
def request_kg_prompt(self, query, kg, timeout=300):
return self.call(
kind="kg-prompt",
query=query,
kg=[
Fact(s=v[0], p=v[1], o=v[2])
for v in kg
],
timeout=timeout
).answer
def request_document_prompt(self, query, documents, timeout=300):
return self.call(
kind="document-prompt",
query=query,
documents=documents,
timeout=timeout
).answer

View file

@ -1,59 +0,0 @@
#!/usr/bin/env python3
import _pulsar
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
from .. schema import triples_request_queue
from .. schema import triples_response_queue
from . base import BaseClient
# Ugly
ERROR=_pulsar.LoggerLevel.Error
WARN=_pulsar.LoggerLevel.Warn
INFO=_pulsar.LoggerLevel.Info
DEBUG=_pulsar.LoggerLevel.Debug
class TriplesQueryClient(BaseClient):
def __init__(
self, log_level=ERROR,
subscriber=None,
input_queue=None,
output_queue=None,
pulsar_host="pulsar://pulsar:6650",
):
if input_queue == None:
input_queue = triples_request_queue
if output_queue == None:
output_queue = triples_response_queue
super(TriplesQueryClient, self).__init__(
log_level=log_level,
subscriber=subscriber,
input_queue=input_queue,
output_queue=output_queue,
pulsar_host=pulsar_host,
input_schema=TriplesQueryRequest,
output_schema=TriplesQueryResponse,
)
def create_value(self, ent):
if ent == None: return None
if ent.startswith("http://") or ent.startswith("https://"):
return Value(value=ent, is_uri=True)
return Value(value=ent, is_uri=False)
def request(self, s, p, o, limit=10, timeout=60):
return self.call(
s=self.create_value(s),
p=self.create_value(p),
o=self.create_value(o),
limit=limit,
timeout=timeout,
).triples

View file

@ -1,3 +0,0 @@
from . pdf_decoder import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . pdf_decoder import run
if __name__ == '__main__':
run()

View file

@ -1,87 +0,0 @@
"""
Simple decoder, accepts PDF documents on input, outputs pages from the
PDF document as text as separate output objects.
"""
import tempfile
import base64
from langchain_community.document_loaders import PyPDFLoader
from ... schema import Document, TextDocument, Source
from ... schema import document_ingest_queue, text_ingest_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = document_ingest_queue
default_output_queue = text_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": Document,
"output_schema": TextDocument,
}
)
print("PDF inited")
def handle(self, msg):
print("PDF message received")
v = msg.value()
print(f"Decoding {v.source.id}...", flush=True)
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
fp.write(base64.b64decode(v.data))
fp.close()
with open(fp.name, mode='rb') as f:
loader = PyPDFLoader(fp.name)
pages = loader.load()
for ix, page in enumerate(pages):
id = v.source.id + "-p" + str(ix)
r = TextDocument(
source=Source(
source=v.source.source,
title=v.source.title,
id=id,
),
text=page.page_content.encode("utf-8"),
)
self.send(r)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
def run():
Processor.start(module, __doc__)

View file

@ -1,108 +0,0 @@
from cassandra.cluster import Cluster
from cassandra.auth import PlainTextAuthProvider
class TrustGraph:
def __init__(self, hosts=None):
if hosts is None:
hosts = ["localhost"]
self.cluster = Cluster(hosts)
self.session = self.cluster.connect()
self.init()
def clear(self):
self.session.execute("""
drop keyspace if exists trustgraph;
""");
self.init()
def init(self):
self.session.execute("""
create keyspace if not exists trustgraph
with replication = {
'class' : 'SimpleStrategy',
'replication_factor' : 1
};
""");
self.session.set_keyspace('trustgraph')
self.session.execute("""
create table if not exists triples (
s text,
p text,
o text,
PRIMARY KEY (s, p, o)
);
""");
self.session.execute("""
create index if not exists triples_p
ON triples (p);
""");
self.session.execute("""
create index if not exists triples_o
ON triples (o);
""");
def insert(self, s, p, o):
self.session.execute(
"insert into triples (s, p, o) values (%s, %s, %s)",
(s, p, o)
)
def get_all(self, limit=50):
return self.session.execute(
f"select s, p, o from triples limit {limit}"
)
def get_s(self, s, limit=10):
return self.session.execute(
f"select p, o from triples where s = %s limit {limit}",
(s,)
)
def get_p(self, p, limit=10):
return self.session.execute(
f"select s, o from triples where p = %s limit {limit}",
(p,)
)
def get_o(self, o, limit=10):
return self.session.execute(
f"select s, p from triples where o = %s limit {limit}",
(o,)
)
def get_sp(self, s, p, limit=10):
return self.session.execute(
f"select o from triples where s = %s and p = %s limit {limit}",
(s, p)
)
def get_po(self, p, o, limit=10):
return self.session.execute(
f"select s from triples where p = %s and o = %s allow filtering limit {limit}",
(p, o)
)
def get_os(self, o, s, limit=10):
return self.session.execute(
f"select p from triples where o = %s and s = %s limit {limit}",
(o, s)
)
def get_spo(self, s, p, o, limit=10):
return self.session.execute(
f"""select s as x from triples where s = %s and p = %s and o = %s limit {limit}""",
(s, p, o)
)

View file

@ -1,138 +0,0 @@
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
import time
class DocVectors:
def __init__(self, uri="http://localhost:19530", prefix='doc'):
self.client = MilvusClient(uri=uri)
# Strategy is to create collections per dimension. Probably only
# going to be using 1 anyway, but that means we don't need to
# hard-code the dimension anywhere, and no big deal if more than
# one are created.
self.collections = {}
self.prefix = prefix
# Time between reloads
self.reload_time = 90
# Next time to reload - this forces a reload at next window
self.next_reload = time.time() + self.reload_time
print("Reload at", self.next_reload)
def init_collection(self, dimension):
collection_name = self.prefix + "_" + str(dimension)
pkey_field = FieldSchema(
name="id",
dtype=DataType.INT64,
is_primary=True,
auto_id=True,
)
vec_field = FieldSchema(
name="vector",
dtype=DataType.FLOAT_VECTOR,
dim=dimension,
)
doc_field = FieldSchema(
name="doc",
dtype=DataType.VARCHAR,
max_length=65535,
)
schema = CollectionSchema(
fields = [pkey_field, vec_field, doc_field],
description = "Document embedding schema",
)
self.client.create_collection(
collection_name=collection_name,
schema=schema,
metric_type="COSINE",
)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE",
index_type="IVF_SQ8",
index_name="vector_index",
params={ "nlist": 128 }
)
self.client.create_index(
collection_name=collection_name,
index_params=index_params
)
self.collections[dimension] = collection_name
def insert(self, embeds, doc):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim)
data = [
{
"vector": embeds,
"doc": doc,
}
]
self.client.insert(
collection_name=self.collections[dim],
data=data
)
def search(self, embeds, fields=["doc"], limit=10):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim)
coll = self.collections[dim]
search_params = {
"metric_type": "COSINE",
"params": {
"radius": 0.1,
"range_filter": 0.8
}
}
print("Loading...")
self.client.load_collection(
collection_name=coll,
)
print("Searching...")
res = self.client.search(
collection_name=coll,
data=[embeds],
limit=limit,
output_fields=fields,
search_params=search_params,
)[0]
# If reload time has passed, unload collection
if time.time() > self.next_reload:
print("Unloading, reload at", self.next_reload)
self.client.release_collection(
collection_name=coll,
)
self.next_reload = time.time() + self.reload_time
return res

View file

@ -1,138 +0,0 @@
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
import time
class EntityVectors:
def __init__(self, uri="http://localhost:19530", prefix='entity'):
self.client = MilvusClient(uri=uri)
# Strategy is to create collections per dimension. Probably only
# going to be using 1 anyway, but that means we don't need to
# hard-code the dimension anywhere, and no big deal if more than
# one are created.
self.collections = {}
self.prefix = prefix
# Time between reloads
self.reload_time = 90
# Next time to reload - this forces a reload at next window
self.next_reload = time.time() + self.reload_time
print("Reload at", self.next_reload)
def init_collection(self, dimension):
collection_name = self.prefix + "_" + str(dimension)
pkey_field = FieldSchema(
name="id",
dtype=DataType.INT64,
is_primary=True,
auto_id=True,
)
vec_field = FieldSchema(
name="vector",
dtype=DataType.FLOAT_VECTOR,
dim=dimension,
)
entity_field = FieldSchema(
name="entity",
dtype=DataType.VARCHAR,
max_length=65535,
)
schema = CollectionSchema(
fields = [pkey_field, vec_field, entity_field],
description = "Graph embedding schema",
)
self.client.create_collection(
collection_name=collection_name,
schema=schema,
metric_type="COSINE",
)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE",
index_type="IVF_SQ8",
index_name="vector_index",
params={ "nlist": 128 }
)
self.client.create_index(
collection_name=collection_name,
index_params=index_params
)
self.collections[dimension] = collection_name
def insert(self, embeds, entity):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim)
data = [
{
"vector": embeds,
"entity": entity,
}
]
self.client.insert(
collection_name=self.collections[dim],
data=data
)
def search(self, embeds, fields=["entity"], limit=10):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim)
coll = self.collections[dim]
search_params = {
"metric_type": "COSINE",
"params": {
"radius": 0.1,
"range_filter": 0.8
}
}
print("Loading...")
self.client.load_collection(
collection_name=coll,
)
print("Searching...")
res = self.client.search(
collection_name=coll,
data=[embeds],
limit=limit,
output_fields=fields,
search_params=search_params,
)[0]
# If reload time has passed, unload collection
if time.time() > self.next_reload:
print("Unloading, reload at", self.next_reload)
self.client.release_collection(
collection_name=coll,
)
self.next_reload = time.time() + self.reload_time
return res

View file

@ -1,154 +0,0 @@
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
import time
class ObjectVectors:
def __init__(self, uri="http://localhost:19530", prefix='obj'):
self.client = MilvusClient(uri=uri)
# Strategy is to create collections per dimension. Probably only
# going to be using 1 anyway, but that means we don't need to
# hard-code the dimension anywhere, and no big deal if more than
# one are created.
self.collections = {}
self.prefix = prefix
# Time between reloads
self.reload_time = 90
# Next time to reload - this forces a reload at next window
self.next_reload = time.time() + self.reload_time
print("Reload at", self.next_reload)
def init_collection(self, dimension, name):
collection_name = self.prefix + "_" + name + "_" + str(dimension)
pkey_field = FieldSchema(
name="id",
dtype=DataType.INT64,
is_primary=True,
auto_id=True,
)
vec_field = FieldSchema(
name="vector",
dtype=DataType.FLOAT_VECTOR,
dim=dimension,
)
name_field = FieldSchema(
name="name",
dtype=DataType.VARCHAR,
max_length=65535,
)
key_name_field = FieldSchema(
name="key_name",
dtype=DataType.VARCHAR,
max_length=65535,
)
key_field = FieldSchema(
name="key",
dtype=DataType.VARCHAR,
max_length=65535,
)
schema = CollectionSchema(
fields = [
pkey_field, vec_field, name_field, key_name_field, key_field
],
description = "Object embedding schema",
)
self.client.create_collection(
collection_name=collection_name,
schema=schema,
metric_type="COSINE",
)
index_params = MilvusClient.prepare_index_params()
index_params.add_index(
field_name="vector",
metric_type="COSINE",
index_type="IVF_SQ8",
index_name="vector_index",
params={ "nlist": 128 }
)
self.client.create_index(
collection_name=collection_name,
index_params=index_params
)
self.collections[(dimension, name)] = collection_name
def insert(self, embeds, name, key_name, key):
dim = len(embeds)
if (dim, name) not in self.collections:
self.init_collection(dim, name)
data = [
{
"vector": embeds,
"name": name,
"key_name": key_name,
"key": key,
}
]
self.client.insert(
collection_name=self.collections[(dim, name)],
data=data
)
def search(self, embeds, name, fields=["key_name", "name"], limit=10):
dim = len(embeds)
if dim not in self.collections:
self.init_collection(dim, name)
coll = self.collections[(dim, name)]
search_params = {
"metric_type": "COSINE",
"params": {
"radius": 0.1,
"range_filter": 0.8
}
}
print("Loading...")
self.client.load_collection(
collection_name=coll,
)
print("Searching...")
res = self.client.search(
collection_name=coll,
data=[embeds],
limit=limit,
output_fields=fields,
search_params=search_params,
)[0]
# If reload time has passed, unload collection
if time.time() > self.next_reload:
print("Unloading, reload at", self.next_reload)
self.client.release_collection(
collection_name=coll,
)
self.next_reload = time.time() + self.reload_time
return res

View file

@ -1,132 +0,0 @@
from . clients.document_embeddings_client import DocumentEmbeddingsClient
from . clients.triples_query_client import TriplesQueryClient
from . clients.embeddings_client import EmbeddingsClient
from . clients.prompt_client import PromptClient
from . schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
from . schema import TriplesQueryRequest, TriplesQueryResponse
from . schema import prompt_request_queue
from . schema import prompt_response_queue
from . schema import embeddings_request_queue
from . schema import embeddings_response_queue
from . schema import document_embeddings_request_queue
from . schema import document_embeddings_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class DocumentRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
de_request_queue=None,
de_response_queue=None,
verbose=False,
module="test",
):
self.verbose=verbose
if pr_request_queue is None:
pr_request_queue = prompt_request_queue
if pr_response_queue is None:
pr_response_queue = prompt_response_queue
if emb_request_queue is None:
emb_request_queue = embeddings_request_queue
if emb_response_queue is None:
emb_response_queue = embeddings_response_queue
if de_request_queue is None:
de_request_queue = document_embeddings_request_queue
if de_response_queue is None:
de_response_queue = document_embeddings_response_queue
if self.verbose:
print("Initialising...", flush=True)
# FIXME: Configurable
self.entity_limit = 20
self.de_client = DocumentEmbeddingsClient(
pulsar_host=pulsar_host,
subscriber=module + "-de",
input_queue=de_request_queue,
output_queue=de_response_queue,
)
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
)
self.lang = PromptClient(
pulsar_host=pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber=module + "-de-prompt",
)
if self.verbose:
print("Initialised", flush=True)
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_docs(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
docs = self.de_client.request(
vectors, self.entity_limit
)
if self.verbose:
print("Docs:", flush=True)
for doc in docs:
print(doc, flush=True)
return docs
def query(self, query):
if self.verbose:
print("Construct prompt...", flush=True)
docs = self.get_docs(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(docs)
print(query)
resp = self.lang.request_document_prompt(query, docs)
if self.verbose:
print("Done", flush=True)
return resp

View file

@ -1,3 +0,0 @@
from . processor import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . write import run
if __name__ == '__main__':
run()

View file

@ -1,85 +0,0 @@
"""
Write graph embeddings to parquet files in a directory.
"""
import pulsar
import base64
import os
import argparse
import time
from .... schema import GraphEmbeddings
from .... schema import graph_embeddings_store_queue
from .... base import Consumer
from . writer import ParquetWriter
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = graph_embeddings_store_queue
default_subscriber = module
default_graph_host='localhost'
default_directory = "."
default_file_template = "graph-embeds-{id}.parquet"
default_rotation_time = 60
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
directory = params.get("directory", default_directory)
file_template = params.get("file_template", default_file_template)
rotation_time = params.get("rotation_time", default_rotation_time)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": GraphEmbeddings,
}
)
self.writer = ParquetWriter(directory, file_template, rotation_time)
def __del__(self):
if hasattr(self, "writer"):
del self.writer
def handle(self, msg):
v = msg.value()
self.writer.write(v.vectors, v.entity.value)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-d', '--directory',
default=default_directory,
help=f'Directory to write to (default: {default_directory})'
)
parser.add_argument(
'-f', '--file-template',
default=default_file_template,
help=f'Directory to write to (default: {default_file_template})'
)
parser.add_argument(
'-t', '--rotation-time',
type=int,
default=default_rotation_time,
help=f'Rotation time / seconds (default: {default_rotation_time})'
)
def run():
Processor.start(module, __doc__)

View file

@ -1,94 +0,0 @@
import threading
import queue
import time
import uuid
import pyarrow as pa
import pyarrow.parquet as pq
class ParquetWriter:
def __init__(self, directory, file_template, rotation_time):
self.directory = directory
self.file_template = file_template
self.rotation_time = rotation_time
self.q = queue.Queue()
self.running = True
self.thread = threading.Thread(target=(self.writer_thread))
self.thread.start()
def writer_thread(self):
items = []
timeout = None
while self.running:
try:
item = self.q.get(timeout=1)
if timeout == None:
timeout = time.time() + self.rotation_time
items.append(item)
except queue.Empty:
pass
if timeout:
if time.time() > timeout:
self.write_file(items)
timeout = None
items = []
def write_file(self, items):
try:
schema = pa.schema([
pa.field('embeddings', pa.list_(pa.list_(pa.float64()))),
pa.field('entity', pa.string()),
])
fname = self.file_template.format(id=str(uuid.uuid4()))
path = f"{self.directory}/{fname}"
writer = pq.ParquetWriter(path, schema)
batch = pa.record_batch(
[
[i[0] for i in items],
[i[1] for i in items],
],
names=['embeddings', 'entity']
)
writer.write_batch(batch)
writer.close()
print(f"Wrote {path}.")
except Exception as e:
print("Parquet write:", e)
def write(self, embeds, ent):
self.q.put((embeds, ent))
def __del__(self):
self.running = False
if hasattr(self, "q"):
self.thread.join()

View file

@ -1,3 +0,0 @@
from . processor import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . write import run
if __name__ == '__main__':
run()

View file

@ -1,85 +0,0 @@
"""
Write graphs triples to parquet files in a directory.
"""
import pulsar
import base64
import os
import argparse
import time
from .... schema import Triple
from .... schema import triples_store_queue
from .... base import Consumer
from . writer import ParquetWriter
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = triples_store_queue
default_subscriber = module
default_graph_host='localhost'
default_directory = "."
default_file_template = "triples-{id}.parquet"
default_rotation_time = 60
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
directory = params.get("directory", default_directory)
file_template = params.get("file_template", default_file_template)
rotation_time = params.get("rotation_time", default_rotation_time)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": Triple,
}
)
self.writer = ParquetWriter(directory, file_template, rotation_time)
def __del__(self):
if hasattr(self, "writer"):
del self.writer
def handle(self, msg):
v = msg.value()
self.writer.write(v.s.value, v.p.value, v.o.value)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
parser.add_argument(
'-d', '--directory',
default=default_directory,
help=f'Directory to write to (default: {default_directory})'
)
parser.add_argument(
'-f', '--file-template',
default=default_file_template,
help=f'Directory to write to (default: {default_file_template})'
)
parser.add_argument(
'-t', '--rotation-time',
type=int,
default=default_rotation_time,
help=f'Rotation time / seconds (default: {default_rotation_time})'
)
def run():
Processor.start(module, __doc__)

View file

@ -1,96 +0,0 @@
import threading
import queue
import time
import uuid
import pyarrow as pa
import pyarrow.parquet as pq
class ParquetWriter:
def __init__(self, directory, file_template, rotation_time):
self.directory = directory
self.file_template = file_template
self.rotation_time = rotation_time
self.q = queue.Queue()
self.running = True
self.thread = threading.Thread(target=(self.writer_thread))
self.thread.start()
def writer_thread(self):
triples = []
timeout = None
while self.running:
try:
item = self.q.get(timeout=1)
if timeout == None:
timeout = time.time() + self.rotation_time
triples.append(item)
except queue.Empty:
pass
if timeout:
if time.time() > timeout:
self.write_file(triples)
timeout = None
triples = []
def write_file(self, triples):
try:
schema = pa.schema([
pa.field('s', pa.string()),
pa.field('p', pa.string()),
pa.field('o', pa.string()),
])
fname = self.file_template.format(id=str(uuid.uuid4()))
path = f"{self.directory}/{fname}"
writer = pq.ParquetWriter(path, schema)
batch = pa.record_batch(
[
[tpl[0] for tpl in triples],
[tpl[1] for tpl in triples],
[tpl[2] for tpl in triples],
],
names=['s', 'p', 'o']
)
writer.write_batch(batch)
writer.close()
print(f"Wrote {path}.")
except Exception as e:
print("Parquet write:", e)
def write(self, s, p, o):
self.q.put((s, p, o))
def __del__(self):
self.running = False
if hasattr(self, "q"):
self.thread.join()

View file

@ -1,3 +0,0 @@
from . hf import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . hf import run
if __name__ == '__main__':
run()

View file

@ -1,99 +0,0 @@
"""
Embeddings service, applies an embeddings model selected from HuggingFace.
Input is text, output is embeddings vector.
"""
from langchain_huggingface import HuggingFaceEmbeddings
from ... schema import EmbeddingsRequest, EmbeddingsResponse, Error
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = embeddings_request_queue
default_output_queue = embeddings_response_queue
default_subscriber = module
default_model="all-MiniLM-L6-v2"
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
model = params.get("model", default_model)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": EmbeddingsRequest,
"output_schema": EmbeddingsResponse,
}
)
self.embeddings = HuggingFaceEmbeddings(model_name=model)
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
try:
text = v.text
embeds = self.embeddings.embed_documents([text])
print("Send response...", flush=True)
r = EmbeddingsResponse(vectors=embeds, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = EmbeddingsResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-m', '--model',
default="all-MiniLM-L6-v2",
help=f'LLM model (default: all-MiniLM-L6-v2)'
)
def run():
Processor.start(module, __doc__)

View file

@ -1,3 +0,0 @@
from . processor import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . processor import run
if __name__ == '__main__':
run()

View file

@ -1,84 +0,0 @@
"""
Embeddings service, applies an embeddings model selected from HuggingFace.
Input is text, output is embeddings vector.
"""
from langchain_community.embeddings import OllamaEmbeddings
from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = embeddings_request_queue
default_output_queue = embeddings_response_queue
default_subscriber = module
default_model="mxbai-embed-large"
default_ollama = 'http://localhost:11434'
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": EmbeddingsRequest,
"output_schema": EmbeddingsResponse,
}
)
self.embeddings = OllamaEmbeddings(base_url=ollama, model=model)
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling input {id}...", flush=True)
text = v.text
embeds = self.embeddings.embed_query([text])
print("Send response...", flush=True)
r = EmbeddingsResponse(vectors=[embeds])
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-m', '--model',
default=default_model,
help=f'Embeddings model (default: {default_model})'
)
parser.add_argument(
'-r', '--ollama',
default=default_ollama,
help=f'ollama (default: {default_ollama})'
)
def run():
Processor.start(module, __doc__)

View file

@ -1,3 +0,0 @@
from . vectorize import *

View file

@ -1,6 +0,0 @@
from . vectorize import run
if __name__ == '__main__':
run()

View file

@ -1,103 +0,0 @@
"""
Vectorizer, calls the embeddings service to get embeddings for a chunk.
Input is text chunk, output is chunk and vectors.
"""
from ... schema import Chunk, ChunkEmbeddings
from ... schema import chunk_ingest_queue, chunk_embeddings_ingest_queue
from ... schema import embeddings_request_queue, embeddings_response_queue
from ... clients.embeddings_client import EmbeddingsClient
from ... log_level import LogLevel
from ... base import ConsumerProducer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_ingest_queue
default_output_queue = chunk_embeddings_ingest_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
emb_request_queue = params.get(
"embeddings_request_queue", embeddings_request_queue
)
emb_response_queue = params.get(
"embeddings_response_queue", embeddings_response_queue
)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"embeddings_request_queue": emb_request_queue,
"embeddings_response_queue": emb_response_queue,
"subscriber": subscriber,
"input_schema": Chunk,
"output_schema": ChunkEmbeddings,
}
)
self.embeddings = EmbeddingsClient(
pulsar_host=self.pulsar_host,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
)
def emit(self, source, chunk, vectors):
r = ChunkEmbeddings(source=source, chunk=chunk, vectors=vectors)
self.producer.send(r)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
try:
vectors = self.embeddings.request(chunk)
self.emit(
source=v.source,
chunk=chunk.encode("utf-8"),
vectors=vectors
)
except Exception as e:
print("Exception:", e, flush=True)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--embeddings-request-queue',
default=embeddings_request_queue,
help=f'Embeddings request queue (default: {embeddings_request_queue})',
)
parser.add_argument(
'--embeddings-response-queue',
default=embeddings_response_queue,
help=f'Embeddings request queue (default: {embeddings_response_queue})',
)
def run():
Processor.start(module, __doc__)

View file

@ -1,14 +0,0 @@
class TooManyRequests(Exception):
pass
class LlmError(Exception):
pass
class ParseError(Exception):
pass

View file

@ -1,3 +0,0 @@
from . extract import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . extract import run
if __name__ == '__main__':
run()

View file

@ -1,134 +0,0 @@
"""
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
get entity definitions which are output as graph edges.
"""
import urllib.parse
import json
from .... schema import ChunkEmbeddings, Triple, Source, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
from .... base import ConsumerProducer
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue
default_output_queue = triples_store_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings,
"output_schema": Triple,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
}
)
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",
)
def to_uri(self, text):
part = text.replace(" ", "-").lower().encode("utf-8")
quoted = urllib.parse.quote(part)
uri = TRUSTGRAPH_ENTITIES + quoted
return uri
def get_definitions(self, chunk):
return self.prompt.request_definitions(chunk)
def emit_edge(self, s, p, o):
t = Triple(s=s, p=p, o=o)
self.producer.send(t)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
try:
defs = self.get_definitions(chunk)
for defn in defs:
s = defn.name
o = defn.definition
if s == "": continue
if o == "": continue
if s is None: continue
if o is None: continue
s_uri = self.to_uri(s)
s_value = Value(value=str(s_uri), is_uri=True)
o_value = Value(value=str(o), is_uri=False)
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
except Exception as e:
print("Exception: ", e, flush=True)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-completion-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
def run():
Processor.start(module, __doc__)

View file

@ -1,3 +0,0 @@
from . extract import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . extract import run
if __name__ == '__main__':
run()

View file

@ -1,208 +0,0 @@
"""
Simple decoder, accepts vector+text chunks input, applies entity
relationship analysis to get entity relationship edges which are output as
graph edges.
"""
import urllib.parse
import os
from pulsar.schema import JsonSchema
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import graph_embeddings_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES
from .... base import ConsumerProducer
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue
default_output_queue = triples_store_queue
default_vector_queue = graph_embeddings_store_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
vector_queue = params.get("vector_queue", default_vector_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings,
"output_schema": Triple,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
}
)
self.vec_prod = self.client.create_producer(
topic=vector_queue,
schema=JsonSchema(GraphEmbeddings),
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"output_queue": output_queue,
"vector_queue": vector_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings.__name__,
"output_schema": Triple.__name__,
"vector_schema": GraphEmbeddings.__name__,
})
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",
)
def to_uri(self, text):
part = text.replace(" ", "-").lower().encode("utf-8")
quoted = urllib.parse.quote(part)
uri = TRUSTGRAPH_ENTITIES + quoted
return uri
def get_relationships(self, chunk):
return self.prompt.request_relationships(chunk)
def emit_edge(self, s, p, o):
t = Triple(s=s, p=p, o=o)
self.producer.send(t)
def emit_vec(self, ent, vec):
r = GraphEmbeddings(entity=ent, vectors=vec)
self.vec_prod.send(r)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
try:
rels = self.get_relationships(chunk)
for rel in rels:
s = rel.s
p = rel.p
o = rel.o
if s == "": continue
if p == "": continue
if o == "": continue
if s is None: continue
if p is None: continue
if o is None: continue
s_uri = self.to_uri(s)
s_value = Value(value=str(s_uri), is_uri=True)
p_uri = self.to_uri(p)
p_value = Value(value=str(p_uri), is_uri=True)
if rel.o_entity:
o_uri = self.to_uri(o)
o_value = Value(value=str(o_uri), is_uri=True)
else:
o_value = Value(value=str(o), is_uri=False)
self.emit_edge(
s_value,
p_value,
o_value
)
# Label for s
self.emit_edge(
s_value,
RDF_LABEL_VALUE,
Value(value=str(s), is_uri=False)
)
# Label for p
self.emit_edge(
p_value,
RDF_LABEL_VALUE,
Value(value=str(p), is_uri=False)
)
if rel.o_entity:
# Label for o
self.emit_edge(
o_value,
RDF_LABEL_VALUE,
Value(value=str(o), is_uri=False)
)
self.emit_vec(s_value, v.vectors)
self.emit_vec(p_value, v.vectors)
if rel.o_entity:
self.emit_vec(o_value, v.vectors)
except Exception as e:
print("Exception: ", e, flush=True)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-c', '--vector-queue',
default=default_vector_queue,
help=f'Vector output queue (default: {default_vector_queue})'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
def run():
Processor.start(module, __doc__)

View file

@ -1,3 +0,0 @@
from . extract import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . extract import run
if __name__ == '__main__':
run()

View file

@ -1,134 +0,0 @@
"""
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
get entity definitions which are output as graph edges.
"""
import urllib.parse
import json
from .... schema import ChunkEmbeddings, Triple, Source, Value
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
from .... base import ConsumerProducer
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue
default_output_queue = triples_store_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings,
"output_schema": Triple,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
}
)
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",
)
def to_uri(self, text):
part = text.replace(" ", "-").lower().encode("utf-8")
quoted = urllib.parse.quote(part)
uri = TRUSTGRAPH_ENTITIES + quoted
return uri
def get_topics(self, chunk):
return self.prompt.request_topics(chunk)
def emit_edge(self, s, p, o):
t = Triple(s=s, p=p, o=o)
self.producer.send(t)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
try:
defs = self.get_topics(chunk)
for defn in defs:
s = defn.name
o = defn.definition
if s == "": continue
if o == "": continue
if s is None: continue
if o is None: continue
s_uri = self.to_uri(s)
s_value = Value(value=str(s_uri), is_uri=True)
o_value = Value(value=str(o), is_uri=False)
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
except Exception as e:
print("Exception: ", e, flush=True)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-completion-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
def run():
Processor.start(module, __doc__)

View file

@ -1,3 +0,0 @@
from . extract import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . extract import run
if __name__ == '__main__':
run()

View file

@ -1,220 +0,0 @@
"""
Simple decoder, accepts vector+text chunks input, applies analysis to pull
out a row of fields. Output as a vector plus object.
"""
import urllib.parse
import os
from pulsar.schema import JsonSchema
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Source
from .... schema import RowSchema, Field
from .... schema import chunk_embeddings_ingest_queue, rows_store_queue
from .... schema import object_embeddings_store_queue
from .... schema import prompt_request_queue
from .... schema import prompt_response_queue
from .... log_level import LogLevel
from .... clients.prompt_client import PromptClient
from .... base import ConsumerProducer
from .... objects.field import Field as FieldParser
from .... objects.object import Schema
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = chunk_embeddings_ingest_queue
default_output_queue = rows_store_queue
default_vector_queue = object_embeddings_store_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
vector_queue = params.get("vector_queue", default_vector_queue)
subscriber = params.get("subscriber", default_subscriber)
pr_request_queue = params.get(
"prompt_request_queue", prompt_request_queue
)
pr_response_queue = params.get(
"prompt_response_queue", prompt_response_queue
)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings,
"output_schema": Rows,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
}
)
self.vec_prod = self.client.create_producer(
topic=vector_queue,
schema=JsonSchema(ObjectEmbeddings),
)
__class__.pubsub_metric.info({
"input_queue": input_queue,
"output_queue": output_queue,
"vector_queue": vector_queue,
"prompt_request_queue": pr_request_queue,
"prompt_response_queue": pr_response_queue,
"subscriber": subscriber,
"input_schema": ChunkEmbeddings.__name__,
"output_schema": Rows.__name__,
"vector_schema": ObjectEmbeddings.__name__,
})
flds = __class__.parse_fields(params["field"])
for fld in flds:
print(fld)
self.primary = None
for f in flds:
if f.primary:
if self.primary:
raise RuntimeError(
"Only one primary key field is supported"
)
self.primary = f
if self.primary == None:
raise RuntimeError(
"Must have exactly one primary key field"
)
self.schema = Schema(
name = params["name"],
description = params["description"],
fields = flds
)
self.row_schema=RowSchema(
name=self.schema.name,
description=self.schema.description,
fields=[
Field(
name=f.name, type=str(f.type), size=f.size,
primary=f.primary, description=f.description,
)
for f in self.schema.fields
]
)
self.prompt = PromptClient(
pulsar_host=self.pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber = module + "-prompt",
)
@staticmethod
def parse_fields(fields):
return [ FieldParser.parse(f) for f in fields ]
def get_rows(self, chunk):
return self.prompt.request_rows(self.schema, chunk)
def emit_rows(self, source, rows):
t = Rows(
source=source, row_schema=self.row_schema, rows=rows
)
self.producer.send(t)
def emit_vec(self, source, name, vec, key_name, key):
r = ObjectEmbeddings(
source=source, vectors=vec, name=name, key_name=key_name, id=key
)
self.vec_prod.send(r)
def handle(self, msg):
v = msg.value()
print(f"Indexing {v.source.id}...", flush=True)
chunk = v.chunk.decode("utf-8")
try:
rows = self.get_rows(chunk)
self.emit_rows(
source=v.source,
rows=rows
)
for row in rows:
self.emit_vec(
source=v.source, vec=v.vectors,
name=self.schema.name, key_name=self.primary.name,
key=row[self.primary.name]
)
for row in rows:
print(row)
except Exception as e:
print("Exception: ", e, flush=True)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-c', '--vector-queue',
default=default_vector_queue,
help=f'Vector output queue (default: {default_vector_queue})'
)
parser.add_argument(
'--prompt-request-queue',
default=prompt_request_queue,
help=f'Prompt request queue (default: {prompt_request_queue})',
)
parser.add_argument(
'--prompt-response-queue',
default=prompt_response_queue,
help=f'Prompt response queue (default: {prompt_response_queue})',
)
parser.add_argument(
'-f', '--field',
required=True,
action='append',
help=f'Field definition, format name:type:size:pri:descriptionn',
)
parser.add_argument(
'-n', '--name',
required=True,
help=f'Name of row object',
)
parser.add_argument(
'-d', '--description',
required=True,
help=f'Description of object',
)
def run():
Processor.start(module, __doc__)

View file

@ -1,250 +0,0 @@
from . clients.graph_embeddings_client import GraphEmbeddingsClient
from . clients.triples_query_client import TriplesQueryClient
from . clients.embeddings_client import EmbeddingsClient
from . clients.prompt_client import PromptClient
from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
from . schema import TriplesQueryRequest, TriplesQueryResponse
from . schema import prompt_request_queue
from . schema import prompt_response_queue
from . schema import embeddings_request_queue
from . schema import embeddings_response_queue
from . schema import graph_embeddings_request_queue
from . schema import graph_embeddings_response_queue
from . schema import triples_request_queue
from . schema import triples_response_queue
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
class GraphRag:
def __init__(
self,
pulsar_host="pulsar://pulsar:6650",
pr_request_queue=None,
pr_response_queue=None,
emb_request_queue=None,
emb_response_queue=None,
ge_request_queue=None,
ge_response_queue=None,
tpl_request_queue=None,
tpl_response_queue=None,
verbose=False,
entity_limit=50,
triple_limit=30,
max_subgraph_size=3000,
module="test",
):
self.verbose=verbose
if pr_request_queue is None:
pr_request_queue = prompt_request_queue
if pr_response_queue is None:
pr_response_queue = prompt_response_queue
if emb_request_queue is None:
emb_request_queue = embeddings_request_queue
if emb_response_queue is None:
emb_response_queue = embeddings_response_queue
if ge_request_queue is None:
ge_request_queue = graph_embeddings_request_queue
if ge_response_queue is None:
ge_response_queue = graph_embeddings_response_queue
if tpl_request_queue is None:
tpl_request_queue = triples_request_queue
if tpl_response_queue is None:
tpl_response_queue = triples_response_queue
if self.verbose:
print("Initialising...", flush=True)
self.ge_client = GraphEmbeddingsClient(
pulsar_host=pulsar_host,
subscriber=module + "-ge",
input_queue=ge_request_queue,
output_queue=ge_response_queue,
)
self.triples_client = TriplesQueryClient(
pulsar_host=pulsar_host,
subscriber=module + "-tpl",
input_queue=tpl_request_queue,
output_queue=tpl_response_queue
)
self.embeddings = EmbeddingsClient(
pulsar_host=pulsar_host,
input_queue=emb_request_queue,
output_queue=emb_response_queue,
subscriber=module + "-emb",
)
self.entity_limit=entity_limit
self.query_limit=triple_limit
self.max_subgraph_size=max_subgraph_size
self.label_cache = {}
self.lang = PromptClient(
pulsar_host=pulsar_host,
input_queue=pr_request_queue,
output_queue=pr_response_queue,
subscriber=module + "-prompt",
)
if self.verbose:
print("Initialised", flush=True)
def get_vector(self, query):
if self.verbose:
print("Compute embeddings...", flush=True)
qembeds = self.embeddings.request(query)
if self.verbose:
print("Done.", flush=True)
return qembeds
def get_entities(self, query):
vectors = self.get_vector(query)
if self.verbose:
print("Get entities...", flush=True)
entities = self.ge_client.request(
vectors, self.entity_limit
)
entities = [
e.value
for e in entities
]
if self.verbose:
print("Entities:", flush=True)
for ent in entities:
print(" ", ent, flush=True)
return entities
def maybe_label(self, e):
if e in self.label_cache:
return self.label_cache[e]
res = self.triples_client.request(
e, LABEL, None, limit=1
)
if len(res) == 0:
self.label_cache[e] = e
return e
self.label_cache[e] = res[0].o.value
return self.label_cache[e]
def get_subgraph(self, query):
entities = self.get_entities(query)
subgraph = set()
if self.verbose:
print("Get subgraph...", flush=True)
for e in entities:
res = self.triples_client.request(
e, None, None,
limit=self.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.triples_client.request(
None, e, None,
limit=self.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
res = self.triples_client.request(
None, None, e,
limit=self.query_limit
)
for triple in res:
subgraph.add(
(triple.s.value, triple.p.value, triple.o.value)
)
subgraph = list(subgraph)
subgraph = subgraph[0:self.max_subgraph_size]
if self.verbose:
print("Subgraph:", flush=True)
for edge in subgraph:
print(" ", str(edge), flush=True)
if self.verbose:
print("Done.", flush=True)
return subgraph
def get_labelgraph(self, query):
subgraph = self.get_subgraph(query)
sg2 = []
for edge in subgraph:
if edge[1] == LABEL:
continue
s = self.maybe_label(edge[0])
p = self.maybe_label(edge[1])
o = self.maybe_label(edge[2])
sg2.append((s, p, o))
return sg2
def query(self, query):
if self.verbose:
print("Construct prompt...", flush=True)
kg = self.get_labelgraph(query)
if self.verbose:
print("Invoke LLM...", flush=True)
print(kg)
print(query)
resp = self.lang.request_kg_prompt(query, kg)
if self.verbose:
print("Done", flush=True)
return resp

View file

@ -1,20 +0,0 @@
from enum import Enum
import _pulsar
class LogLevel(Enum):
DEBUG = 'debug'
INFO = 'info'
WARN = 'warn'
ERROR = 'error'
def __str__(self):
return self.value
def to_pulsar(self):
if self == LogLevel.DEBUG: return _pulsar.LoggerLevel.Debug
if self == LogLevel.INFO: return _pulsar.LoggerLevel.Info
if self == LogLevel.WARN: return _pulsar.LoggerLevel.Warn
if self == LogLevel.ERROR: return _pulsar.LoggerLevel.Error
raise RuntimeError("Log level mismatch")

View file

@ -1,3 +0,0 @@
from . counter import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . counter import run
if __name__ == '__main__':
run()

View file

@ -1,75 +0,0 @@
"""
Simple token counter for each LLM response.
"""
from prometheus_client import Histogram, Info
from . pricelist import price_list
from .. schema import TextCompletionResponse, Error
from .. schema import text_completion_response_queue
from .. log_level import LogLevel
from .. base import Consumer
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_response_queue
default_subscriber = module
class Processor(Consumer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
subscriber = params.get("subscriber", default_subscriber)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"subscriber": subscriber,
"input_schema": TextCompletionResponse,
}
)
def get_prices(self, prices, modelname):
for model in prices["price_list"]:
if model["model_name"] == modelname:
return model["input_price"], model["output_price"]
return None, None # Return None if model is not found
def handle(self, msg):
v = msg.value()
modelname = v.model
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling response {id}...", flush=True)
num_in = v.in_token
num_out = v.out_token
model_input_price, model_output_price = self.get_prices(price_list, modelname)
if model_input_price == None:
cost_per_call = f"Model Not Found in Price list"
else:
cost_in = num_in * model_input_price
cost_out = num_out * model_output_price
cost_per_call = round(cost_in + cost_out, 6)
print(f"Input Tokens: {num_in}", flush=True)
print(f"Output Tokens: {num_out}", flush=True)
print(f"Cost for call: ${cost_per_call}", flush=True)
@staticmethod
def add_args(parser):
Consumer.add_args(
parser, default_input_queue, default_subscriber,
)
def run():
Processor.start(module, __doc__)

View file

@ -1,104 +0,0 @@
price_list = {
"price_list": [
{
"model_name": "mistral.mistral-large-2407-v1:0",
"input_price": 0.000004,
"output_price": 0.000012
},
{
"model_name": "meta.llama3-1-405b-instruct-v1:0",
"input_price": 0.00000532,
"output_price": 0.000016
},
{
"model_name": "mistral.mixtral-8x7b-instruct-v0:1",
"input_price": 0.00000045,
"output_price": 0.0000007
},
{
"model_name": "meta.llama3-1-70b-instruct-v1:0",
"input_price": 0.00000099,
"output_price": 0.00000099
},
{
"model_name": "meta.llama3-1-8b-instruct-v1:0",
"input_price": 0.00000022,
"output_price": 0.00000022
},
{
"model_name": "anthropic.claude-3-haiku-20240307-v1:0",
"input_price": 0.00000025,
"output_price": 0.00000125
},
{
"model_name": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"input_price": 0.000003,
"output_price": 0.000015
},
{
"model_name": "cohere.command-r-plus-v1:0",
"input_price": 0.0000030,
"output_price": 0.0000150
},
{
"model_name": "ollama",
"input_price": 0,
"output_price": 0
},
{
"model_name": "claude-3-haiku-20240307",
"input_price": 0.00000025,
"output_price": 0.00000125
},
{
"model_name": "claude-3-5-sonnet-20240620",
"input_price": 0.000003,
"output_price": 0.000015
},
{
"model_name": "claude-3-opus-20240229",
"input_price": 0.000015,
"output_price": 0.000075
},
{
"model_name": "claude-3-sonnet-20240229",
"input_price": 0.000003,
"output_price": 0.000015
},
{
"model_name": "command-r-08-202",
"input_price": 0.0000025,
"output_price": 0.000010
},
{
"model_name": "c4ai-aya-23-8b",
"input_price": 0,
"output_price": 0
},
{
"model_name": "llama.cpp",
"input_price": 0,
"output_price": 0
},
{
"model_name": "gpt-4o",
"input_price": 0.000005,
"output_price": 0.000015
},
{
"model_name": "gpt-4o-2024-08-06",
"input_price": 0.0000025,
"output_price": 0.000010
},
{
"model_name": "gpt-4o-2024-05-13",
"input_price": 0.000005,
"output_price": 0.000015
},
{
"model_name": "gpt-4o-mini",
"input_price": 0.00000015,
"output_price": 0.0000006
},
]
}

View file

@ -1,3 +0,0 @@
from . service import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . service import run
if __name__ == '__main__':
run()

View file

@ -1,176 +0,0 @@
def to_relationships(text):
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.
Read the provided text. You will model the text as an information network for a RDF knowledge graph in JSON.
Information Network Rules:
- An information network has subjects connected by predicates to objects.
- A subject is a named-entity or a conceptual topic.
- One subject can have many predicates and objects.
- An object is a property or attribute of a subject.
- A subject can be connected by a predicate to another subject.
Reading Instructions:
- Ignore document formatting in the provided text.
- Study the provided text carefully.
Here is the text:
{text}
Response Instructions:
- Obey the information network rules.
- Do not return special characters.
- Respond only with well-formed JSON.
- The JSON response shall be an array of JSON objects with keys "subject", "predicate", "object", and "object-entity".
- The JSON response shall use the following structure:
```json
[{{"subject": string, "predicate": string, "object": string, "object-entity": boolean}}]
```
- The key "object-entity" is TRUE only if the "object" is a subject.
- Do not write any additional text or explanations.
"""
return prompt
def to_topics(text):
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify topics and their definitions in JSON.
Reading Instructions:
- Ignore document formatting in the provided text.
- Study the provided text carefully.
Here is the text:
{text}
Response Instructions:
- Do not respond with special characters.
- Return only topics that are concepts and unique to the provided text.
- Respond only with well-formed JSON.
- The JSON response shall be an array of objects with keys "topic" and "definition".
- The JSON response shall use the following structure:
```json
[{{"topic": string, "definition": string}}]
```
- Do not write any additional text or explanations.
"""
return prompt
def to_definitions(text):
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify entities and their definitions in JSON.
Reading Instructions:
- Ignore document formatting in the provided text.
- Study the provided text carefully.
Here is the text:
{text}
Response Instructions:
- Do not respond with special characters.
- Return only entities that are named-entities such as: people, organizations, physical objects, locations, animals, products, commodotities, or substances.
- Respond only with well-formed JSON.
- The JSON response shall be an array of objects with keys "entity" and "definition".
- The JSON response shall use the following structure:
```json
[{{"entity": string, "definition": string}}]
```
- Do not write any additional text or explanations.
"""
return prompt
def to_rows(schema, text):
field_schema = [
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
for f in schema.fields
]
field_schema = "\n".join(field_schema)
schema = f"""Object name: {schema.name}
Description: {schema.description}
Fields:
{field_schema}"""
prompt = f"""<instructions>
Study the following text and derive objects which match the schema provided.
You must output an array of JSON objects for each object you discover
which matches the schema. For each object, output a JSON object whose fields
carry the name field specified in the schema.
</instructions>
<schema>
{schema}
</schema>
<text>
{text}
</text>
<requirements>
You will respond only with raw JSON format data. Do not provide
explanations. Do not add markdown formatting or headers or prefixes.
</requirements>"""
return prompt
def get_cypher(kg):
sg2 = []
for f in kg:
print(f)
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
print(sg2)
kg = "\n".join(sg2)
kg = kg.replace("\\", "-")
return kg
def to_kg_query(query, kg):
cypher = get_cypher(kg)
prompt=f"""Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
Here's the knowledge statements:
{cypher}
Use only the provided knowledge statements to respond to the following:
{query}
"""
return prompt
def to_document_query(query, documents):
documents = "\n\n".join(documents)
prompt=f"""Study the following context. Use only the information provided in the context in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
Here is the context:
{documents}
Use only the provided knowledge statements to respond to the following:
{query}
"""
return prompt

View file

@ -1,473 +0,0 @@
"""
Language service abstracts prompt engineering from LLM.
"""
import json
import re
from .... schema import Definition, Relationship, Triple
from .... schema import Topic
from .... schema import PromptRequest, PromptResponse, Error
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... schema import prompt_request_queue, prompt_response_queue
from .... base import ConsumerProducer
from .... clients.llm_client import LlmClient
from . prompts import to_definitions, to_relationships, to_topics
from . prompts import to_kg_query, to_document_query, to_rows
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = prompt_request_queue
default_output_queue = prompt_response_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": PromptRequest,
"output_schema": PromptResponse,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
}
)
self.llm = LlmClient(
subscriber=subscriber,
input_queue=tc_request_queue,
output_queue=tc_response_queue,
pulsar_host = self.pulsar_host
)
def parse_json(self, text):
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
if json_match:
json_str = json_match.group(1).strip()
else:
# If no delimiters, assume the entire output is JSON
json_str = text.strip()
return json.loads(json_str)
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
kind = v.kind
print(f"Handling kind {kind}...", flush=True)
if kind == "extract-definitions":
self.handle_extract_definitions(id, v)
return
elif kind == "extract-topics":
self.handle_extract_topics(id, v)
return
elif kind == "extract-relationships":
self.handle_extract_relationships(id, v)
return
elif kind == "extract-rows":
self.handle_extract_rows(id, v)
return
elif kind == "kg-prompt":
self.handle_kg_prompt(id, v)
return
elif kind == "document-prompt":
self.handle_document_prompt(id, v)
return
else:
print("Invalid kind.", flush=True)
return
def handle_extract_definitions(self, id, v):
try:
prompt = to_definitions(v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
e = defn["entity"]
d = defn["definition"]
if e == "": continue
if e is None: continue
if d == "": continue
if d is None: continue
output.append(
Definition(
name=e, definition=d
)
)
except:
print("definition fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(definitions=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_extract_topics(self, id, v):
try:
prompt = to_topics(v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
e = defn["topic"]
d = defn["definition"]
if e == "": continue
if e is None: continue
if d == "": continue
if d is None: continue
output.append(
Topic(
name=e, definition=d
)
)
except:
print("definition fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(topics=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_extract_relationships(self, id, v):
try:
prompt = to_relationships(v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
s = defn["subject"]
p = defn["predicate"]
o = defn["object"]
o_entity = defn["object-entity"]
if s == "": continue
if s is None: continue
if p == "": continue
if p is None: continue
if o == "": continue
if o is None: continue
if o_entity == "" or o_entity is None:
o_entity = False
output.append(
Relationship(
s = s,
p = p,
o = o,
o_entity = o_entity,
)
)
except Exception as e:
print("relationship fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(relationships=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_extract_rows(self, id, v):
try:
fields = v.row_schema.fields
prompt = to_rows(v.row_schema, v.chunk)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
# Silently ignore JSON parse error
try:
objs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
objs = []
output = []
for obj in objs:
try:
row = {}
for f in fields:
if f.name not in obj:
print(f"Object ignored, missing field {f.name}")
row = {}
break
row[f.name] = obj[f.name]
if row == {}:
continue
output.append(row)
except Exception as e:
print("row fields missing, ignored", flush=True)
for row in output:
print(row)
print("Send response...", flush=True)
r = PromptResponse(rows=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_kg_prompt(self, id, v):
try:
prompt = to_kg_query(v.query, v.kg)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
print("Send response...", flush=True)
r = PromptResponse(answer=ans, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_document_prompt(self, id, v):
try:
prompt = to_document_query(v.query, v.documents)
print("prompt")
print(prompt)
print("Call LLM...")
ans = self.llm.request(prompt)
print(ans)
print("Send response...", flush=True)
r = PromptResponse(answer=ans, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
)
def run():
Processor.start(module, __doc__)

View file

@ -1,3 +0,0 @@
from . service import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . service import run
if __name__ == '__main__':
run()

View file

@ -1,47 +0,0 @@
def to_relationships(template, text):
return template.format(text=text)
def to_definitions(template, text):
return template.format(text=text)
def to_topics(template, text):
return template.format(text=text)
def to_rows(template, schema, text):
field_schema = [
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
for f in schema.fields
]
field_schema = "\n".join(field_schema)
return template.format(schema=schema, text=text)
schema = f"""Object name: {schema.name}
Description: {schema.description}
Fields:
{schema}"""
prompt = f""""""
return prompt
def get_cypher(kg):
sg2 = []
for f in kg:
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
kg = "\n".join(sg2)
kg = kg.replace("\\", "-")
return kg
def to_kg_query(template, query, kg):
cypher = get_cypher(kg)
return template.format(query=query, graph=cypher)
def to_document_query(template, query, docs):
docs = "\n\n".join(docs)
return template.format(query=query, documents=docs)

View file

@ -1,523 +0,0 @@
"""
Language service abstracts prompt engineering from LLM.
"""
import json
import re
from .... schema import Definition, Relationship, Triple
from .... schema import Topic
from .... schema import PromptRequest, PromptResponse, Error
from .... schema import TextCompletionRequest, TextCompletionResponse
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... schema import prompt_request_queue, prompt_response_queue
from .... base import ConsumerProducer
from .... clients.llm_client import LlmClient
from . prompts import to_definitions, to_relationships, to_rows
from . prompts import to_kg_query, to_document_query, to_topics
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = prompt_request_queue
default_output_queue = prompt_response_queue
default_subscriber = module
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
tc_request_queue = params.get(
"text_completion_request_queue", text_completion_request_queue
)
tc_response_queue = params.get(
"text_completion_response_queue", text_completion_response_queue
)
definition_template = params.get("definition_template")
relationship_template = params.get("relationship_template")
topic_template = params.get("topic_template")
rows_template = params.get("rows_template")
knowledge_query_template = params.get("knowledge_query_template")
document_query_template = params.get("document_query_template")
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": PromptRequest,
"output_schema": PromptResponse,
"text_completion_request_queue": tc_request_queue,
"text_completion_response_queue": tc_response_queue,
}
)
self.llm = LlmClient(
subscriber=subscriber,
input_queue=tc_request_queue,
output_queue=tc_response_queue,
pulsar_host = self.pulsar_host
)
self.definition_template = definition_template
self.topic_template = topic_template
self.relationship_template = relationship_template
self.rows_template = rows_template
self.knowledge_query_template = knowledge_query_template
self.document_query_template = document_query_template
def parse_json(self, text):
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
if json_match:
json_str = json_match.group(1).strip()
else:
# If no delimiters, assume the entire output is JSON
json_str = text.strip()
return json.loads(json_str)
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
kind = v.kind
print(f"Handling kind {kind}...", flush=True)
if kind == "extract-definitions":
self.handle_extract_definitions(id, v)
return
elif kind == "extract-topics":
self.handle_extract_topics(id, v)
return
elif kind == "extract-relationships":
self.handle_extract_relationships(id, v)
return
elif kind == "extract-rows":
self.handle_extract_rows(id, v)
return
elif kind == "kg-prompt":
self.handle_kg_prompt(id, v)
return
elif kind == "document-prompt":
self.handle_document_prompt(id, v)
return
else:
print("Invalid kind.", flush=True)
return
def handle_extract_definitions(self, id, v):
try:
prompt = to_definitions(self.definition_template, v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
e = defn["entity"]
d = defn["definition"]
if e == "": continue
if e is None: continue
if d == "": continue
if d is None: continue
output.append(
Definition(
name=e, definition=d
)
)
except:
print("definition fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(definitions=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_extract_topics(self, id, v):
try:
prompt = to_topics(self.topic_template, v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
e = defn["topic"]
d = defn["definition"]
if e == "": continue
if e is None: continue
if d == "": continue
if d is None: continue
output.append(
Topic(
name=e, definition=d
)
)
except:
print("definition fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(topics=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_extract_relationships(self, id, v):
try:
prompt = to_relationships(self.relationship_template, v.chunk)
ans = self.llm.request(prompt)
# Silently ignore JSON parse error
try:
defs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
defs = []
output = []
for defn in defs:
try:
s = defn["subject"]
p = defn["predicate"]
o = defn["object"]
o_entity = defn["object-entity"]
if s == "": continue
if s is None: continue
if p == "": continue
if p is None: continue
if o == "": continue
if o is None: continue
if o_entity == "" or o_entity is None:
o_entity = False
output.append(
Relationship(
s = s,
p = p,
o = o,
o_entity = o_entity,
)
)
except Exception as e:
print("relationship fields missing, ignored", flush=True)
print("Send response...", flush=True)
r = PromptResponse(relationships=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_extract_rows(self, id, v):
try:
fields = v.row_schema.fields
prompt = to_rows(self.rows_template, v.row_schema, v.chunk)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
# Silently ignore JSON parse error
try:
objs = self.parse_json(ans)
except:
print("JSON parse error, ignored", flush=True)
objs = []
output = []
for obj in objs:
try:
row = {}
for f in fields:
if f.name not in obj:
print(f"Object ignored, missing field {f.name}")
row = {}
break
row[f.name] = obj[f.name]
if row == {}:
continue
output.append(row)
except Exception as e:
print("row fields missing, ignored", flush=True)
for row in output:
print(row)
print("Send response...", flush=True)
r = PromptResponse(rows=output, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_kg_prompt(self, id, v):
try:
prompt = to_kg_query(self.knowledge_query_template, v.query, v.kg)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
print("Send response...", flush=True)
r = PromptResponse(answer=ans, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
def handle_document_prompt(self, id, v):
try:
prompt = to_document_query(
self.document_query_template, v.query, v.documents
)
print(prompt)
ans = self.llm.request(prompt)
print(ans)
print("Send response...", flush=True)
r = PromptResponse(answer=ans, error=None)
self.producer.send(r, properties={"id": id})
print("Done.", flush=True)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = PromptResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
)
self.producer.send(r, properties={"id": id})
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'--text-completion-request-queue',
default=text_completion_request_queue,
help=f'Text completion request queue (default: {text_completion_request_queue})',
)
parser.add_argument(
'--text-completion-response-queue',
default=text_completion_response_queue,
help=f'Text completion response queue (default: {text_completion_response_queue})',
)
parser.add_argument(
'--definition-template',
required=True,
help=f'Definition extraction template',
)
parser.add_argument(
'--topic-template',
required=True,
help=f'Topic extraction template',
)
parser.add_argument(
'--rows-template',
required=True,
help=f'Rows extraction template',
)
parser.add_argument(
'--relationship-template',
required=True,
help=f'Relationship extraction template',
)
parser.add_argument(
'--knowledge-query-template',
required=True,
help=f'Knowledge query template',
)
parser.add_argument(
'--document-query-template',
required=True,
help=f'Document query template',
)
def run():
Processor.start(module, __doc__)

View file

@ -1,3 +0,0 @@
from . llm import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . llm import run
if __name__ == '__main__':
run()

View file

@ -1,226 +0,0 @@
"""
Simple LLM service, performs text prompt completion using the Azure
serverless endpoint service. Input is prompt, output is response.
"""
import requests
import json
from prometheus_client import Histogram
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
default_subscriber = module
default_temperature = 0.0
default_max_output = 4192
default_model = "AzureAI"
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
endpoint = params.get("endpoint")
token = params.get("token")
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
model = default_model
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextCompletionRequest,
"output_schema": TextCompletionResponse,
"temperature": temperature,
"max_output": max_output,
"model": model,
}
)
if not hasattr(__class__, "text_completion_metric"):
__class__.text_completion_metric = Histogram(
'text_completion_duration',
'Text completion duration (seconds)',
buckets=[
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
120.0
]
)
self.endpoint = endpoint
self.token = token
self.temperature = temperature
self.max_output = max_output
self.model = model
def build_prompt(self, system, content):
data = {
"messages": [
{
"role": "system", "content": system
},
{
"role": "user", "content": content
}
],
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": 1
}
body = json.dumps(data)
return body
def call_llm(self, body):
url = self.endpoint
# Replace this with the primary/secondary key, AMLToken, or
# Microsoft Entra ID token for the endpoint
api_key = self.token
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
resp = requests.post(url, data=body, headers=headers)
if resp.status_code == 429:
raise TooManyRequests()
if resp.status_code != 200:
raise RuntimeError("LLM failure")
result = resp.json()
return result
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling prompt {id}...", flush=True)
try:
prompt = self.build_prompt(
"You are a helpful chatbot",
v.prompt
)
with __class__.text_completion_metric.time():
response = self.call_llm(prompt)
resp = response['choices'][0]['message']['content']
inputtokens = response['usage']['prompt_tokens']
outputtokens = response['usage']['completion_tokens']
print(resp, flush=True)
print(f"Input Tokens: {inputtokens}", flush=True)
print(f"Output Tokens: {outputtokens}", flush=True)
print("Send response...", flush=True)
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
self.producer.send(r, properties={"id": id})
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
print("Done.", flush=True)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-e', '--endpoint',
help=f'LLM model endpoint'
)
parser.add_argument(
'-k', '--token',
help=f'LLM model token'
)
parser.add_argument(
'-t', '--temperature',
type=float,
default=default_temperature,
help=f'LLM temperature parameter (default: {default_temperature})'
)
parser.add_argument(
'-x', '--max-output',
type=int,
default=default_max_output,
help=f'LLM max output tokens (default: {default_max_output})'
)
def run():
Processor.start(module, __doc__)

View file

@ -1,3 +0,0 @@
from . llm import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . llm import run
if __name__ == '__main__':
run()

View file

@ -1,323 +0,0 @@
"""
Simple LLM service, performs text prompt completion using AWS Bedrock.
Input is prompt, output is response. Mistral is default.
"""
import boto3
import json
from prometheus_client import Histogram
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
default_subscriber = module
default_model = 'mistral.mistral-large-2407-v1:0'
default_region = 'us-west-2'
default_temperature = 0.0
default_max_output = 2048
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
model = params.get("model", default_model)
aws_id = params.get("aws_id_key")
aws_secret = params.get("aws_secret")
aws_region = params.get("aws_region", default_region)
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextCompletionRequest,
"output_schema": TextCompletionResponse,
"model": model,
"temperature": temperature,
"max_output": max_output,
}
)
if not hasattr(__class__, "text_completion_metric"):
__class__.text_completion_metric = Histogram(
'text_completion_duration',
'Text completion duration (seconds)',
buckets=[
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
120.0
]
)
self.model = model
self.temperature = temperature
self.max_output = max_output
self.session = boto3.Session(
aws_access_key_id=aws_id,
aws_secret_access_key=aws_secret,
region_name=aws_region
)
self.bedrock = self.session.client(service_name='bedrock-runtime')
print("Initialised", flush=True)
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling prompt {id}...", flush=True)
prompt = v.prompt
try:
# Mistral Input Format
if self.model.startswith("mistral"):
promptbody = json.dumps({
"prompt": prompt,
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": 0.99,
"top_k": 40
})
# Llama 3.1 Input Format
elif self.model.startswith("meta"):
promptbody = json.dumps({
"prompt": prompt,
"max_gen_len": self.max_output,
"temperature": self.temperature,
"top_p": 0.95,
})
# Anthropic Input Format
elif self.model.startswith("anthropic"):
promptbody = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": 0.999,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
}
]
})
# Jamba Input Format
elif self.model.startswith("ai21"):
promptbody = json.dumps({
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": 0.9,
"messages": [
{
"role": "user",
"content": prompt
}
]
})
# Cohere Input Format
elif self.model.startswith("cohere"):
promptbody = json.dumps({
"max_tokens": self.max_output,
"temperature": self.temperature,
"message": prompt
})
# Use Mistral format as defualt
else:
promptbody = json.dumps({
"prompt": prompt,
"max_tokens": self.max_output,
"temperature": self.temperature,
"top_p": 0.99,
"top_k": 40
})
accept = 'application/json'
contentType = 'application/json'
# FIXME: Consider catching request limits and raise TooManyRequests
# See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html
with __class__.text_completion_metric.time():
response = self.bedrock.invoke_model(
body=promptbody, modelId=self.model, accept=accept,
contentType=contentType
)
# Mistral Response Structure
if self.model.startswith("mistral"):
response_body = json.loads(response.get("body").read())
outputtext = response_body['outputs'][0]['text']
# Claude Response Structure
elif self.model.startswith("anthropic"):
model_response = json.loads(response["body"].read())
outputtext = model_response['content'][0]['text']
# Llama 3.1 Response Structure
elif self.model.startswith("meta"):
model_response = json.loads(response["body"].read())
outputtext = model_response["generation"]
# Jamba Response Structure
elif self.model.startswith("ai21"):
content = response['body'].read()
content_str = content.decode('utf-8')
content_json = json.loads(content_str)
outputtext = content_json['choices'][0]['message']['content']
# Cohere Input Format
elif self.model.startswith("cohere"):
content = response['body'].read()
content_str = content.decode('utf-8')
content_json = json.loads(content_str)
outputtext = content_json['text']
# Use Mistral as default
else:
response_body = json.loads(response.get("body").read())
outputtext = response_body['outputs'][0]['text']
metadata = response['ResponseMetadata']['HTTPHeaders']
inputtokens = int(metadata['x-amzn-bedrock-input-token-count'])
outputtokens = int(metadata['x-amzn-bedrock-output-token-count'])
print(outputtext, flush=True)
print(f"Input Tokens: {inputtokens}", flush=True)
print(f"Output Tokens: {outputtokens}", flush=True)
print("Send response...", flush=True)
r = TextCompletionResponse(
error=None,
response=outputtext,
in_token=inputtokens,
out_token=outputtokens,
model=str(self.model),
)
self.send(r, properties={"id": id})
print("Done.", flush=True)
# FIXME: Wrong exception, don't know what Bedrock throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
)
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-m', '--model',
default="mistral.mistral-large-2407-v1:0",
help=f'Bedrock model (default: Mistral-Large-2407)'
)
parser.add_argument(
'-z', '--aws-id-key',
help=f'AWS ID Key'
)
parser.add_argument(
'-k', '--aws-secret',
help=f'AWS Secret Key'
)
parser.add_argument(
'-r', '--aws-region',
help=f'AWS Region (default: us-west-2)'
)
parser.add_argument(
'-t', '--temperature',
type=float,
default=default_temperature,
help=f'LLM temperature parameter (default: {default_temperature})'
)
parser.add_argument(
'-x', '--max-output',
type=int,
default=default_max_output,
help=f'LLM max output tokens (default: {default_max_output})'
)
def run():
Processor.start(module, __doc__)

View file

@ -1,3 +0,0 @@
from . llm import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . llm import run
if __name__ == '__main__':
run()

View file

@ -1,199 +0,0 @@
"""
Simple LLM service, performs text prompt completion using Claude.
Input is prompt, output is response.
"""
import anthropic
from prometheus_client import Histogram
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
default_subscriber = module
default_model = 'claude-3-5-sonnet-20240620'
default_temperature = 0.0
default_max_output = 8192
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
model = params.get("model", default_model)
api_key = params.get("api_key")
temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextCompletionRequest,
"output_schema": TextCompletionResponse,
"model": model,
"temperature": temperature,
"max_output": max_output,
}
)
if not hasattr(__class__, "text_completion_metric"):
__class__.text_completion_metric = Histogram(
'text_completion_duration',
'Text completion duration (seconds)',
buckets=[
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
120.0
]
)
self.model = model
self.claude = anthropic.Anthropic(api_key=api_key)
self.temperature = temperature
self.max_output = max_output
print("Initialised", flush=True)
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling prompt {id}...", flush=True)
prompt = v.prompt
try:
# FIXME: Rate limits?
with __class__.text_completion_metric.time():
response = message = self.claude.messages.create(
model=self.model,
max_tokens=self.max_output,
temperature=self.temperature,
system = "You are a helpful chatbot.",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
}
]
}
]
)
resp = response.content[0].text
inputtokens = response.usage.input_tokens
outputtokens = response.usage.output_tokens
print(resp, flush=True)
print(f"Input Tokens: {inputtokens}", flush=True)
print(f"Output Tokens: {outputtokens}", flush=True)
print("Send response...", flush=True)
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
self.send(r, properties={"id": id})
print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-m', '--model',
default="claude-3-5-sonnet-20240620",
help=f'LLM model (default: claude-3-5-sonnet-20240620)'
)
parser.add_argument(
'-k', '--api-key',
help=f'Claude API key'
)
parser.add_argument(
'-t', '--temperature',
type=float,
default=default_temperature,
help=f'LLM temperature parameter (default: {default_temperature})'
)
parser.add_argument(
'-x', '--max-output',
type=int,
default=default_max_output,
help=f'LLM max output tokens (default: {default_max_output})'
)
def run():
Processor.start(module, __doc__)

View file

@ -1,3 +0,0 @@
from . llm import *

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python3
from . llm import run
if __name__ == '__main__':
run()

View file

@ -1,179 +0,0 @@
"""
Simple LLM service, performs text prompt completion using Cohere.
Input is prompt, output is response.
"""
import cohere
from prometheus_client import Histogram
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
from .... schema import text_completion_request_queue
from .... schema import text_completion_response_queue
from .... log_level import LogLevel
from .... base import ConsumerProducer
from .... exceptions import TooManyRequests
module = ".".join(__name__.split(".")[1:-1])
default_input_queue = text_completion_request_queue
default_output_queue = text_completion_response_queue
default_subscriber = module
default_model = 'c4ai-aya-23-8b'
default_temperature = 0.0
class Processor(ConsumerProducer):
def __init__(self, **params):
input_queue = params.get("input_queue", default_input_queue)
output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber)
model = params.get("model", default_model)
api_key = params.get("api_key")
temperature = params.get("temperature", default_temperature)
super(Processor, self).__init__(
**params | {
"input_queue": input_queue,
"output_queue": output_queue,
"subscriber": subscriber,
"input_schema": TextCompletionRequest,
"output_schema": TextCompletionResponse,
"model": model,
"temperature": temperature,
}
)
if not hasattr(__class__, "text_completion_metric"):
__class__.text_completion_metric = Histogram(
'text_completion_duration',
'Text completion duration (seconds)',
buckets=[
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
120.0
]
)
self.model = model
self.temperature = temperature
self.cohere = cohere.Client(api_key=api_key)
print("Initialised", flush=True)
def handle(self, msg):
v = msg.value()
# Sender-produced ID
id = msg.properties()["id"]
print(f"Handling prompt {id}...", flush=True)
prompt = v.prompt
try:
with __class__.text_completion_metric.time():
output = self.cohere.chat(
model=self.model,
message=prompt,
preamble = "You are a helpful AI-assistant.",
temperature=self.temperature,
chat_history=[],
prompt_truncation='auto',
connectors=[]
)
resp = output.text
inputtokens = int(output.meta.billed_units.input_tokens)
outputtokens = int(output.meta.billed_units.output_tokens)
print(resp, flush=True)
print(f"Input Tokens: {inputtokens}", flush=True)
print(f"Output Tokens: {outputtokens}", flush=True)
print("Send response...", flush=True)
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
self.send(r, properties={"id": id})
print("Done.", flush=True)
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except TooManyRequests:
print("Send rate limit response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "rate-limit",
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
except Exception as e:
print(f"Exception: {e}")
print("Send error response...", flush=True)
r = TextCompletionResponse(
error=Error(
type = "llm-error",
message = str(e),
),
response=None,
in_token=None,
out_token=None,
model=None,
)
self.producer.send(r, properties={"id": id})
self.consumer.acknowledge(msg)
@staticmethod
def add_args(parser):
ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
default_output_queue,
)
parser.add_argument(
'-m', '--model',
default="c4ai-aya-23-8b",
help=f'Cohere model (default: c4ai-aya-23-8b)'
)
parser.add_argument(
'-k', '--api-key',
help=f'Cohere API key'
)
parser.add_argument(
'-t', '--temperature',
type=float,
default=default_temperature,
help=f'LLM temperature parameter (default: {default_temperature})'
)
def run():
Processor.start(module, __doc__)

View file

@ -1,3 +0,0 @@
from . llm import *

Some files were not shown because too many files have changed in this diff Show more