diff --git a/trustgraph-flow/setup.py b/trustgraph-flow/setup.py index bea1d496..f5a0bc3c 100644 --- a/trustgraph-flow/setup.py +++ b/trustgraph-flow/setup.py @@ -49,6 +49,7 @@ setuptools.setup( "langchain-core", "langchain-text-splitters", "neo4j", + "ollama", "openai", "pinecone[grpc]", "prometheus-client", diff --git a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py index 6682a79f..fc54cbb8 100755 --- a/trustgraph-flow/trustgraph/embeddings/ollama/processor.py +++ b/trustgraph-flow/trustgraph/embeddings/ollama/processor.py @@ -3,12 +3,13 @@ Embeddings service, applies an embeddings model selected from HuggingFace. Input is text, output is embeddings vector. """ -from langchain_community.embeddings import OllamaEmbeddings from ... schema import EmbeddingsRequest, EmbeddingsResponse from ... schema import embeddings_request_queue, embeddings_response_queue from ... log_level import LogLevel from ... base import ConsumerProducer +from ollama import Client +import os module = ".".join(__name__.split(".")[1:-1]) @@ -16,7 +17,7 @@ default_input_queue = embeddings_request_queue default_output_queue = embeddings_response_queue default_subscriber = module default_model="mxbai-embed-large" -default_ollama = 'http://localhost:11434' +default_ollama = os.getenv("OLLAMA_HOST", 'http://localhost:11434') class Processor(ConsumerProducer): @@ -26,6 +27,9 @@ class Processor(ConsumerProducer): output_queue = params.get("output_queue", default_output_queue) subscriber = params.get("subscriber", default_subscriber) + ollama = params.get("ollama", default_ollama) + model = params.get("model", default_model) + super(Processor, self).__init__( **params | { "input_queue": input_queue, @@ -33,10 +37,13 @@ class Processor(ConsumerProducer): "subscriber": subscriber, "input_schema": EmbeddingsRequest, "output_schema": EmbeddingsResponse, + "ollama": ollama, + "model": model, } ) - self.embeddings = OllamaEmbeddings(base_url=ollama, model=model) + self.client = Client(host=ollama) + self.model = model def handle(self, msg): @@ -49,10 +56,16 @@ class Processor(ConsumerProducer): print(f"Handling input {id}...", flush=True) text = v.text - embeds = self.embeddings.embed_query([text]) + embeds = self.client.embed( + model = self.model, + input = text + ) print("Send response...", flush=True) - r = EmbeddingsResponse(vectors=[embeds]) + r = EmbeddingsResponse( + vectors=embeds.embeddings, + error=None, + ) self.producer.send(r, properties={"id": id})