embeddings-ollama

This commit is contained in:
Cyber MacGeddon 2024-07-17 16:08:58 +01:00
parent de283f2582
commit 48386bc7f3

View file

@ -3,31 +3,23 @@
Embeddings service, applies an embeddings model selected from HuggingFace. Embeddings service, applies an embeddings model selected from HuggingFace.
Input is text, output is embeddings vector. Input is text, output is embeddings vector.
""" """
import pulsar
from pulsar.schema import JsonSchema
import tempfile
import base64
import os
import argparse
from langchain_community.embeddings import OllamaEmbeddings from langchain_community.embeddings import OllamaEmbeddings
import time
from ... schema import EmbeddingsRequest, EmbeddingsResponse from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... log_level import LogLevel from ... log_level import LogLevel
from ... base import ConsumerProducer
default_pulsar_host = os.getenv("PULSAR_HOST", 'pulsar://pulsar:6650')
default_input_queue = 'embeddings' default_input_queue = 'embeddings'
default_output_queue = 'embeddings-response' default_output_queue = 'embeddings-response'
default_subscriber = 'embeddings-ollama' default_subscriber = 'embeddings-ollama'
default_model="mxbai-embed-large" default_model="mxbai-embed-large"
default_ollama = 'http://localhost:11434' default_ollama = 'http://localhost:11434'
class Processor: class Processor(ConsumerProducer):
def __init__( def __init__(
self, self,
pulsar_host=default_pulsar_host, pulsar_host=None,
input_queue=default_input_queue, input_queue=default_input_queue,
output_queue=default_output_queue, output_queue=default_output_queue,
subscriber=default_subscriber, subscriber=default_subscriber,
@ -36,32 +28,19 @@ class Processor:
ollama=default_ollama, ollama=default_ollama,
): ):
self.client = None super(Processor, self).__init__(
pulsar_host=pulsar_host,
self.client = pulsar.Client( log_level=log_level,
pulsar_host, input_queue=input_queue,
logger=pulsar.ConsoleLogger(log_level.to_pulsar()) output_queue=output_queue,
) subscriber=subscriber,
input_schema=EmbeddingsRequest,
self.consumer = self.client.subscribe( output_schema=EmbeddingsResponse,
input_queue, subscriber,
schema=JsonSchema(EmbeddingsRequest),
)
self.producer = self.client.create_producer(
topic=output_queue,
schema=JsonSchema(EmbeddingsResponse),
) )
self.embeddings = OllamaEmbeddings(base_url=ollama, model=model) self.embeddings = OllamaEmbeddings(base_url=ollama, model=model)
def run(self): def handle(self, msg):
while True:
msg = self.consumer.receive()
try:
v = msg.value() v = msg.value()
@ -81,58 +60,12 @@ class Processor:
print("Done.", flush=True) print("Done.", flush=True)
# Acknowledge successful processing of the message @staticmethod
self.consumer.acknowledge(msg) def add_args(parser):
except Exception as e: ConsumerProducer.add_args(
parser, default_input_queue, default_subscriber,
print("Exception:", e, flush=True) default_output_queue,
# Message failed to be processed
self.consumer.negative_acknowledge(msg)
def __del__(self):
if self.client:
self.client.close()
def run():
parser = argparse.ArgumentParser(
prog='embeddings-ollama',
description=__doc__,
)
parser.add_argument(
'-p', '--pulsar-host',
default=default_pulsar_host,
help=f'Pulsar host (default: {default_pulsar_host})',
)
parser.add_argument(
'-i', '--input-queue',
default=default_input_queue,
help=f'Input queue (default: {default_input_queue})'
)
parser.add_argument(
'-s', '--subscriber',
default=default_subscriber,
help=f'Queue subscriber name (default: {default_subscriber})'
)
parser.add_argument(
'-o', '--output-queue',
default=default_output_queue,
help=f'Output queue (default: {default_output_queue})'
)
parser.add_argument(
'-l', '--log-level',
type=LogLevel,
default=LogLevel.INFO,
choices=list(LogLevel),
help=f'Output queue (default: info)'
) )
parser.add_argument( parser.add_argument(
@ -147,29 +80,7 @@ def run():
help=f'ollama (default: {default_ollama})' help=f'ollama (default: {default_ollama})'
) )
args = parser.parse_args() def run():
Processor.start('embeddings-ollama', __doc__)
while True:
try:
p = Processor(
pulsar_host=args.pulsar_host,
input_queue=args.input_queue,
output_queue=args.output_queue,
subscriber=args.subscriber,
log_level=args.log_level,
model=args.model,
ollama=args.ollama,
)
p.run()
except Exception as e:
print("Exception:", e, flush=True)
print("Will retry...", flush=True)
time.sleep(10)