Fix ollama embeddings client to work (#285)

This commit is contained in:
cybermaggedon 2025-01-27 23:47:15 +00:00 committed by GitHub
parent 552637c1f7
commit 75a72b0d2d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 19 additions and 5 deletions

View file

@ -49,6 +49,7 @@ setuptools.setup(
"langchain-core", "langchain-core",
"langchain-text-splitters", "langchain-text-splitters",
"neo4j", "neo4j",
"ollama",
"openai", "openai",
"pinecone[grpc]", "pinecone[grpc]",
"prometheus-client", "prometheus-client",

View file

@ -3,12 +3,13 @@
Embeddings service, applies an embeddings model selected from HuggingFace. Embeddings service, applies an embeddings model selected from HuggingFace.
Input is text, output is embeddings vector. Input is text, output is embeddings vector.
""" """
from langchain_community.embeddings import OllamaEmbeddings
from ... schema import EmbeddingsRequest, EmbeddingsResponse from ... schema import EmbeddingsRequest, EmbeddingsResponse
from ... schema import embeddings_request_queue, embeddings_response_queue from ... schema import embeddings_request_queue, embeddings_response_queue
from ... log_level import LogLevel from ... log_level import LogLevel
from ... base import ConsumerProducer from ... base import ConsumerProducer
from ollama import Client
import os
module = ".".join(__name__.split(".")[1:-1]) module = ".".join(__name__.split(".")[1:-1])
@ -16,7 +17,7 @@ default_input_queue = embeddings_request_queue
default_output_queue = embeddings_response_queue default_output_queue = embeddings_response_queue
default_subscriber = module default_subscriber = module
default_model="mxbai-embed-large" default_model="mxbai-embed-large"
default_ollama = 'http://localhost:11434' default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434')
class Processor(ConsumerProducer): class Processor(ConsumerProducer):
@ -26,6 +27,9 @@ class Processor(ConsumerProducer):
output_queue = params.get("output_queue", default_output_queue) output_queue = params.get("output_queue", default_output_queue)
subscriber = params.get("subscriber", default_subscriber) subscriber = params.get("subscriber", default_subscriber)
ollama = params.get("ollama", default_ollama)
model = params.get("model", default_model)
super(Processor, self).__init__( super(Processor, self).__init__(
**params | { **params | {
"input_queue": input_queue, "input_queue": input_queue,
@ -33,10 +37,13 @@ class Processor(ConsumerProducer):
"subscriber": subscriber, "subscriber": subscriber,
"input_schema": EmbeddingsRequest, "input_schema": EmbeddingsRequest,
"output_schema": EmbeddingsResponse, "output_schema": EmbeddingsResponse,
"ollama": ollama,
"model": model,
} }
) )
self.embeddings = OllamaEmbeddings(base_url=ollama, model=model) self.client = Client(host=ollama)
self.model = model
def handle(self, msg): def handle(self, msg):
@ -49,10 +56,16 @@ class Processor(ConsumerProducer):
print(f"Handling input {id}...", flush=True) print(f"Handling input {id}...", flush=True)
text = v.text text = v.text
embeds = self.embeddings.embed_query([text]) embeds = self.client.embed(
model = self.model,
input = text
)
print("Send response...", flush=True) print("Send response...", flush=True)
r = EmbeddingsResponse(vectors=[embeds]) r = EmbeddingsResponse(
vectors=embeds.embeddings,
error=None,
)
self.producer.send(r, properties={"id": id}) self.producer.send(r, properties={"id": id})