mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-07 05:45:13 +02:00
Feature/subpackages (#80)
* Renaming what will become the core package * Tweaking to get package build working * Fix metering merge * Rename to core directory * Bump version. Use namespace searching for packaging trustgraph-core * Change references to trustgraph-core * Forming embeddings-hf package * Reference modules in core package. * Build both packages to one container, bump version * Update YAMLs
This commit is contained in:
parent
14d79ef9f1
commit
f081933217
303 changed files with 681 additions and 624 deletions
|
|
@ -1,6 +0,0 @@
|
|||
|
||||
from . base_processor import BaseProcessor
|
||||
from . consumer import Consumer
|
||||
from . producer import Producer
|
||||
from . consumer_producer import ConsumerProducer
|
||||
|
||||
|
|
@ -1,119 +0,0 @@
|
|||
|
||||
import os
|
||||
import argparse
|
||||
import pulsar
|
||||
import _pulsar
|
||||
import time
|
||||
from prometheus_client import start_http_server, Info
|
||||
|
||||
from .. log_level import LogLevel
|
||||
|
||||
class BaseProcessor:
|
||||
|
||||
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
self.client = None
|
||||
|
||||
if not hasattr(__class__, "params_metric"):
|
||||
__class__.params_metric = Info(
|
||||
'params', 'Parameters configuration'
|
||||
)
|
||||
|
||||
# FIXME: Maybe outputs information it should not
|
||||
__class__.params_metric.info({
|
||||
k: str(params[k])
|
||||
for k in params
|
||||
})
|
||||
|
||||
pulsar_host = params.get("pulsar_host", self.default_pulsar_host)
|
||||
log_level = params.get("log_level", LogLevel.INFO)
|
||||
|
||||
self.pulsar_host = pulsar_host
|
||||
|
||||
self.client = pulsar.Client(
|
||||
pulsar_host,
|
||||
logger=pulsar.ConsoleLogger(log_level.to_pulsar())
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
|
||||
if self.client:
|
||||
self.client.close()
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
parser.add_argument(
|
||||
'-p', '--pulsar-host',
|
||||
default=__class__.default_pulsar_host,
|
||||
help=f'Pulsar host (default: {__class__.default_pulsar_host})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-l', '--log-level',
|
||||
type=LogLevel,
|
||||
default=LogLevel.INFO,
|
||||
choices=list(LogLevel),
|
||||
help=f'Output queue (default: info)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--metrics',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help=f'Metrics enabled (default: true)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-P', '--metrics-port',
|
||||
type=int,
|
||||
default=8000,
|
||||
help=f'Pulsar host (default: 8000)',
|
||||
)
|
||||
|
||||
def run(self):
|
||||
raise RuntimeError("Something should have implemented the run method")
|
||||
|
||||
@classmethod
|
||||
def start(cls, prog, doc):
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog=prog,
|
||||
description=doc
|
||||
)
|
||||
|
||||
cls.add_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
args = vars(args)
|
||||
|
||||
print(args)
|
||||
|
||||
if args["metrics"]:
|
||||
start_http_server(args["metrics_port"])
|
||||
|
||||
while True:
|
||||
|
||||
try:
|
||||
|
||||
p = cls(**args)
|
||||
p.run()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Keyboard interrupt.")
|
||||
return
|
||||
|
||||
except _pulsar.Interrupted:
|
||||
print("Pulsar Interrupted.")
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(type(e))
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
print("Will retry...", flush=True)
|
||||
|
||||
time.sleep(4)
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
|
||||
from pulsar.schema import JsonSchema
|
||||
from prometheus_client import Histogram, Info, Counter, Enum
|
||||
import time
|
||||
|
||||
from . base_processor import BaseProcessor
|
||||
from .. exceptions import TooManyRequests
|
||||
|
||||
class Consumer(BaseProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
if not hasattr(__class__, "state_metric"):
|
||||
__class__.state_metric = Enum(
|
||||
'processor_state', 'Processor state',
|
||||
states=['starting', 'running', 'stopped']
|
||||
)
|
||||
__class__.state_metric.state('starting')
|
||||
|
||||
__class__.state_metric.state('starting')
|
||||
|
||||
super(Consumer, self).__init__(**params)
|
||||
|
||||
input_queue = params.get("input_queue")
|
||||
subscriber = params.get("subscriber")
|
||||
input_schema = params.get("input_schema")
|
||||
|
||||
if input_schema == None:
|
||||
raise RuntimeError("input_schema must be specified")
|
||||
|
||||
if not hasattr(__class__, "request_metric"):
|
||||
__class__.request_metric = Histogram(
|
||||
'request_latency', 'Request latency (seconds)'
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "pubsub_metric"):
|
||||
__class__.pubsub_metric = Info(
|
||||
'pubsub', 'Pub/sub configuration'
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "processing_metric"):
|
||||
__class__.processing_metric = Counter(
|
||||
'processing_count', 'Processing count', ["status"]
|
||||
)
|
||||
|
||||
__class__.pubsub_metric.info({
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": input_schema.__name__,
|
||||
})
|
||||
|
||||
self.consumer = self.client.subscribe(
|
||||
input_queue, subscriber,
|
||||
schema=JsonSchema(input_schema),
|
||||
)
|
||||
|
||||
def run(self):
|
||||
|
||||
__class__.state_metric.state('running')
|
||||
|
||||
while True:
|
||||
|
||||
msg = self.consumer.receive()
|
||||
|
||||
try:
|
||||
|
||||
with __class__.request_metric.time():
|
||||
self.handle(msg)
|
||||
|
||||
# Acknowledge successful processing of the message
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
__class__.processing_metric.labels(status="success").inc()
|
||||
|
||||
except TooManyRequests:
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
print("TooManyRequests: will retry")
|
||||
__class__.processing_metric.labels(status="rate-limit").inc()
|
||||
time.sleep(5)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
# Message failed to be processed
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
|
||||
__class__.processing_metric.labels(status="error").inc()
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser, default_input_queue, default_subscriber):
|
||||
|
||||
BaseProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-i', '--input-queue',
|
||||
default=default_input_queue,
|
||||
help=f'Input queue (default: {default_input_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-s', '--subscriber',
|
||||
default=default_subscriber,
|
||||
help=f'Queue subscriber name (default: {default_subscriber})'
|
||||
)
|
||||
|
||||
|
|
@ -1,139 +0,0 @@
|
|||
|
||||
from pulsar.schema import JsonSchema
|
||||
from prometheus_client import Histogram, Info, Counter, Enum
|
||||
import time
|
||||
|
||||
from . base_processor import BaseProcessor
|
||||
from .. exceptions import TooManyRequests
|
||||
|
||||
# FIXME: Derive from consumer? And producer?
|
||||
|
||||
class ConsumerProducer(BaseProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
if not hasattr(__class__, "state_metric"):
|
||||
__class__.state_metric = Enum(
|
||||
'processor_state', 'Processor state',
|
||||
states=['starting', 'running', 'stopped']
|
||||
)
|
||||
__class__.state_metric.state('starting')
|
||||
|
||||
__class__.state_metric.state('starting')
|
||||
|
||||
input_queue = params.get("input_queue")
|
||||
output_queue = params.get("output_queue")
|
||||
subscriber = params.get("subscriber")
|
||||
input_schema = params.get("input_schema")
|
||||
output_schema = params.get("output_schema")
|
||||
|
||||
if not hasattr(__class__, "request_metric"):
|
||||
__class__.request_metric = Histogram(
|
||||
'request_latency', 'Request latency (seconds)'
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "output_metric"):
|
||||
__class__.output_metric = Counter(
|
||||
'output_count', 'Output items created'
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "pubsub_metric"):
|
||||
__class__.pubsub_metric = Info(
|
||||
'pubsub', 'Pub/sub configuration'
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "processing_metric"):
|
||||
__class__.processing_metric = Counter(
|
||||
'processing_count', 'Processing count', ["status"]
|
||||
)
|
||||
|
||||
__class__.pubsub_metric.info({
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": input_schema.__name__,
|
||||
"output_schema": output_schema.__name__,
|
||||
})
|
||||
|
||||
super(ConsumerProducer, self).__init__(**params)
|
||||
|
||||
if input_schema == None:
|
||||
raise RuntimeError("input_schema must be specified")
|
||||
|
||||
if output_schema == None:
|
||||
raise RuntimeError("output_schema must be specified")
|
||||
|
||||
self.producer = self.client.create_producer(
|
||||
topic=output_queue,
|
||||
schema=JsonSchema(output_schema),
|
||||
)
|
||||
|
||||
self.consumer = self.client.subscribe(
|
||||
input_queue, subscriber,
|
||||
schema=JsonSchema(input_schema),
|
||||
)
|
||||
|
||||
def run(self):
|
||||
|
||||
__class__.state_metric.state('running')
|
||||
|
||||
while True:
|
||||
|
||||
msg = self.consumer.receive()
|
||||
|
||||
try:
|
||||
|
||||
with __class__.request_metric.time():
|
||||
resp = self.handle(msg)
|
||||
|
||||
# Acknowledge successful processing of the message
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
__class__.processing_metric.labels(status="success").inc()
|
||||
|
||||
except TooManyRequests:
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
print("TooManyRequests: will retry")
|
||||
__class__.processing_metric.labels(status="rate-limit").inc()
|
||||
time.sleep(5)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
# Message failed to be processed
|
||||
self.consumer.negative_acknowledge(msg)
|
||||
|
||||
__class__.processing_metric.labels(status="error").inc()
|
||||
|
||||
def send(self, msg, properties={}):
|
||||
self.producer.send(msg, properties)
|
||||
__class__.output_metric.inc()
|
||||
|
||||
@staticmethod
|
||||
def add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
):
|
||||
|
||||
BaseProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-i', '--input-queue',
|
||||
default=default_input_queue,
|
||||
help=f'Input queue (default: {default_input_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-s', '--subscriber',
|
||||
default=default_subscriber,
|
||||
help=f'Queue subscriber name (default: {default_subscriber})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-o', '--output-queue',
|
||||
default=default_output_queue,
|
||||
help=f'Output queue (default: {default_output_queue})'
|
||||
)
|
||||
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
|
||||
from pulsar.schema import JsonSchema
|
||||
from prometheus_client import Info, Counter
|
||||
|
||||
from . base_processor import BaseProcessor
|
||||
|
||||
class Producer(BaseProcessor):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
output_queue = params.get("output_queue")
|
||||
output_schema = params.get("output_schema")
|
||||
|
||||
if not hasattr(__class__, "output_metric"):
|
||||
__class__.output_metric = Counter(
|
||||
'output_count', 'Output items created'
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "pubsub_metric"):
|
||||
__class__.pubsub_metric = Info(
|
||||
'pubsub', 'Pub/sub configuration'
|
||||
)
|
||||
|
||||
__class__.pubsub_metric.info({
|
||||
"output_queue": output_queue,
|
||||
"output_schema": output_schema.__name__,
|
||||
})
|
||||
|
||||
super(Producer, self).__init__(**params)
|
||||
|
||||
if output_schema == None:
|
||||
raise RuntimeError("output_schema must be specified")
|
||||
|
||||
self.producer = self.client.create_producer(
|
||||
topic=output_queue,
|
||||
schema=JsonSchema(output_schema),
|
||||
)
|
||||
|
||||
def send(self, msg, properties={}):
|
||||
self.producer.send(msg, properties)
|
||||
__class__.output_metric.inc()
|
||||
|
||||
@staticmethod
|
||||
def add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
):
|
||||
|
||||
BaseProcessor.add_args(parser)
|
||||
|
||||
parser.add_argument(
|
||||
'-o', '--output-queue',
|
||||
default=default_output_queue,
|
||||
help=f'Output queue (default: {default_output_queue})'
|
||||
)
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . chunker import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . chunker import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,108 +0,0 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts text documents on input, outputs chunks from the
|
||||
as text as separate output objects.
|
||||
"""
|
||||
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk, Source
|
||||
from ... schema import text_ingest_queue, chunk_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_ingest_queue
|
||||
default_output_queue = chunk_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
chunk_size = params.get("chunk_size", 2000)
|
||||
chunk_overlap = params.get("chunk_overlap", 100)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextDocument,
|
||||
"output_schema": Chunk,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
'chunk_size', 'Chunk size',
|
||||
buckets=[100, 160, 250, 400, 650, 1000, 1600,
|
||||
2500, 4000, 6400, 10000, 16000]
|
||||
)
|
||||
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
is_separator_regex=False,
|
||||
)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Chunking {v.source.id}...", flush=True)
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
)
|
||||
|
||||
for ix, chunk in enumerate(texts):
|
||||
|
||||
id = v.source.id + "-c" + str(ix)
|
||||
|
||||
r = Chunk(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
id=id,
|
||||
title=v.source.title
|
||||
),
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
)
|
||||
|
||||
__class__.chunk_metric.observe(len(chunk.page_content))
|
||||
|
||||
self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--chunk-size',
|
||||
type=int,
|
||||
default=2000,
|
||||
help=f'Chunk size (default: 2000)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-v', '--chunk-overlap',
|
||||
type=int,
|
||||
default=100,
|
||||
help=f'Chunk overlap (default: 100)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . chunker import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . chunker import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,107 +0,0 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts text documents on input, outputs chunks from the
|
||||
as text as separate output objects.
|
||||
"""
|
||||
|
||||
from langchain_text_splitters import TokenTextSplitter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from ... schema import TextDocument, Chunk, Source
|
||||
from ... schema import text_ingest_queue, chunk_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_ingest_queue
|
||||
default_output_queue = chunk_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
chunk_size = params.get("chunk_size", 250)
|
||||
chunk_overlap = params.get("chunk_overlap", 15)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextDocument,
|
||||
"output_schema": Chunk,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
'chunk_size', 'Chunk size',
|
||||
buckets=[100, 160, 250, 400, 650, 1000, 1600,
|
||||
2500, 4000, 6400, 10000, 16000]
|
||||
)
|
||||
|
||||
self.text_splitter = TokenTextSplitter(
|
||||
encoding_name="cl100k_base",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Chunking {v.source.id}...", flush=True)
|
||||
|
||||
texts = self.text_splitter.create_documents(
|
||||
[v.text.decode("utf-8")]
|
||||
)
|
||||
|
||||
for ix, chunk in enumerate(texts):
|
||||
|
||||
id = v.source.id + "-c" + str(ix)
|
||||
|
||||
r = Chunk(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
id=id,
|
||||
title=v.source.title
|
||||
),
|
||||
chunk=chunk.page_content.encode("utf-8"),
|
||||
)
|
||||
|
||||
__class__.chunk_metric.observe(len(chunk.page_content))
|
||||
|
||||
self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--chunk-size',
|
||||
type=int,
|
||||
default=250,
|
||||
help=f'Chunk size (default: 250)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-v', '--chunk-overlap',
|
||||
type=int,
|
||||
default=15,
|
||||
help=f'Chunk overlap (default: 15)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,125 +0,0 @@
|
|||
|
||||
import pulsar
|
||||
import _pulsar
|
||||
import hashlib
|
||||
import uuid
|
||||
import time
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .. exceptions import *
|
||||
|
||||
# Default timeout for a request/response. In seconds.
|
||||
DEFAULT_TIMEOUT=300
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class BaseClient:
|
||||
|
||||
def __init__(
|
||||
self, log_level=ERROR,
|
||||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
input_schema=None,
|
||||
output_schema=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
):
|
||||
|
||||
if input_queue == None: raise RuntimeError("Need input_queue")
|
||||
if output_queue == None: raise RuntimeError("Need output_queue")
|
||||
if input_schema == None: raise RuntimeError("Need input_schema")
|
||||
if output_schema == None: raise RuntimeError("Need output_schema")
|
||||
|
||||
if subscriber == None:
|
||||
subscriber = str(uuid.uuid4())
|
||||
|
||||
self.client = pulsar.Client(
|
||||
pulsar_host,
|
||||
logger=pulsar.ConsoleLogger(log_level),
|
||||
)
|
||||
|
||||
self.producer = self.client.create_producer(
|
||||
topic=input_queue,
|
||||
schema=JsonSchema(input_schema),
|
||||
chunking_enabled=True,
|
||||
)
|
||||
|
||||
self.consumer = self.client.subscribe(
|
||||
output_queue, subscriber,
|
||||
schema=JsonSchema(output_schema),
|
||||
)
|
||||
|
||||
self.input_schema = input_schema
|
||||
self.output_schema = output_schema
|
||||
|
||||
def call(self, **args):
|
||||
|
||||
timeout = args.get("timeout", DEFAULT_TIMEOUT)
|
||||
|
||||
if "timeout" in args:
|
||||
del args["timeout"]
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
r = self.input_schema(**args)
|
||||
|
||||
end_time = time.time() + timeout
|
||||
|
||||
self.producer.send(r, properties={ "id": id })
|
||||
|
||||
while time.time() < end_time:
|
||||
|
||||
try:
|
||||
msg = self.consumer.receive(timeout_millis=2500)
|
||||
except pulsar.exceptions.Timeout:
|
||||
continue
|
||||
|
||||
mid = msg.properties()["id"]
|
||||
|
||||
if mid == id:
|
||||
|
||||
value = msg.value()
|
||||
|
||||
if value.error:
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
if value.error.type == "llm-error":
|
||||
raise LlmError(value.error.message)
|
||||
|
||||
elif value.error.type == "too-many-requests":
|
||||
raise TooManyRequests(value.error.message)
|
||||
|
||||
elif value.error.type == "ParseError":
|
||||
raise ParseError(value.error.message)
|
||||
|
||||
else:
|
||||
|
||||
raise RuntimeError(
|
||||
f"{value.error.type}: {value.error.message}"
|
||||
)
|
||||
|
||||
resp = msg.value()
|
||||
self.consumer.acknowledge(msg)
|
||||
return resp
|
||||
|
||||
# Ignore messages with wrong ID
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
raise TimeoutError("Timed out waiting for response")
|
||||
|
||||
def __del__(self):
|
||||
|
||||
if hasattr(self, "consumer"):
|
||||
self.consumer.close()
|
||||
|
||||
if hasattr(self, "producer"):
|
||||
self.producer.flush()
|
||||
self.producer.close()
|
||||
|
||||
self.client.close()
|
||||
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||
from .. schema import document_embeddings_request_queue
|
||||
from .. schema import document_embeddings_response_queue
|
||||
from . base import BaseClient
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class DocumentEmbeddingsClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
self, log_level=ERROR,
|
||||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
):
|
||||
|
||||
if input_queue == None:
|
||||
input_queue = document_embeddings_request_queue
|
||||
|
||||
if output_queue == None:
|
||||
output_queue = document_embeddings_response_queue
|
||||
|
||||
super(DocumentEmbeddingsClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
pulsar_host=pulsar_host,
|
||||
input_schema=DocumentEmbeddingsRequest,
|
||||
output_schema=DocumentEmbeddingsResponse,
|
||||
)
|
||||
|
||||
def request(self, vectors, limit=10, timeout=300):
|
||||
return self.call(
|
||||
vectors=vectors, limit=limit, timeout=timeout
|
||||
).documents
|
||||
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import DocumentRagQuery, DocumentRagResponse
|
||||
from .. schema import document_rag_request_queue, document_rag_response_queue
|
||||
from . base import BaseClient
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class DocumentRagClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_level=ERROR,
|
||||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
):
|
||||
|
||||
if input_queue == None:
|
||||
input_queue = document_rag_request_queue
|
||||
|
||||
if output_queue == None:
|
||||
output_queue = document_rag_response_queue
|
||||
|
||||
super(DocumentRagClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
pulsar_host=pulsar_host,
|
||||
input_schema=DocumentRagQuery,
|
||||
output_schema=DocumentRagResponse,
|
||||
)
|
||||
|
||||
def request(self, query, timeout=500):
|
||||
|
||||
return self.call(
|
||||
query=query, timeout=timeout
|
||||
).response
|
||||
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
|
||||
from pulsar.schema import JsonSchema
|
||||
from .. schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from .. schema import embeddings_request_queue, embeddings_response_queue
|
||||
from . base import BaseClient
|
||||
|
||||
import _pulsar
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class EmbeddingsClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
self, log_level=ERROR,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
subscriber=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
):
|
||||
|
||||
if input_queue == None:
|
||||
input_queue=embeddings_request_queue
|
||||
|
||||
if output_queue == None:
|
||||
output_queue=embeddings_response_queue
|
||||
|
||||
super(EmbeddingsClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
pulsar_host=pulsar_host,
|
||||
input_schema=EmbeddingsRequest,
|
||||
output_schema=EmbeddingsResponse,
|
||||
)
|
||||
|
||||
def request(self, text, timeout=300):
|
||||
return self.call(text=text, timeout=timeout).vectors
|
||||
|
||||
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from .. schema import graph_embeddings_request_queue
|
||||
from .. schema import graph_embeddings_response_queue
|
||||
from . base import BaseClient
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class GraphEmbeddingsClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
self, log_level=ERROR,
|
||||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
):
|
||||
|
||||
if input_queue == None:
|
||||
input_queue = graph_embeddings_request_queue
|
||||
|
||||
if output_queue == None:
|
||||
output_queue = graph_embeddings_response_queue
|
||||
|
||||
super(GraphEmbeddingsClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
pulsar_host=pulsar_host,
|
||||
input_schema=GraphEmbeddingsRequest,
|
||||
output_schema=GraphEmbeddingsResponse,
|
||||
)
|
||||
|
||||
def request(self, vectors, limit=10, timeout=300):
|
||||
return self.call(
|
||||
vectors=vectors, limit=limit, timeout=timeout
|
||||
).entities
|
||||
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import GraphRagQuery, GraphRagResponse
|
||||
from .. schema import graph_rag_request_queue, graph_rag_response_queue
|
||||
from . base import BaseClient
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class GraphRagClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_level=ERROR,
|
||||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
):
|
||||
|
||||
if input_queue == None:
|
||||
input_queue = graph_rag_request_queue
|
||||
|
||||
if output_queue == None:
|
||||
output_queue = graph_rag_response_queue
|
||||
|
||||
super(GraphRagClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
pulsar_host=pulsar_host,
|
||||
input_schema=GraphRagQuery,
|
||||
output_schema=GraphRagResponse,
|
||||
)
|
||||
|
||||
def request(self, query, timeout=500):
|
||||
|
||||
return self.call(
|
||||
query=query, timeout=timeout
|
||||
).response
|
||||
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import TextCompletionRequest, TextCompletionResponse
|
||||
from .. schema import text_completion_request_queue
|
||||
from .. schema import text_completion_response_queue
|
||||
from . base import BaseClient
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class LlmClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
self, log_level=ERROR,
|
||||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
):
|
||||
|
||||
if input_queue is None: input_queue = text_completion_request_queue
|
||||
if output_queue is None: output_queue = text_completion_response_queue
|
||||
|
||||
super(LlmClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
pulsar_host=pulsar_host,
|
||||
input_schema=TextCompletionRequest,
|
||||
output_schema=TextCompletionResponse,
|
||||
)
|
||||
|
||||
def request(self, prompt, timeout=300):
|
||||
return self.call(prompt=prompt, timeout=timeout).response
|
||||
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import PromptRequest, PromptResponse, Fact, RowSchema, Field
|
||||
from .. schema import prompt_request_queue
|
||||
from .. schema import prompt_response_queue
|
||||
from . base import BaseClient
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class PromptClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
self, log_level=ERROR,
|
||||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
):
|
||||
|
||||
if input_queue == None:
|
||||
input_queue = prompt_request_queue
|
||||
|
||||
if output_queue == None:
|
||||
output_queue = prompt_response_queue
|
||||
|
||||
super(PromptClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
pulsar_host=pulsar_host,
|
||||
input_schema=PromptRequest,
|
||||
output_schema=PromptResponse,
|
||||
)
|
||||
|
||||
def request_definitions(self, chunk, timeout=300):
|
||||
|
||||
return self.call(
|
||||
kind="extract-definitions", chunk=chunk,
|
||||
timeout=timeout
|
||||
).definitions
|
||||
|
||||
def request_topics(self, chunk, timeout=300):
|
||||
|
||||
return self.call(
|
||||
kind="extract-topics", chunk=chunk,
|
||||
timeout=timeout
|
||||
).topics
|
||||
|
||||
def request_relationships(self, chunk, timeout=300):
|
||||
|
||||
return self.call(
|
||||
kind="extract-relationships", chunk=chunk,
|
||||
timeout=timeout
|
||||
).relationships
|
||||
|
||||
def request_rows(self, schema, chunk, timeout=300):
|
||||
|
||||
return self.call(
|
||||
kind="extract-rows", chunk=chunk,
|
||||
row_schema=RowSchema(
|
||||
name=schema.name,
|
||||
description=schema.description,
|
||||
fields=[
|
||||
Field(
|
||||
name=f.name, type=str(f.type), size=f.size,
|
||||
primary=f.primary, description=f.description,
|
||||
)
|
||||
for f in schema.fields
|
||||
]
|
||||
),
|
||||
timeout=timeout
|
||||
).rows
|
||||
|
||||
def request_kg_prompt(self, query, kg, timeout=300):
|
||||
|
||||
return self.call(
|
||||
kind="kg-prompt",
|
||||
query=query,
|
||||
kg=[
|
||||
Fact(s=v[0], p=v[1], o=v[2])
|
||||
for v in kg
|
||||
],
|
||||
timeout=timeout
|
||||
).answer
|
||||
|
||||
def request_document_prompt(self, query, documents, timeout=300):
|
||||
|
||||
return self.call(
|
||||
kind="document-prompt",
|
||||
query=query,
|
||||
documents=documents,
|
||||
timeout=timeout
|
||||
).answer
|
||||
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import _pulsar
|
||||
|
||||
from .. schema import TriplesQueryRequest, TriplesQueryResponse, Value
|
||||
from .. schema import triples_request_queue
|
||||
from .. schema import triples_response_queue
|
||||
from . base import BaseClient
|
||||
|
||||
# Ugly
|
||||
ERROR=_pulsar.LoggerLevel.Error
|
||||
WARN=_pulsar.LoggerLevel.Warn
|
||||
INFO=_pulsar.LoggerLevel.Info
|
||||
DEBUG=_pulsar.LoggerLevel.Debug
|
||||
|
||||
class TriplesQueryClient(BaseClient):
|
||||
|
||||
def __init__(
|
||||
self, log_level=ERROR,
|
||||
subscriber=None,
|
||||
input_queue=None,
|
||||
output_queue=None,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
):
|
||||
|
||||
if input_queue == None:
|
||||
input_queue = triples_request_queue
|
||||
|
||||
if output_queue == None:
|
||||
output_queue = triples_response_queue
|
||||
|
||||
super(TriplesQueryClient, self).__init__(
|
||||
log_level=log_level,
|
||||
subscriber=subscriber,
|
||||
input_queue=input_queue,
|
||||
output_queue=output_queue,
|
||||
pulsar_host=pulsar_host,
|
||||
input_schema=TriplesQueryRequest,
|
||||
output_schema=TriplesQueryResponse,
|
||||
)
|
||||
|
||||
def create_value(self, ent):
|
||||
|
||||
if ent == None: return None
|
||||
|
||||
if ent.startswith("http://") or ent.startswith("https://"):
|
||||
return Value(value=ent, is_uri=True)
|
||||
|
||||
return Value(value=ent, is_uri=False)
|
||||
|
||||
def request(self, s, p, o, limit=10, timeout=60):
|
||||
return self.call(
|
||||
s=self.create_value(s),
|
||||
p=self.create_value(p),
|
||||
o=self.create_value(o),
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
).triples
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . pdf_decoder import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . pdf_decoder import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts PDF documents on input, outputs pages from the
|
||||
PDF document as text as separate output objects.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import base64
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
|
||||
from ... schema import Document, TextDocument, Source
|
||||
from ... schema import document_ingest_queue, text_ingest_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = document_ingest_queue
|
||||
default_output_queue = text_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Document,
|
||||
"output_schema": TextDocument,
|
||||
}
|
||||
)
|
||||
|
||||
print("PDF inited")
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
print("PDF message received")
|
||||
|
||||
v = msg.value()
|
||||
|
||||
print(f"Decoding {v.source.id}...", flush=True)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete_on_close=False) as fp:
|
||||
|
||||
fp.write(base64.b64decode(v.data))
|
||||
fp.close()
|
||||
|
||||
with open(fp.name, mode='rb') as f:
|
||||
|
||||
loader = PyPDFLoader(fp.name)
|
||||
pages = loader.load()
|
||||
|
||||
for ix, page in enumerate(pages):
|
||||
|
||||
id = v.source.id + "-p" + str(ix)
|
||||
r = TextDocument(
|
||||
source=Source(
|
||||
source=v.source.source,
|
||||
title=v.source.title,
|
||||
id=id,
|
||||
),
|
||||
text=page.page_content.encode("utf-8"),
|
||||
)
|
||||
|
||||
self.send(r)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,108 +0,0 @@
|
|||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.auth import PlainTextAuthProvider
|
||||
|
||||
class TrustGraph:
|
||||
|
||||
def __init__(self, hosts=None):
|
||||
|
||||
if hosts is None:
|
||||
hosts = ["localhost"]
|
||||
|
||||
self.cluster = Cluster(hosts)
|
||||
self.session = self.cluster.connect()
|
||||
|
||||
self.init()
|
||||
|
||||
def clear(self):
|
||||
|
||||
self.session.execute("""
|
||||
drop keyspace if exists trustgraph;
|
||||
""");
|
||||
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
|
||||
self.session.execute("""
|
||||
create keyspace if not exists trustgraph
|
||||
with replication = {
|
||||
'class' : 'SimpleStrategy',
|
||||
'replication_factor' : 1
|
||||
};
|
||||
""");
|
||||
|
||||
self.session.set_keyspace('trustgraph')
|
||||
|
||||
self.session.execute("""
|
||||
create table if not exists triples (
|
||||
s text,
|
||||
p text,
|
||||
o text,
|
||||
PRIMARY KEY (s, p, o)
|
||||
);
|
||||
""");
|
||||
|
||||
self.session.execute("""
|
||||
create index if not exists triples_p
|
||||
ON triples (p);
|
||||
""");
|
||||
|
||||
self.session.execute("""
|
||||
create index if not exists triples_o
|
||||
ON triples (o);
|
||||
""");
|
||||
|
||||
def insert(self, s, p, o):
|
||||
|
||||
self.session.execute(
|
||||
"insert into triples (s, p, o) values (%s, %s, %s)",
|
||||
(s, p, o)
|
||||
)
|
||||
|
||||
def get_all(self, limit=50):
|
||||
return self.session.execute(
|
||||
f"select s, p, o from triples limit {limit}"
|
||||
)
|
||||
|
||||
def get_s(self, s, limit=10):
|
||||
return self.session.execute(
|
||||
f"select p, o from triples where s = %s limit {limit}",
|
||||
(s,)
|
||||
)
|
||||
|
||||
def get_p(self, p, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s, o from triples where p = %s limit {limit}",
|
||||
(p,)
|
||||
)
|
||||
|
||||
def get_o(self, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s, p from triples where o = %s limit {limit}",
|
||||
(o,)
|
||||
)
|
||||
|
||||
def get_sp(self, s, p, limit=10):
|
||||
return self.session.execute(
|
||||
f"select o from triples where s = %s and p = %s limit {limit}",
|
||||
(s, p)
|
||||
)
|
||||
|
||||
def get_po(self, p, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"select s from triples where p = %s and o = %s allow filtering limit {limit}",
|
||||
(p, o)
|
||||
)
|
||||
|
||||
def get_os(self, o, s, limit=10):
|
||||
return self.session.execute(
|
||||
f"select p from triples where o = %s and s = %s limit {limit}",
|
||||
(o, s)
|
||||
)
|
||||
|
||||
def get_spo(self, s, p, o, limit=10):
|
||||
return self.session.execute(
|
||||
f"""select s as x from triples where s = %s and p = %s and o = %s limit {limit}""",
|
||||
(s, p, o)
|
||||
)
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
|
||||
class DocVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='doc'):
|
||||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
# Time between reloads
|
||||
self.reload_time = 90
|
||||
|
||||
# Next time to reload - this forces a reload at next window
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
print("Reload at", self.next_reload)
|
||||
|
||||
def init_collection(self, dimension):
|
||||
|
||||
collection_name = self.prefix + "_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
doc_field = FieldSchema(
|
||||
name="doc",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [pkey_field, vec_field, doc_field],
|
||||
description = "Document embedding schema",
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[dimension] = collection_name
|
||||
|
||||
def insert(self, embeds, doc):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"doc": doc,
|
||||
}
|
||||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[dim],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, fields=["doc"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
coll = self.collections[dim]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
print("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
print("Searching...")
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
|
||||
# If reload time has passed, unload collection
|
||||
if time.time() > self.next_reload:
|
||||
print("Unloading, reload at", self.next_reload)
|
||||
self.client.release_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
|
||||
return res
|
||||
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
|
||||
class EntityVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='entity'):
|
||||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
# Time between reloads
|
||||
self.reload_time = 90
|
||||
|
||||
# Next time to reload - this forces a reload at next window
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
print("Reload at", self.next_reload)
|
||||
|
||||
def init_collection(self, dimension):
|
||||
|
||||
collection_name = self.prefix + "_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
entity_field = FieldSchema(
|
||||
name="entity",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [pkey_field, vec_field, entity_field],
|
||||
description = "Graph embedding schema",
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[dimension] = collection_name
|
||||
|
||||
def insert(self, embeds, entity):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"entity": entity,
|
||||
}
|
||||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[dim],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, fields=["entity"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim)
|
||||
|
||||
coll = self.collections[dim]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
print("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
print("Searching...")
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
|
||||
# If reload time has passed, unload collection
|
||||
if time.time() > self.next_reload:
|
||||
print("Unloading, reload at", self.next_reload)
|
||||
self.client.release_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
|
||||
return res
|
||||
|
||||
|
|
@ -1,154 +0,0 @@
|
|||
|
||||
from pymilvus import MilvusClient, CollectionSchema, FieldSchema, DataType
|
||||
import time
|
||||
|
||||
class ObjectVectors:
|
||||
|
||||
def __init__(self, uri="http://localhost:19530", prefix='obj'):
|
||||
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
# Strategy is to create collections per dimension. Probably only
|
||||
# going to be using 1 anyway, but that means we don't need to
|
||||
# hard-code the dimension anywhere, and no big deal if more than
|
||||
# one are created.
|
||||
self.collections = {}
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
# Time between reloads
|
||||
self.reload_time = 90
|
||||
|
||||
# Next time to reload - this forces a reload at next window
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
print("Reload at", self.next_reload)
|
||||
|
||||
def init_collection(self, dimension, name):
|
||||
|
||||
collection_name = self.prefix + "_" + name + "_" + str(dimension)
|
||||
|
||||
pkey_field = FieldSchema(
|
||||
name="id",
|
||||
dtype=DataType.INT64,
|
||||
is_primary=True,
|
||||
auto_id=True,
|
||||
)
|
||||
|
||||
vec_field = FieldSchema(
|
||||
name="vector",
|
||||
dtype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
)
|
||||
|
||||
name_field = FieldSchema(
|
||||
name="name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
key_name_field = FieldSchema(
|
||||
name="key_name",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
key_field = FieldSchema(
|
||||
name="key",
|
||||
dtype=DataType.VARCHAR,
|
||||
max_length=65535,
|
||||
)
|
||||
|
||||
schema = CollectionSchema(
|
||||
fields = [
|
||||
pkey_field, vec_field, name_field, key_name_field, key_field
|
||||
],
|
||||
description = "Object embedding schema",
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
metric_type="COSINE",
|
||||
)
|
||||
|
||||
index_params = MilvusClient.prepare_index_params()
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
metric_type="COSINE",
|
||||
index_type="IVF_SQ8",
|
||||
index_name="vector_index",
|
||||
params={ "nlist": 128 }
|
||||
)
|
||||
|
||||
self.client.create_index(
|
||||
collection_name=collection_name,
|
||||
index_params=index_params
|
||||
)
|
||||
|
||||
self.collections[(dimension, name)] = collection_name
|
||||
|
||||
def insert(self, embeds, name, key_name, key):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if (dim, name) not in self.collections:
|
||||
self.init_collection(dim, name)
|
||||
|
||||
data = [
|
||||
{
|
||||
"vector": embeds,
|
||||
"name": name,
|
||||
"key_name": key_name,
|
||||
"key": key,
|
||||
}
|
||||
]
|
||||
|
||||
self.client.insert(
|
||||
collection_name=self.collections[(dim, name)],
|
||||
data=data
|
||||
)
|
||||
|
||||
def search(self, embeds, name, fields=["key_name", "name"], limit=10):
|
||||
|
||||
dim = len(embeds)
|
||||
|
||||
if dim not in self.collections:
|
||||
self.init_collection(dim, name)
|
||||
|
||||
coll = self.collections[(dim, name)]
|
||||
|
||||
search_params = {
|
||||
"metric_type": "COSINE",
|
||||
"params": {
|
||||
"radius": 0.1,
|
||||
"range_filter": 0.8
|
||||
}
|
||||
}
|
||||
|
||||
print("Loading...")
|
||||
self.client.load_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
|
||||
print("Searching...")
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=coll,
|
||||
data=[embeds],
|
||||
limit=limit,
|
||||
output_fields=fields,
|
||||
search_params=search_params,
|
||||
)[0]
|
||||
|
||||
|
||||
# If reload time has passed, unload collection
|
||||
if time.time() > self.next_reload:
|
||||
print("Unloading, reload at", self.next_reload)
|
||||
self.client.release_collection(
|
||||
collection_name=coll,
|
||||
)
|
||||
self.next_reload = time.time() + self.reload_time
|
||||
|
||||
return res
|
||||
|
||||
|
|
@ -1,132 +0,0 @@
|
|||
|
||||
from . clients.document_embeddings_client import DocumentEmbeddingsClient
|
||||
from . clients.triples_query_client import TriplesQueryClient
|
||||
from . clients.embeddings_client import EmbeddingsClient
|
||||
from . clients.prompt_client import PromptClient
|
||||
|
||||
from . schema import DocumentEmbeddingsRequest, DocumentEmbeddingsResponse
|
||||
from . schema import TriplesQueryRequest, TriplesQueryResponse
|
||||
from . schema import prompt_request_queue
|
||||
from . schema import prompt_response_queue
|
||||
from . schema import embeddings_request_queue
|
||||
from . schema import embeddings_response_queue
|
||||
from . schema import document_embeddings_request_queue
|
||||
from . schema import document_embeddings_response_queue
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
|
||||
|
||||
class DocumentRag:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pr_request_queue=None,
|
||||
pr_response_queue=None,
|
||||
emb_request_queue=None,
|
||||
emb_response_queue=None,
|
||||
de_request_queue=None,
|
||||
de_response_queue=None,
|
||||
verbose=False,
|
||||
module="test",
|
||||
):
|
||||
|
||||
self.verbose=verbose
|
||||
|
||||
if pr_request_queue is None:
|
||||
pr_request_queue = prompt_request_queue
|
||||
|
||||
if pr_response_queue is None:
|
||||
pr_response_queue = prompt_response_queue
|
||||
|
||||
if emb_request_queue is None:
|
||||
emb_request_queue = embeddings_request_queue
|
||||
|
||||
if emb_response_queue is None:
|
||||
emb_response_queue = embeddings_response_queue
|
||||
|
||||
if de_request_queue is None:
|
||||
de_request_queue = document_embeddings_request_queue
|
||||
|
||||
if de_response_queue is None:
|
||||
de_response_queue = document_embeddings_response_queue
|
||||
|
||||
if self.verbose:
|
||||
print("Initialising...", flush=True)
|
||||
|
||||
# FIXME: Configurable
|
||||
self.entity_limit = 20
|
||||
|
||||
self.de_client = DocumentEmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
subscriber=module + "-de",
|
||||
input_queue=de_request_queue,
|
||||
output_queue=de_response_queue,
|
||||
)
|
||||
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
self.lang = PromptClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber=module + "-de-prompt",
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = self.embeddings.request(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
def get_docs(self, query):
|
||||
|
||||
vectors = self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
docs = self.de_client.request(
|
||||
vectors, self.entity_limit
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Docs:", flush=True)
|
||||
for doc in docs:
|
||||
print(doc, flush=True)
|
||||
|
||||
return docs
|
||||
|
||||
def query(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
docs = self.get_docs(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(docs)
|
||||
print(query)
|
||||
|
||||
resp = self.lang.request_document_prompt(query, docs)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
||||
return resp
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . processor import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . write import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
|
||||
"""
|
||||
Write graph embeddings to parquet files in a directory.
|
||||
"""
|
||||
|
||||
import pulsar
|
||||
import base64
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from .... schema import GraphEmbeddings
|
||||
from .... schema import graph_embeddings_store_queue
|
||||
from .... base import Consumer
|
||||
|
||||
from . writer import ParquetWriter
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = graph_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
default_graph_host='localhost'
|
||||
default_directory = "."
|
||||
default_file_template = "graph-embeds-{id}.parquet"
|
||||
default_rotation_time = 60
|
||||
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
directory = params.get("directory", default_directory)
|
||||
file_template = params.get("file_template", default_file_template)
|
||||
rotation_time = params.get("rotation_time", default_rotation_time)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": GraphEmbeddings,
|
||||
}
|
||||
)
|
||||
|
||||
self.writer = ParquetWriter(directory, file_template, rotation_time)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "writer"):
|
||||
del self.writer
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
self.writer.write(v.vectors, v.entity.value)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-d', '--directory',
|
||||
default=default_directory,
|
||||
help=f'Directory to write to (default: {default_directory})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--file-template',
|
||||
default=default_file_template,
|
||||
help=f'Directory to write to (default: {default_file_template})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--rotation-time',
|
||||
type=int,
|
||||
default=default_rotation_time,
|
||||
help=f'Rotation time / seconds (default: {default_rotation_time})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
import uuid
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
class ParquetWriter:
|
||||
|
||||
def __init__(self, directory, file_template, rotation_time):
|
||||
self.directory = directory
|
||||
self.file_template = file_template
|
||||
self.rotation_time = rotation_time
|
||||
|
||||
self.q = queue.Queue()
|
||||
|
||||
self.running = True
|
||||
|
||||
self.thread = threading.Thread(target=(self.writer_thread))
|
||||
self.thread.start()
|
||||
|
||||
def writer_thread(self):
|
||||
|
||||
items = []
|
||||
|
||||
timeout = None
|
||||
|
||||
while self.running:
|
||||
|
||||
try:
|
||||
|
||||
item = self.q.get(timeout=1)
|
||||
|
||||
if timeout == None:
|
||||
timeout = time.time() + self.rotation_time
|
||||
|
||||
items.append(item)
|
||||
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
if timeout:
|
||||
if time.time() > timeout:
|
||||
|
||||
self.write_file(items)
|
||||
timeout = None
|
||||
items = []
|
||||
|
||||
def write_file(self, items):
|
||||
|
||||
try:
|
||||
|
||||
schema = pa.schema([
|
||||
pa.field('embeddings', pa.list_(pa.list_(pa.float64()))),
|
||||
pa.field('entity', pa.string()),
|
||||
])
|
||||
|
||||
fname = self.file_template.format(id=str(uuid.uuid4()))
|
||||
path = f"{self.directory}/{fname}"
|
||||
|
||||
writer = pq.ParquetWriter(path, schema)
|
||||
|
||||
batch = pa.record_batch(
|
||||
[
|
||||
[i[0] for i in items],
|
||||
[i[1] for i in items],
|
||||
],
|
||||
names=['embeddings', 'entity']
|
||||
)
|
||||
|
||||
writer.write_batch(batch)
|
||||
|
||||
writer.close()
|
||||
|
||||
print(f"Wrote {path}.")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Parquet write:", e)
|
||||
|
||||
def write(self, embeds, ent):
|
||||
self.q.put((embeds, ent))
|
||||
|
||||
def __del__(self):
|
||||
|
||||
self.running = False
|
||||
|
||||
if hasattr(self, "q"):
|
||||
self.thread.join()
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . processor import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . write import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
|
||||
"""
|
||||
Write graphs triples to parquet files in a directory.
|
||||
"""
|
||||
|
||||
import pulsar
|
||||
import base64
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from .... schema import Triple
|
||||
from .... schema import triples_store_queue
|
||||
from .... base import Consumer
|
||||
|
||||
from . writer import ParquetWriter
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
default_graph_host='localhost'
|
||||
default_directory = "."
|
||||
default_file_template = "triples-{id}.parquet"
|
||||
default_rotation_time = 60
|
||||
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
directory = params.get("directory", default_directory)
|
||||
file_template = params.get("file_template", default_file_template)
|
||||
rotation_time = params.get("rotation_time", default_rotation_time)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Triple,
|
||||
}
|
||||
)
|
||||
|
||||
self.writer = ParquetWriter(directory, file_template, rotation_time)
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "writer"):
|
||||
del self.writer
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
self.writer.write(v.s.value, v.p.value, v.o.value)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-d', '--directory',
|
||||
default=default_directory,
|
||||
help=f'Directory to write to (default: {default_directory})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--file-template',
|
||||
default=default_file_template,
|
||||
help=f'Directory to write to (default: {default_file_template})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--rotation-time',
|
||||
type=int,
|
||||
default=default_rotation_time,
|
||||
help=f'Rotation time / seconds (default: {default_rotation_time})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
import uuid
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
class ParquetWriter:
|
||||
|
||||
def __init__(self, directory, file_template, rotation_time):
|
||||
self.directory = directory
|
||||
self.file_template = file_template
|
||||
self.rotation_time = rotation_time
|
||||
|
||||
self.q = queue.Queue()
|
||||
|
||||
self.running = True
|
||||
|
||||
self.thread = threading.Thread(target=(self.writer_thread))
|
||||
self.thread.start()
|
||||
|
||||
def writer_thread(self):
|
||||
|
||||
triples = []
|
||||
|
||||
timeout = None
|
||||
|
||||
while self.running:
|
||||
|
||||
try:
|
||||
|
||||
item = self.q.get(timeout=1)
|
||||
|
||||
if timeout == None:
|
||||
timeout = time.time() + self.rotation_time
|
||||
|
||||
triples.append(item)
|
||||
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
if timeout:
|
||||
if time.time() > timeout:
|
||||
|
||||
self.write_file(triples)
|
||||
timeout = None
|
||||
triples = []
|
||||
|
||||
def write_file(self, triples):
|
||||
|
||||
try:
|
||||
|
||||
schema = pa.schema([
|
||||
pa.field('s', pa.string()),
|
||||
pa.field('p', pa.string()),
|
||||
pa.field('o', pa.string()),
|
||||
])
|
||||
|
||||
fname = self.file_template.format(id=str(uuid.uuid4()))
|
||||
path = f"{self.directory}/{fname}"
|
||||
|
||||
writer = pq.ParquetWriter(path, schema)
|
||||
|
||||
batch = pa.record_batch(
|
||||
[
|
||||
[tpl[0] for tpl in triples],
|
||||
[tpl[1] for tpl in triples],
|
||||
[tpl[2] for tpl in triples],
|
||||
],
|
||||
names=['s', 'p', 'o']
|
||||
)
|
||||
|
||||
writer.write_batch(batch)
|
||||
|
||||
writer.close()
|
||||
|
||||
print(f"Wrote {path}.")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print("Parquet write:", e)
|
||||
|
||||
def write(self, s, p, o):
|
||||
self.q.put((s, p, o))
|
||||
|
||||
def __del__(self):
|
||||
|
||||
self.running = False
|
||||
|
||||
if hasattr(self, "q"):
|
||||
self.thread.join()
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . hf import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . hf import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,99 +0,0 @@
|
|||
|
||||
"""
|
||||
Embeddings service, applies an embeddings model selected from HuggingFace.
|
||||
Input is text, output is embeddings vector.
|
||||
"""
|
||||
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
from ... schema import EmbeddingsRequest, EmbeddingsResponse, Error
|
||||
from ... schema import embeddings_request_queue, embeddings_response_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = embeddings_request_queue
|
||||
default_output_queue = embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_model="all-MiniLM-L6-v2"
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": EmbeddingsRequest,
|
||||
"output_schema": EmbeddingsResponse,
|
||||
}
|
||||
)
|
||||
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=model)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
try:
|
||||
|
||||
text = v.text
|
||||
embeds = self.embeddings.embed_documents([text])
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = EmbeddingsResponse(vectors=embeds, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = EmbeddingsResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="all-MiniLM-L6-v2",
|
||||
help=f'LLM model (default: all-MiniLM-L6-v2)'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . processor import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . processor import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
|
||||
"""
|
||||
Embeddings service, applies an embeddings model selected from HuggingFace.
|
||||
Input is text, output is embeddings vector.
|
||||
"""
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
|
||||
from ... schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from ... schema import embeddings_request_queue, embeddings_response_queue
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = embeddings_request_queue
|
||||
default_output_queue = embeddings_response_queue
|
||||
default_subscriber = module
|
||||
default_model="mxbai-embed-large"
|
||||
default_ollama = 'http://localhost:11434'
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": EmbeddingsRequest,
|
||||
"output_schema": EmbeddingsResponse,
|
||||
}
|
||||
)
|
||||
|
||||
self.embeddings = OllamaEmbeddings(base_url=ollama, model=model)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling input {id}...", flush=True)
|
||||
|
||||
text = v.text
|
||||
embeds = self.embeddings.embed_query([text])
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = EmbeddingsResponse(vectors=[embeds])
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default=default_model,
|
||||
help=f'Embeddings model (default: {default_model})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-r', '--ollama',
|
||||
default=default_ollama,
|
||||
help=f'ollama (default: {default_ollama})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . vectorize import *
|
||||
|
||||
|
|
@ -1,6 +0,0 @@
|
|||
|
||||
from . vectorize import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,103 +0,0 @@
|
|||
|
||||
"""
|
||||
Vectorizer, calls the embeddings service to get embeddings for a chunk.
|
||||
Input is text chunk, output is chunk and vectors.
|
||||
"""
|
||||
|
||||
from ... schema import Chunk, ChunkEmbeddings
|
||||
from ... schema import chunk_ingest_queue, chunk_embeddings_ingest_queue
|
||||
from ... schema import embeddings_request_queue, embeddings_response_queue
|
||||
from ... clients.embeddings_client import EmbeddingsClient
|
||||
from ... log_level import LogLevel
|
||||
from ... base import ConsumerProducer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_ingest_queue
|
||||
default_output_queue = chunk_embeddings_ingest_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
emb_request_queue = params.get(
|
||||
"embeddings_request_queue", embeddings_request_queue
|
||||
)
|
||||
emb_response_queue = params.get(
|
||||
"embeddings_response_queue", embeddings_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"embeddings_request_queue": emb_request_queue,
|
||||
"embeddings_response_queue": emb_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": Chunk,
|
||||
"output_schema": ChunkEmbeddings,
|
||||
}
|
||||
)
|
||||
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
def emit(self, source, chunk, vectors):
|
||||
|
||||
r = ChunkEmbeddings(source=source, chunk=chunk, vectors=vectors)
|
||||
self.producer.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
vectors = self.embeddings.request(chunk)
|
||||
|
||||
self.emit(
|
||||
source=v.source,
|
||||
chunk=chunk.encode("utf-8"),
|
||||
vectors=vectors
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-request-queue',
|
||||
default=embeddings_request_queue,
|
||||
help=f'Embeddings request queue (default: {embeddings_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embeddings-response-queue',
|
||||
default=embeddings_response_queue,
|
||||
help=f'Embeddings request queue (default: {embeddings_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
|
||||
class TooManyRequests(Exception):
|
||||
pass
|
||||
|
||||
class LlmError(Exception):
|
||||
pass
|
||||
|
||||
class ParseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
|
||||
get entity definitions which are output as graph edges.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import json
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Source, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Triple,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
def to_uri(self, text):
|
||||
|
||||
part = text.replace(" ", "-").lower().encode("utf-8")
|
||||
quoted = urllib.parse.quote(part)
|
||||
uri = TRUSTGRAPH_ENTITIES + quoted
|
||||
|
||||
return uri
|
||||
|
||||
def get_definitions(self, chunk):
|
||||
|
||||
return self.prompt.request_definitions(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
defs = self.get_definitions(chunk)
|
||||
|
||||
for defn in defs:
|
||||
|
||||
s = defn.name
|
||||
o = defn.definition
|
||||
|
||||
if s == "": continue
|
||||
if o == "": continue
|
||||
|
||||
if s is None: continue
|
||||
if o is None: continue
|
||||
|
||||
s_uri = self.to_uri(s)
|
||||
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-completion-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,208 +0,0 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts vector+text chunks input, applies entity
|
||||
relationship analysis to get entity relationship edges which are output as
|
||||
graph edges.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, GraphEmbeddings, Source, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import graph_embeddings_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import RDF_LABEL, TRUSTGRAPH_ENTITIES
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
RDF_LABEL_VALUE = Value(value=RDF_LABEL, is_uri=True)
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_vector_queue = graph_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
vector_queue = params.get("vector_queue", default_vector_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Triple,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.vec_prod = self.client.create_producer(
|
||||
topic=vector_queue,
|
||||
schema=JsonSchema(GraphEmbeddings),
|
||||
)
|
||||
|
||||
__class__.pubsub_metric.info({
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"vector_queue": vector_queue,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings.__name__,
|
||||
"output_schema": Triple.__name__,
|
||||
"vector_schema": GraphEmbeddings.__name__,
|
||||
})
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
def to_uri(self, text):
|
||||
|
||||
part = text.replace(" ", "-").lower().encode("utf-8")
|
||||
quoted = urllib.parse.quote(part)
|
||||
uri = TRUSTGRAPH_ENTITIES + quoted
|
||||
|
||||
return uri
|
||||
|
||||
def get_relationships(self, chunk):
|
||||
|
||||
return self.prompt.request_relationships(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_vec(self, ent, vec):
|
||||
|
||||
r = GraphEmbeddings(entity=ent, vectors=vec)
|
||||
self.vec_prod.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
rels = self.get_relationships(chunk)
|
||||
|
||||
for rel in rels:
|
||||
|
||||
s = rel.s
|
||||
p = rel.p
|
||||
o = rel.o
|
||||
|
||||
if s == "": continue
|
||||
if p == "": continue
|
||||
if o == "": continue
|
||||
|
||||
if s is None: continue
|
||||
if p is None: continue
|
||||
if o is None: continue
|
||||
|
||||
s_uri = self.to_uri(s)
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
|
||||
p_uri = self.to_uri(p)
|
||||
p_value = Value(value=str(p_uri), is_uri=True)
|
||||
|
||||
if rel.o_entity:
|
||||
o_uri = self.to_uri(o)
|
||||
o_value = Value(value=str(o_uri), is_uri=True)
|
||||
else:
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(
|
||||
s_value,
|
||||
p_value,
|
||||
o_value
|
||||
)
|
||||
|
||||
# Label for s
|
||||
self.emit_edge(
|
||||
s_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(s), is_uri=False)
|
||||
)
|
||||
|
||||
# Label for p
|
||||
self.emit_edge(
|
||||
p_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(p), is_uri=False)
|
||||
)
|
||||
|
||||
if rel.o_entity:
|
||||
# Label for o
|
||||
self.emit_edge(
|
||||
o_value,
|
||||
RDF_LABEL_VALUE,
|
||||
Value(value=str(o), is_uri=False)
|
||||
)
|
||||
|
||||
self.emit_vec(s_value, v.vectors)
|
||||
self.emit_vec(p_value, v.vectors)
|
||||
if rel.o_entity:
|
||||
self.emit_vec(o_value, v.vectors)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--vector-queue',
|
||||
default=default_vector_queue,
|
||||
help=f'Vector output queue (default: {default_vector_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts embeddings+text chunks input, applies entity analysis to
|
||||
get entity definitions which are output as graph edges.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import json
|
||||
|
||||
from .... schema import ChunkEmbeddings, Triple, Source, Value
|
||||
from .... schema import chunk_embeddings_ingest_queue, triples_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... rdf import TRUSTGRAPH_ENTITIES, DEFINITION
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
DEFINITION_VALUE = Value(value=DEFINITION, is_uri=True)
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = triples_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Triple,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
def to_uri(self, text):
|
||||
|
||||
part = text.replace(" ", "-").lower().encode("utf-8")
|
||||
quoted = urllib.parse.quote(part)
|
||||
uri = TRUSTGRAPH_ENTITIES + quoted
|
||||
|
||||
return uri
|
||||
|
||||
def get_topics(self, chunk):
|
||||
|
||||
return self.prompt.request_topics(chunk)
|
||||
|
||||
def emit_edge(self, s, p, o):
|
||||
|
||||
t = Triple(s=s, p=p, o=o)
|
||||
self.producer.send(t)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
defs = self.get_topics(chunk)
|
||||
|
||||
for defn in defs:
|
||||
|
||||
s = defn.name
|
||||
o = defn.definition
|
||||
|
||||
if s == "": continue
|
||||
if o == "": continue
|
||||
|
||||
if s is None: continue
|
||||
if o is None: continue
|
||||
|
||||
s_uri = self.to_uri(s)
|
||||
|
||||
s_value = Value(value=str(s_uri), is_uri=True)
|
||||
o_value = Value(value=str(o), is_uri=False)
|
||||
|
||||
self.emit_edge(s_value, DEFINITION_VALUE, o_value)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-completion-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . extract import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . extract import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,220 +0,0 @@
|
|||
|
||||
"""
|
||||
Simple decoder, accepts vector+text chunks input, applies analysis to pull
|
||||
out a row of fields. Output as a vector plus object.
|
||||
"""
|
||||
|
||||
import urllib.parse
|
||||
import os
|
||||
from pulsar.schema import JsonSchema
|
||||
|
||||
from .... schema import ChunkEmbeddings, Rows, ObjectEmbeddings, Source
|
||||
from .... schema import RowSchema, Field
|
||||
from .... schema import chunk_embeddings_ingest_queue, rows_store_queue
|
||||
from .... schema import object_embeddings_store_queue
|
||||
from .... schema import prompt_request_queue
|
||||
from .... schema import prompt_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... clients.prompt_client import PromptClient
|
||||
from .... base import ConsumerProducer
|
||||
|
||||
from .... objects.field import Field as FieldParser
|
||||
from .... objects.object import Schema
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = chunk_embeddings_ingest_queue
|
||||
default_output_queue = rows_store_queue
|
||||
default_vector_queue = object_embeddings_store_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
vector_queue = params.get("vector_queue", default_vector_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
pr_request_queue = params.get(
|
||||
"prompt_request_queue", prompt_request_queue
|
||||
)
|
||||
pr_response_queue = params.get(
|
||||
"prompt_response_queue", prompt_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings,
|
||||
"output_schema": Rows,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.vec_prod = self.client.create_producer(
|
||||
topic=vector_queue,
|
||||
schema=JsonSchema(ObjectEmbeddings),
|
||||
)
|
||||
|
||||
__class__.pubsub_metric.info({
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"vector_queue": vector_queue,
|
||||
"prompt_request_queue": pr_request_queue,
|
||||
"prompt_response_queue": pr_response_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": ChunkEmbeddings.__name__,
|
||||
"output_schema": Rows.__name__,
|
||||
"vector_schema": ObjectEmbeddings.__name__,
|
||||
})
|
||||
|
||||
flds = __class__.parse_fields(params["field"])
|
||||
|
||||
for fld in flds:
|
||||
print(fld)
|
||||
|
||||
self.primary = None
|
||||
|
||||
for f in flds:
|
||||
if f.primary:
|
||||
if self.primary:
|
||||
raise RuntimeError(
|
||||
"Only one primary key field is supported"
|
||||
)
|
||||
self.primary = f
|
||||
|
||||
if self.primary == None:
|
||||
raise RuntimeError(
|
||||
"Must have exactly one primary key field"
|
||||
)
|
||||
|
||||
self.schema = Schema(
|
||||
name = params["name"],
|
||||
description = params["description"],
|
||||
fields = flds
|
||||
)
|
||||
|
||||
self.row_schema=RowSchema(
|
||||
name=self.schema.name,
|
||||
description=self.schema.description,
|
||||
fields=[
|
||||
Field(
|
||||
name=f.name, type=str(f.type), size=f.size,
|
||||
primary=f.primary, description=f.description,
|
||||
)
|
||||
for f in self.schema.fields
|
||||
]
|
||||
)
|
||||
|
||||
self.prompt = PromptClient(
|
||||
pulsar_host=self.pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber = module + "-prompt",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_fields(fields):
|
||||
return [ FieldParser.parse(f) for f in fields ]
|
||||
|
||||
def get_rows(self, chunk):
|
||||
return self.prompt.request_rows(self.schema, chunk)
|
||||
|
||||
def emit_rows(self, source, rows):
|
||||
|
||||
t = Rows(
|
||||
source=source, row_schema=self.row_schema, rows=rows
|
||||
)
|
||||
self.producer.send(t)
|
||||
|
||||
def emit_vec(self, source, name, vec, key_name, key):
|
||||
|
||||
r = ObjectEmbeddings(
|
||||
source=source, vectors=vec, name=name, key_name=key_name, id=key
|
||||
)
|
||||
self.vec_prod.send(r)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
print(f"Indexing {v.source.id}...", flush=True)
|
||||
|
||||
chunk = v.chunk.decode("utf-8")
|
||||
|
||||
try:
|
||||
|
||||
rows = self.get_rows(chunk)
|
||||
|
||||
self.emit_rows(
|
||||
source=v.source,
|
||||
rows=rows
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
self.emit_vec(
|
||||
source=v.source, vec=v.vectors,
|
||||
name=self.schema.name, key_name=self.primary.name,
|
||||
key=row[self.primary.name]
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
print(row)
|
||||
|
||||
except Exception as e:
|
||||
print("Exception: ", e, flush=True)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--vector-queue',
|
||||
default=default_vector_queue,
|
||||
help=f'Vector output queue (default: {default_vector_queue})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-request-queue',
|
||||
default=prompt_request_queue,
|
||||
help=f'Prompt request queue (default: {prompt_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--prompt-response-queue',
|
||||
default=prompt_response_queue,
|
||||
help=f'Prompt response queue (default: {prompt_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-f', '--field',
|
||||
required=True,
|
||||
action='append',
|
||||
help=f'Field definition, format name:type:size:pri:descriptionn',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-n', '--name',
|
||||
required=True,
|
||||
help=f'Name of row object',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-d', '--description',
|
||||
required=True,
|
||||
help=f'Description of object',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,250 +0,0 @@
|
|||
|
||||
from . clients.graph_embeddings_client import GraphEmbeddingsClient
|
||||
from . clients.triples_query_client import TriplesQueryClient
|
||||
from . clients.embeddings_client import EmbeddingsClient
|
||||
from . clients.prompt_client import PromptClient
|
||||
|
||||
from . schema import GraphEmbeddingsRequest, GraphEmbeddingsResponse
|
||||
from . schema import TriplesQueryRequest, TriplesQueryResponse
|
||||
from . schema import prompt_request_queue
|
||||
from . schema import prompt_response_queue
|
||||
from . schema import embeddings_request_queue
|
||||
from . schema import embeddings_response_queue
|
||||
from . schema import graph_embeddings_request_queue
|
||||
from . schema import graph_embeddings_response_queue
|
||||
from . schema import triples_request_queue
|
||||
from . schema import triples_response_queue
|
||||
|
||||
LABEL="http://www.w3.org/2000/01/rdf-schema#label"
|
||||
DEFINITION="http://www.w3.org/2004/02/skos/core#definition"
|
||||
|
||||
class GraphRag:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pulsar_host="pulsar://pulsar:6650",
|
||||
pr_request_queue=None,
|
||||
pr_response_queue=None,
|
||||
emb_request_queue=None,
|
||||
emb_response_queue=None,
|
||||
ge_request_queue=None,
|
||||
ge_response_queue=None,
|
||||
tpl_request_queue=None,
|
||||
tpl_response_queue=None,
|
||||
verbose=False,
|
||||
entity_limit=50,
|
||||
triple_limit=30,
|
||||
max_subgraph_size=3000,
|
||||
module="test",
|
||||
):
|
||||
|
||||
self.verbose=verbose
|
||||
|
||||
if pr_request_queue is None:
|
||||
pr_request_queue = prompt_request_queue
|
||||
|
||||
if pr_response_queue is None:
|
||||
pr_response_queue = prompt_response_queue
|
||||
|
||||
if emb_request_queue is None:
|
||||
emb_request_queue = embeddings_request_queue
|
||||
|
||||
if emb_response_queue is None:
|
||||
emb_response_queue = embeddings_response_queue
|
||||
|
||||
if ge_request_queue is None:
|
||||
ge_request_queue = graph_embeddings_request_queue
|
||||
|
||||
if ge_response_queue is None:
|
||||
ge_response_queue = graph_embeddings_response_queue
|
||||
|
||||
if tpl_request_queue is None:
|
||||
tpl_request_queue = triples_request_queue
|
||||
|
||||
if tpl_response_queue is None:
|
||||
tpl_response_queue = triples_response_queue
|
||||
|
||||
if self.verbose:
|
||||
print("Initialising...", flush=True)
|
||||
|
||||
self.ge_client = GraphEmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
subscriber=module + "-ge",
|
||||
input_queue=ge_request_queue,
|
||||
output_queue=ge_response_queue,
|
||||
)
|
||||
|
||||
self.triples_client = TriplesQueryClient(
|
||||
pulsar_host=pulsar_host,
|
||||
subscriber=module + "-tpl",
|
||||
input_queue=tpl_request_queue,
|
||||
output_queue=tpl_response_queue
|
||||
)
|
||||
|
||||
self.embeddings = EmbeddingsClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=emb_request_queue,
|
||||
output_queue=emb_response_queue,
|
||||
subscriber=module + "-emb",
|
||||
)
|
||||
|
||||
self.entity_limit=entity_limit
|
||||
self.query_limit=triple_limit
|
||||
self.max_subgraph_size=max_subgraph_size
|
||||
|
||||
self.label_cache = {}
|
||||
|
||||
self.lang = PromptClient(
|
||||
pulsar_host=pulsar_host,
|
||||
input_queue=pr_request_queue,
|
||||
output_queue=pr_response_queue,
|
||||
subscriber=module + "-prompt",
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def get_vector(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Compute embeddings...", flush=True)
|
||||
|
||||
qembeds = self.embeddings.request(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return qembeds
|
||||
|
||||
def get_entities(self, query):
|
||||
|
||||
vectors = self.get_vector(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Get entities...", flush=True)
|
||||
|
||||
entities = self.ge_client.request(
|
||||
vectors, self.entity_limit
|
||||
)
|
||||
|
||||
entities = [
|
||||
e.value
|
||||
for e in entities
|
||||
]
|
||||
|
||||
if self.verbose:
|
||||
print("Entities:", flush=True)
|
||||
for ent in entities:
|
||||
print(" ", ent, flush=True)
|
||||
|
||||
return entities
|
||||
|
||||
def maybe_label(self, e):
|
||||
|
||||
if e in self.label_cache:
|
||||
return self.label_cache[e]
|
||||
|
||||
res = self.triples_client.request(
|
||||
e, LABEL, None, limit=1
|
||||
)
|
||||
|
||||
if len(res) == 0:
|
||||
self.label_cache[e] = e
|
||||
return e
|
||||
|
||||
self.label_cache[e] = res[0].o.value
|
||||
return self.label_cache[e]
|
||||
|
||||
def get_subgraph(self, query):
|
||||
|
||||
entities = self.get_entities(query)
|
||||
|
||||
subgraph = set()
|
||||
|
||||
if self.verbose:
|
||||
print("Get subgraph...", flush=True)
|
||||
|
||||
for e in entities:
|
||||
|
||||
res = self.triples_client.request(
|
||||
e, None, None,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.triples_client.request(
|
||||
None, e, None,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
res = self.triples_client.request(
|
||||
None, None, e,
|
||||
limit=self.query_limit
|
||||
)
|
||||
|
||||
for triple in res:
|
||||
subgraph.add(
|
||||
(triple.s.value, triple.p.value, triple.o.value)
|
||||
)
|
||||
|
||||
subgraph = list(subgraph)
|
||||
|
||||
subgraph = subgraph[0:self.max_subgraph_size]
|
||||
|
||||
if self.verbose:
|
||||
print("Subgraph:", flush=True)
|
||||
for edge in subgraph:
|
||||
print(" ", str(edge), flush=True)
|
||||
|
||||
if self.verbose:
|
||||
print("Done.", flush=True)
|
||||
|
||||
return subgraph
|
||||
|
||||
def get_labelgraph(self, query):
|
||||
|
||||
subgraph = self.get_subgraph(query)
|
||||
|
||||
sg2 = []
|
||||
|
||||
for edge in subgraph:
|
||||
|
||||
if edge[1] == LABEL:
|
||||
continue
|
||||
|
||||
s = self.maybe_label(edge[0])
|
||||
p = self.maybe_label(edge[1])
|
||||
o = self.maybe_label(edge[2])
|
||||
|
||||
sg2.append((s, p, o))
|
||||
|
||||
return sg2
|
||||
|
||||
def query(self, query):
|
||||
|
||||
if self.verbose:
|
||||
print("Construct prompt...", flush=True)
|
||||
|
||||
kg = self.get_labelgraph(query)
|
||||
|
||||
if self.verbose:
|
||||
print("Invoke LLM...", flush=True)
|
||||
print(kg)
|
||||
print(query)
|
||||
|
||||
resp = self.lang.request_kg_prompt(query, kg)
|
||||
|
||||
if self.verbose:
|
||||
print("Done", flush=True)
|
||||
|
||||
return resp
|
||||
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
|
||||
from enum import Enum
|
||||
import _pulsar
|
||||
|
||||
class LogLevel(Enum):
|
||||
DEBUG = 'debug'
|
||||
INFO = 'info'
|
||||
WARN = 'warn'
|
||||
ERROR = 'error'
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
def to_pulsar(self):
|
||||
if self == LogLevel.DEBUG: return _pulsar.LoggerLevel.Debug
|
||||
if self == LogLevel.INFO: return _pulsar.LoggerLevel.Info
|
||||
if self == LogLevel.WARN: return _pulsar.LoggerLevel.Warn
|
||||
if self == LogLevel.ERROR: return _pulsar.LoggerLevel.Error
|
||||
raise RuntimeError("Log level mismatch")
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . counter import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . counter import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
"""
|
||||
Simple token counter for each LLM response.
|
||||
"""
|
||||
|
||||
from prometheus_client import Histogram, Info
|
||||
from . pricelist import price_list
|
||||
|
||||
from .. schema import TextCompletionResponse, Error
|
||||
from .. schema import text_completion_response_queue
|
||||
from .. log_level import LogLevel
|
||||
from .. base import Consumer
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
|
||||
class Processor(Consumer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionResponse,
|
||||
}
|
||||
)
|
||||
|
||||
def get_prices(self, prices, modelname):
|
||||
for model in prices["price_list"]:
|
||||
if model["model_name"] == modelname:
|
||||
return model["input_price"], model["output_price"]
|
||||
return None, None # Return None if model is not found
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
modelname = v.model
|
||||
|
||||
# Sender-produced ID
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling response {id}...", flush=True)
|
||||
|
||||
num_in = v.in_token
|
||||
num_out = v.out_token
|
||||
|
||||
model_input_price, model_output_price = self.get_prices(price_list, modelname)
|
||||
|
||||
if model_input_price == None:
|
||||
cost_per_call = f"Model Not Found in Price list"
|
||||
else:
|
||||
cost_in = num_in * model_input_price
|
||||
cost_out = num_out * model_output_price
|
||||
cost_per_call = round(cost_in + cost_out, 6)
|
||||
|
||||
print(f"Input Tokens: {num_in}", flush=True)
|
||||
print(f"Output Tokens: {num_out}", flush=True)
|
||||
print(f"Cost for call: ${cost_per_call}", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
Consumer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
|
@ -1,104 +0,0 @@
|
|||
price_list = {
|
||||
"price_list": [
|
||||
{
|
||||
"model_name": "mistral.mistral-large-2407-v1:0",
|
||||
"input_price": 0.000004,
|
||||
"output_price": 0.000012
|
||||
},
|
||||
{
|
||||
"model_name": "meta.llama3-1-405b-instruct-v1:0",
|
||||
"input_price": 0.00000532,
|
||||
"output_price": 0.000016
|
||||
},
|
||||
{
|
||||
"model_name": "mistral.mixtral-8x7b-instruct-v0:1",
|
||||
"input_price": 0.00000045,
|
||||
"output_price": 0.0000007
|
||||
},
|
||||
{
|
||||
"model_name": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"input_price": 0.00000099,
|
||||
"output_price": 0.00000099
|
||||
},
|
||||
{
|
||||
"model_name": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"input_price": 0.00000022,
|
||||
"output_price": 0.00000022
|
||||
},
|
||||
{
|
||||
"model_name": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"input_price": 0.00000025,
|
||||
"output_price": 0.00000125
|
||||
},
|
||||
{
|
||||
"model_name": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"input_price": 0.000003,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "cohere.command-r-plus-v1:0",
|
||||
"input_price": 0.0000030,
|
||||
"output_price": 0.0000150
|
||||
},
|
||||
{
|
||||
"model_name": "ollama",
|
||||
"input_price": 0,
|
||||
"output_price": 0
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-haiku-20240307",
|
||||
"input_price": 0.00000025,
|
||||
"output_price": 0.00000125
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-5-sonnet-20240620",
|
||||
"input_price": 0.000003,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-opus-20240229",
|
||||
"input_price": 0.000015,
|
||||
"output_price": 0.000075
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-sonnet-20240229",
|
||||
"input_price": 0.000003,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "command-r-08-202",
|
||||
"input_price": 0.0000025,
|
||||
"output_price": 0.000010
|
||||
},
|
||||
{
|
||||
"model_name": "c4ai-aya-23-8b",
|
||||
"input_price": 0,
|
||||
"output_price": 0
|
||||
},
|
||||
{
|
||||
"model_name": "llama.cpp",
|
||||
"input_price": 0,
|
||||
"output_price": 0
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o",
|
||||
"input_price": 0.000005,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o-2024-08-06",
|
||||
"input_price": 0.0000025,
|
||||
"output_price": 0.000010
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o-2024-05-13",
|
||||
"input_price": 0.000005,
|
||||
"output_price": 0.000015
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"input_price": 0.00000015,
|
||||
"output_price": 0.0000006
|
||||
},
|
||||
]
|
||||
}
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
|
||||
def to_relationships(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.
|
||||
|
||||
Read the provided text. You will model the text as an information network for a RDF knowledge graph in JSON.
|
||||
|
||||
Information Network Rules:
|
||||
- An information network has subjects connected by predicates to objects.
|
||||
- A subject is a named-entity or a conceptual topic.
|
||||
- One subject can have many predicates and objects.
|
||||
- An object is a property or attribute of a subject.
|
||||
- A subject can be connected by a predicate to another subject.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Obey the information network rules.
|
||||
- Do not return special characters.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of JSON objects with keys "subject", "predicate", "object", and "object-entity".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"subject": string, "predicate": string, "object": string, "object-entity": boolean}}]
|
||||
```
|
||||
|
||||
- The key "object-entity" is TRUE only if the "object" is a subject.
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_topics(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify topics and their definitions in JSON.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Do not respond with special characters.
|
||||
- Return only topics that are concepts and unique to the provided text.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of objects with keys "topic" and "definition".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"topic": string, "definition": string}}]
|
||||
```
|
||||
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_definitions(text):
|
||||
|
||||
prompt = f"""You are a helpful assistant that performs information extraction tasks for a provided text.\nRead the provided text. You will identify entities and their definitions in JSON.
|
||||
|
||||
Reading Instructions:
|
||||
- Ignore document formatting in the provided text.
|
||||
- Study the provided text carefully.
|
||||
|
||||
Here is the text:
|
||||
{text}
|
||||
|
||||
Response Instructions:
|
||||
- Do not respond with special characters.
|
||||
- Return only entities that are named-entities such as: people, organizations, physical objects, locations, animals, products, commodotities, or substances.
|
||||
- Respond only with well-formed JSON.
|
||||
- The JSON response shall be an array of objects with keys "entity" and "definition".
|
||||
- The JSON response shall use the following structure:
|
||||
|
||||
```json
|
||||
[{{"entity": string, "definition": string}}]
|
||||
```
|
||||
|
||||
- Do not write any additional text or explanations.
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_rows(schema, text):
|
||||
|
||||
field_schema = [
|
||||
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
|
||||
for f in schema.fields
|
||||
]
|
||||
|
||||
field_schema = "\n".join(field_schema)
|
||||
|
||||
schema = f"""Object name: {schema.name}
|
||||
Description: {schema.description}
|
||||
|
||||
Fields:
|
||||
{field_schema}"""
|
||||
|
||||
prompt = f"""<instructions>
|
||||
Study the following text and derive objects which match the schema provided.
|
||||
|
||||
You must output an array of JSON objects for each object you discover
|
||||
which matches the schema. For each object, output a JSON object whose fields
|
||||
carry the name field specified in the schema.
|
||||
</instructions>
|
||||
|
||||
<schema>
|
||||
{schema}
|
||||
</schema>
|
||||
|
||||
<text>
|
||||
{text}
|
||||
</text>
|
||||
|
||||
<requirements>
|
||||
You will respond only with raw JSON format data. Do not provide
|
||||
explanations. Do not add markdown formatting or headers or prefixes.
|
||||
</requirements>"""
|
||||
|
||||
return prompt
|
||||
|
||||
def get_cypher(kg):
|
||||
|
||||
sg2 = []
|
||||
|
||||
for f in kg:
|
||||
|
||||
print(f)
|
||||
|
||||
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
|
||||
|
||||
print(sg2)
|
||||
|
||||
kg = "\n".join(sg2)
|
||||
kg = kg.replace("\\", "-")
|
||||
|
||||
return kg
|
||||
|
||||
def to_kg_query(query, kg):
|
||||
|
||||
cypher = get_cypher(kg)
|
||||
|
||||
prompt=f"""Study the following set of knowledge statements. The statements are written in Cypher format that has been extracted from a knowledge graph. Use only the provided set of knowledge statements in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
|
||||
|
||||
Here's the knowledge statements:
|
||||
{cypher}
|
||||
|
||||
Use only the provided knowledge statements to respond to the following:
|
||||
{query}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
def to_document_query(query, documents):
|
||||
|
||||
documents = "\n\n".join(documents)
|
||||
|
||||
prompt=f"""Study the following context. Use only the information provided in the context in your response. Do not speculate if the answer is not found in the provided set of knowledge statements.
|
||||
|
||||
Here is the context:
|
||||
{documents}
|
||||
|
||||
Use only the provided knowledge statements to respond to the following:
|
||||
{query}
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
|
@ -1,473 +0,0 @@
|
|||
"""
|
||||
Language service abstracts prompt engineering from LLM.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .... schema import Definition, Relationship, Triple
|
||||
from .... schema import Topic
|
||||
from .... schema import PromptRequest, PromptResponse, Error
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... schema import prompt_request_queue, prompt_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... clients.llm_client import LlmClient
|
||||
|
||||
from . prompts import to_definitions, to_relationships, to_topics
|
||||
from . prompts import to_kg_query, to_document_query, to_rows
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = prompt_request_queue
|
||||
default_output_queue = prompt_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
tc_request_queue = params.get(
|
||||
"text_completion_request_queue", text_completion_request_queue
|
||||
)
|
||||
tc_response_queue = params.get(
|
||||
"text_completion_response_queue", text_completion_response_queue
|
||||
)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": PromptRequest,
|
||||
"output_schema": PromptResponse,
|
||||
"text_completion_request_queue": tc_request_queue,
|
||||
"text_completion_response_queue": tc_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.llm = LlmClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=tc_request_queue,
|
||||
output_queue=tc_response_queue,
|
||||
pulsar_host = self.pulsar_host
|
||||
)
|
||||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# If no delimiters, assume the entire output is JSON
|
||||
json_str = text.strip()
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
kind = v.kind
|
||||
|
||||
print(f"Handling kind {kind}...", flush=True)
|
||||
|
||||
if kind == "extract-definitions":
|
||||
|
||||
self.handle_extract_definitions(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-topics":
|
||||
|
||||
self.handle_extract_topics(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-relationships":
|
||||
|
||||
self.handle_extract_relationships(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-rows":
|
||||
|
||||
self.handle_extract_rows(id, v)
|
||||
return
|
||||
|
||||
elif kind == "kg-prompt":
|
||||
|
||||
self.handle_kg_prompt(id, v)
|
||||
return
|
||||
|
||||
elif kind == "document-prompt":
|
||||
|
||||
self.handle_document_prompt(id, v)
|
||||
return
|
||||
|
||||
else:
|
||||
|
||||
print("Invalid kind.", flush=True)
|
||||
return
|
||||
|
||||
def handle_extract_definitions(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_definitions(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["entity"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Definition(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(definitions=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_topics(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_topics(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["topic"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Topic(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(topics=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_relationships(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_relationships(v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
|
||||
s = defn["subject"]
|
||||
p = defn["predicate"]
|
||||
o = defn["object"]
|
||||
o_entity = defn["object-entity"]
|
||||
|
||||
if s == "": continue
|
||||
if s is None: continue
|
||||
|
||||
if p == "": continue
|
||||
if p is None: continue
|
||||
|
||||
if o == "": continue
|
||||
if o is None: continue
|
||||
|
||||
if o_entity == "" or o_entity is None:
|
||||
o_entity = False
|
||||
|
||||
output.append(
|
||||
Relationship(
|
||||
s = s,
|
||||
p = p,
|
||||
o = o,
|
||||
o_entity = o_entity,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("relationship fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(relationships=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_rows(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
fields = v.row_schema.fields
|
||||
|
||||
prompt = to_rows(v.row_schema, v.chunk)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
objs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
objs = []
|
||||
|
||||
output = []
|
||||
|
||||
for obj in objs:
|
||||
|
||||
try:
|
||||
|
||||
row = {}
|
||||
|
||||
for f in fields:
|
||||
|
||||
if f.name not in obj:
|
||||
print(f"Object ignored, missing field {f.name}")
|
||||
row = {}
|
||||
break
|
||||
|
||||
row[f.name] = obj[f.name]
|
||||
|
||||
if row == {}:
|
||||
continue
|
||||
|
||||
output.append(row)
|
||||
|
||||
except Exception as e:
|
||||
print("row fields missing, ignored", flush=True)
|
||||
|
||||
for row in output:
|
||||
print(row)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(rows=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_kg_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_kg_query(v.query, v.kg)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_document_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_document_query(v.query, v.documents)
|
||||
|
||||
print("prompt")
|
||||
print(prompt)
|
||||
|
||||
print("Call LLM...")
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-request-queue',
|
||||
default=text_completion_request_queue,
|
||||
help=f'Text completion request queue (default: {text_completion_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-response-queue',
|
||||
default=text_completion_response_queue,
|
||||
help=f'Text completion response queue (default: {text_completion_response_queue})',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . service import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . service import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
|
||||
def to_relationships(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_definitions(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_topics(template, text):
|
||||
return template.format(text=text)
|
||||
|
||||
def to_rows(template, schema, text):
|
||||
|
||||
field_schema = [
|
||||
f"- Name: {f.name}\n Type: {f.type}\n Definition: {f.description}"
|
||||
for f in schema.fields
|
||||
]
|
||||
|
||||
field_schema = "\n".join(field_schema)
|
||||
|
||||
return template.format(schema=schema, text=text)
|
||||
|
||||
schema = f"""Object name: {schema.name}
|
||||
Description: {schema.description}
|
||||
|
||||
Fields:
|
||||
{schema}"""
|
||||
|
||||
prompt = f""""""
|
||||
|
||||
return prompt
|
||||
|
||||
def get_cypher(kg):
|
||||
sg2 = []
|
||||
for f in kg:
|
||||
sg2.append(f"({f.s})-[{f.p}]->({f.o})")
|
||||
kg = "\n".join(sg2)
|
||||
kg = kg.replace("\\", "-")
|
||||
return kg
|
||||
|
||||
def to_kg_query(template, query, kg):
|
||||
cypher = get_cypher(kg)
|
||||
return template.format(query=query, graph=cypher)
|
||||
|
||||
def to_document_query(template, query, docs):
|
||||
docs = "\n\n".join(docs)
|
||||
return template.format(query=query, documents=docs)
|
||||
|
||||
|
|
@ -1,523 +0,0 @@
|
|||
|
||||
"""
|
||||
Language service abstracts prompt engineering from LLM.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from .... schema import Definition, Relationship, Triple
|
||||
from .... schema import Topic
|
||||
from .... schema import PromptRequest, PromptResponse, Error
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... schema import prompt_request_queue, prompt_response_queue
|
||||
from .... base import ConsumerProducer
|
||||
from .... clients.llm_client import LlmClient
|
||||
|
||||
from . prompts import to_definitions, to_relationships, to_rows
|
||||
from . prompts import to_kg_query, to_document_query, to_topics
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = prompt_request_queue
|
||||
default_output_queue = prompt_response_queue
|
||||
default_subscriber = module
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
tc_request_queue = params.get(
|
||||
"text_completion_request_queue", text_completion_request_queue
|
||||
)
|
||||
tc_response_queue = params.get(
|
||||
"text_completion_response_queue", text_completion_response_queue
|
||||
)
|
||||
definition_template = params.get("definition_template")
|
||||
relationship_template = params.get("relationship_template")
|
||||
topic_template = params.get("topic_template")
|
||||
rows_template = params.get("rows_template")
|
||||
knowledge_query_template = params.get("knowledge_query_template")
|
||||
document_query_template = params.get("document_query_template")
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": PromptRequest,
|
||||
"output_schema": PromptResponse,
|
||||
"text_completion_request_queue": tc_request_queue,
|
||||
"text_completion_response_queue": tc_response_queue,
|
||||
}
|
||||
)
|
||||
|
||||
self.llm = LlmClient(
|
||||
subscriber=subscriber,
|
||||
input_queue=tc_request_queue,
|
||||
output_queue=tc_response_queue,
|
||||
pulsar_host = self.pulsar_host
|
||||
)
|
||||
|
||||
self.definition_template = definition_template
|
||||
self.topic_template = topic_template
|
||||
self.relationship_template = relationship_template
|
||||
self.rows_template = rows_template
|
||||
self.knowledge_query_template = knowledge_query_template
|
||||
self.document_query_template = document_query_template
|
||||
|
||||
def parse_json(self, text):
|
||||
json_match = re.search(r'```(?:json)?(.*?)```', text, re.DOTALL)
|
||||
|
||||
if json_match:
|
||||
json_str = json_match.group(1).strip()
|
||||
else:
|
||||
# If no delimiters, assume the entire output is JSON
|
||||
json_str = text.strip()
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
kind = v.kind
|
||||
|
||||
print(f"Handling kind {kind}...", flush=True)
|
||||
|
||||
if kind == "extract-definitions":
|
||||
|
||||
self.handle_extract_definitions(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-topics":
|
||||
|
||||
self.handle_extract_topics(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-relationships":
|
||||
|
||||
self.handle_extract_relationships(id, v)
|
||||
return
|
||||
|
||||
elif kind == "extract-rows":
|
||||
|
||||
self.handle_extract_rows(id, v)
|
||||
return
|
||||
|
||||
elif kind == "kg-prompt":
|
||||
|
||||
self.handle_kg_prompt(id, v)
|
||||
return
|
||||
|
||||
elif kind == "document-prompt":
|
||||
|
||||
self.handle_document_prompt(id, v)
|
||||
return
|
||||
|
||||
else:
|
||||
|
||||
print("Invalid kind.", flush=True)
|
||||
return
|
||||
|
||||
def handle_extract_definitions(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_definitions(self.definition_template, v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["entity"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Definition(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(definitions=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_topics(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_topics(self.topic_template, v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
e = defn["topic"]
|
||||
d = defn["definition"]
|
||||
|
||||
if e == "": continue
|
||||
if e is None: continue
|
||||
if d == "": continue
|
||||
if d is None: continue
|
||||
|
||||
output.append(
|
||||
Topic(
|
||||
name=e, definition=d
|
||||
)
|
||||
)
|
||||
|
||||
except:
|
||||
print("definition fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(topics=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
|
||||
def handle_extract_relationships(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_relationships(self.relationship_template, v.chunk)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
defs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
defs = []
|
||||
|
||||
output = []
|
||||
|
||||
for defn in defs:
|
||||
|
||||
try:
|
||||
|
||||
s = defn["subject"]
|
||||
p = defn["predicate"]
|
||||
o = defn["object"]
|
||||
o_entity = defn["object-entity"]
|
||||
|
||||
if s == "": continue
|
||||
if s is None: continue
|
||||
|
||||
if p == "": continue
|
||||
if p is None: continue
|
||||
|
||||
if o == "": continue
|
||||
if o is None: continue
|
||||
|
||||
if o_entity == "" or o_entity is None:
|
||||
o_entity = False
|
||||
|
||||
output.append(
|
||||
Relationship(
|
||||
s = s,
|
||||
p = p,
|
||||
o = o,
|
||||
o_entity = o_entity,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print("relationship fields missing, ignored", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(relationships=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_extract_rows(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
fields = v.row_schema.fields
|
||||
|
||||
prompt = to_rows(self.rows_template, v.row_schema, v.chunk)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
# Silently ignore JSON parse error
|
||||
try:
|
||||
objs = self.parse_json(ans)
|
||||
except:
|
||||
print("JSON parse error, ignored", flush=True)
|
||||
objs = []
|
||||
|
||||
output = []
|
||||
|
||||
for obj in objs:
|
||||
|
||||
try:
|
||||
|
||||
row = {}
|
||||
|
||||
for f in fields:
|
||||
|
||||
if f.name not in obj:
|
||||
print(f"Object ignored, missing field {f.name}")
|
||||
row = {}
|
||||
break
|
||||
|
||||
row[f.name] = obj[f.name]
|
||||
|
||||
if row == {}:
|
||||
continue
|
||||
|
||||
output.append(row)
|
||||
|
||||
except Exception as e:
|
||||
print("row fields missing, ignored", flush=True)
|
||||
|
||||
for row in output:
|
||||
print(row)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(rows=output, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_kg_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_kg_query(self.knowledge_query_template, v.query, v.kg)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
def handle_document_prompt(self, id, v):
|
||||
|
||||
try:
|
||||
|
||||
prompt = to_document_query(
|
||||
self.document_query_template, v.query, v.documents
|
||||
)
|
||||
|
||||
print(prompt)
|
||||
|
||||
ans = self.llm.request(prompt)
|
||||
|
||||
print(ans)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = PromptResponse(answer=ans, error=None)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = PromptResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-request-queue',
|
||||
default=text_completion_request_queue,
|
||||
help=f'Text completion request queue (default: {text_completion_request_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--text-completion-response-queue',
|
||||
default=text_completion_response_queue,
|
||||
help=f'Text completion response queue (default: {text_completion_response_queue})',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--definition-template',
|
||||
required=True,
|
||||
help=f'Definition extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--topic-template',
|
||||
required=True,
|
||||
help=f'Topic extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--rows-template',
|
||||
required=True,
|
||||
help=f'Rows extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--relationship-template',
|
||||
required=True,
|
||||
help=f'Relationship extraction template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--knowledge-query-template',
|
||||
required=True,
|
||||
help=f'Knowledge query template',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--document-query-template',
|
||||
required=True,
|
||||
help=f'Document query template',
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,226 +0,0 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using the Azure
|
||||
serverless endpoint service. Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_temperature = 0.0
|
||||
default_max_output = 4192
|
||||
default_model = "AzureAI"
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
endpoint = params.get("endpoint")
|
||||
token = params.get("token")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
model = default_model
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
"model": model,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.token = token
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
self.model = model
|
||||
|
||||
def build_prompt(self, system, content):
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system", "content": system
|
||||
},
|
||||
{
|
||||
"role": "user", "content": content
|
||||
}
|
||||
],
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 1
|
||||
}
|
||||
|
||||
body = json.dumps(data)
|
||||
|
||||
return body
|
||||
|
||||
def call_llm(self, body):
|
||||
|
||||
url = self.endpoint
|
||||
|
||||
# Replace this with the primary/secondary key, AMLToken, or
|
||||
# Microsoft Entra ID token for the endpoint
|
||||
api_key = self.token
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {api_key}'
|
||||
}
|
||||
|
||||
resp = requests.post(url, data=body, headers=headers)
|
||||
|
||||
if resp.status_code == 429:
|
||||
raise TooManyRequests()
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise RuntimeError("LLM failure")
|
||||
|
||||
result = resp.json()
|
||||
|
||||
return result
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
try:
|
||||
|
||||
prompt = self.build_prompt(
|
||||
"You are a helpful chatbot",
|
||||
v.prompt
|
||||
)
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
response = self.call_llm(prompt)
|
||||
|
||||
resp = response['choices'][0]['message']['content']
|
||||
inputtokens = response['usage']['prompt_tokens']
|
||||
outputtokens = response['usage']['completion_tokens']
|
||||
|
||||
print(resp, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-e', '--endpoint',
|
||||
help=f'LLM model endpoint'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--token',
|
||||
help=f'LLM model token'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,323 +0,0 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using AWS Bedrock.
|
||||
Input is prompt, output is response. Mistral is default.
|
||||
"""
|
||||
|
||||
import boto3
|
||||
import json
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'mistral.mistral-large-2407-v1:0'
|
||||
default_region = 'us-west-2'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 2048
|
||||
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
aws_id = params.get("aws_id_key")
|
||||
aws_secret = params.get("aws_secret")
|
||||
aws_region = params.get("aws_region", default_region)
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
self.session = boto3.Session(
|
||||
aws_access_key_id=aws_id,
|
||||
aws_secret_access_key=aws_secret,
|
||||
region_name=aws_region
|
||||
)
|
||||
|
||||
self.bedrock = self.session.client(service_name='bedrock-runtime')
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
# Mistral Input Format
|
||||
if self.model.startswith("mistral"):
|
||||
promptbody = json.dumps({
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 0.99,
|
||||
"top_k": 40
|
||||
})
|
||||
|
||||
# Llama 3.1 Input Format
|
||||
elif self.model.startswith("meta"):
|
||||
promptbody = json.dumps({
|
||||
"prompt": prompt,
|
||||
"max_gen_len": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 0.95,
|
||||
})
|
||||
|
||||
# Anthropic Input Format
|
||||
elif self.model.startswith("anthropic"):
|
||||
promptbody = json.dumps({
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 0.999,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Jamba Input Format
|
||||
elif self.model.startswith("ai21"):
|
||||
promptbody = json.dumps({
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 0.9,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Cohere Input Format
|
||||
elif self.model.startswith("cohere"):
|
||||
promptbody = json.dumps({
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"message": prompt
|
||||
})
|
||||
|
||||
# Use Mistral format as defualt
|
||||
else:
|
||||
promptbody = json.dumps({
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.max_output,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 0.99,
|
||||
"top_k": 40
|
||||
})
|
||||
|
||||
accept = 'application/json'
|
||||
contentType = 'application/json'
|
||||
|
||||
# FIXME: Consider catching request limits and raise TooManyRequests
|
||||
# See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
response = self.bedrock.invoke_model(
|
||||
body=promptbody, modelId=self.model, accept=accept,
|
||||
contentType=contentType
|
||||
)
|
||||
|
||||
# Mistral Response Structure
|
||||
if self.model.startswith("mistral"):
|
||||
response_body = json.loads(response.get("body").read())
|
||||
outputtext = response_body['outputs'][0]['text']
|
||||
|
||||
# Claude Response Structure
|
||||
elif self.model.startswith("anthropic"):
|
||||
model_response = json.loads(response["body"].read())
|
||||
outputtext = model_response['content'][0]['text']
|
||||
|
||||
# Llama 3.1 Response Structure
|
||||
elif self.model.startswith("meta"):
|
||||
model_response = json.loads(response["body"].read())
|
||||
outputtext = model_response["generation"]
|
||||
|
||||
# Jamba Response Structure
|
||||
elif self.model.startswith("ai21"):
|
||||
content = response['body'].read()
|
||||
content_str = content.decode('utf-8')
|
||||
content_json = json.loads(content_str)
|
||||
outputtext = content_json['choices'][0]['message']['content']
|
||||
|
||||
# Cohere Input Format
|
||||
elif self.model.startswith("cohere"):
|
||||
content = response['body'].read()
|
||||
content_str = content.decode('utf-8')
|
||||
content_json = json.loads(content_str)
|
||||
outputtext = content_json['text']
|
||||
|
||||
# Use Mistral as default
|
||||
else:
|
||||
response_body = json.loads(response.get("body").read())
|
||||
outputtext = response_body['outputs'][0]['text']
|
||||
|
||||
metadata = response['ResponseMetadata']['HTTPHeaders']
|
||||
inputtokens = int(metadata['x-amzn-bedrock-input-token-count'])
|
||||
outputtokens = int(metadata['x-amzn-bedrock-output-token-count'])
|
||||
|
||||
print(outputtext, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TextCompletionResponse(
|
||||
error=None,
|
||||
response=outputtext,
|
||||
in_token=inputtokens,
|
||||
out_token=outputtokens,
|
||||
model=str(self.model),
|
||||
)
|
||||
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
|
||||
# FIXME: Wrong exception, don't know what Bedrock throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="mistral.mistral-large-2407-v1:0",
|
||||
help=f'Bedrock model (default: Mistral-Large-2407)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-z', '--aws-id-key',
|
||||
help=f'AWS ID Key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--aws-secret',
|
||||
help=f'AWS Secret Key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-r', '--aws-region',
|
||||
help=f'AWS Region (default: us-west-2)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,199 +0,0 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using Claude.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import anthropic
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'claude-3-5-sonnet-20240620'
|
||||
default_temperature = 0.0
|
||||
default_max_output = 8192
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_output": max_output,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.claude = anthropic.Anthropic(api_key=api_key)
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
# FIXME: Rate limits?
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
|
||||
response = message = self.claude.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_output,
|
||||
temperature=self.temperature,
|
||||
system = "You are a helpful chatbot.",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
resp = response.content[0].text
|
||||
inputtokens = response.usage.input_tokens
|
||||
outputtokens = response.usage.output_tokens
|
||||
print(resp, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="claude-3-5-sonnet-20240620",
|
||||
help=f'LLM model (default: claude-3-5-sonnet-20240620)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
help=f'Claude API key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-x', '--max-output',
|
||||
type=int,
|
||||
default=default_max_output,
|
||||
help=f'LLM max output tokens (default: {default_max_output})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from . llm import run
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
||||
|
|
@ -1,179 +0,0 @@
|
|||
|
||||
"""
|
||||
Simple LLM service, performs text prompt completion using Cohere.
|
||||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
import cohere
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from .... schema import TextCompletionRequest, TextCompletionResponse, Error
|
||||
from .... schema import text_completion_request_queue
|
||||
from .... schema import text_completion_response_queue
|
||||
from .... log_level import LogLevel
|
||||
from .... base import ConsumerProducer
|
||||
from .... exceptions import TooManyRequests
|
||||
|
||||
module = ".".join(__name__.split(".")[1:-1])
|
||||
|
||||
default_input_queue = text_completion_request_queue
|
||||
default_output_queue = text_completion_response_queue
|
||||
default_subscriber = module
|
||||
default_model = 'c4ai-aya-23-8b'
|
||||
default_temperature = 0.0
|
||||
|
||||
class Processor(ConsumerProducer):
|
||||
|
||||
def __init__(self, **params):
|
||||
|
||||
input_queue = params.get("input_queue", default_input_queue)
|
||||
output_queue = params.get("output_queue", default_output_queue)
|
||||
subscriber = params.get("subscriber", default_subscriber)
|
||||
model = params.get("model", default_model)
|
||||
api_key = params.get("api_key")
|
||||
temperature = params.get("temperature", default_temperature)
|
||||
|
||||
super(Processor, self).__init__(
|
||||
**params | {
|
||||
"input_queue": input_queue,
|
||||
"output_queue": output_queue,
|
||||
"subscriber": subscriber,
|
||||
"input_schema": TextCompletionRequest,
|
||||
"output_schema": TextCompletionResponse,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(__class__, "text_completion_metric"):
|
||||
__class__.text_completion_metric = Histogram(
|
||||
'text_completion_duration',
|
||||
'Text completion duration (seconds)',
|
||||
buckets=[
|
||||
0.25, 0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0,
|
||||
30.0, 35.0, 40.0, 45.0, 50.0, 60.0, 80.0, 100.0,
|
||||
120.0
|
||||
]
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.cohere = cohere.Client(api_key=api_key)
|
||||
|
||||
print("Initialised", flush=True)
|
||||
|
||||
def handle(self, msg):
|
||||
|
||||
v = msg.value()
|
||||
|
||||
# Sender-produced ID
|
||||
|
||||
id = msg.properties()["id"]
|
||||
|
||||
print(f"Handling prompt {id}...", flush=True)
|
||||
|
||||
prompt = v.prompt
|
||||
|
||||
try:
|
||||
|
||||
with __class__.text_completion_metric.time():
|
||||
|
||||
output = self.cohere.chat(
|
||||
model=self.model,
|
||||
message=prompt,
|
||||
preamble = "You are a helpful AI-assistant.",
|
||||
temperature=self.temperature,
|
||||
chat_history=[],
|
||||
prompt_truncation='auto',
|
||||
connectors=[]
|
||||
)
|
||||
|
||||
resp = output.text
|
||||
inputtokens = int(output.meta.billed_units.input_tokens)
|
||||
outputtokens = int(output.meta.billed_units.output_tokens)
|
||||
|
||||
print(resp, flush=True)
|
||||
print(f"Input Tokens: {inputtokens}", flush=True)
|
||||
print(f"Output Tokens: {outputtokens}", flush=True)
|
||||
|
||||
print("Send response...", flush=True)
|
||||
r = TextCompletionResponse(response=resp, error=None, in_token=inputtokens, out_token=outputtokens, model=self.model)
|
||||
self.send(r, properties={"id": id})
|
||||
|
||||
print("Done.", flush=True)
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except TooManyRequests:
|
||||
|
||||
print("Send rate limit response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "rate-limit",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
print(f"Exception: {e}")
|
||||
|
||||
print("Send error response...", flush=True)
|
||||
|
||||
r = TextCompletionResponse(
|
||||
error=Error(
|
||||
type = "llm-error",
|
||||
message = str(e),
|
||||
),
|
||||
response=None,
|
||||
in_token=None,
|
||||
out_token=None,
|
||||
model=None,
|
||||
)
|
||||
|
||||
self.producer.send(r, properties={"id": id})
|
||||
|
||||
self.consumer.acknowledge(msg)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
|
||||
ConsumerProducer.add_args(
|
||||
parser, default_input_queue, default_subscriber,
|
||||
default_output_queue,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-m', '--model',
|
||||
default="c4ai-aya-23-8b",
|
||||
help=f'Cohere model (default: c4ai-aya-23-8b)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-k', '--api-key',
|
||||
help=f'Cohere API key'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-t', '--temperature',
|
||||
type=float,
|
||||
default=default_temperature,
|
||||
help=f'LLM temperature parameter (default: {default_temperature})'
|
||||
)
|
||||
|
||||
def run():
|
||||
|
||||
Processor.start(module, __doc__)
|
||||
|
||||
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
|
||||
from . llm import *
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue