mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 18:36:22 +02:00
Fix Ollama async issue (#854)
* Fix Ollama sync issues - replaced with async * Fix tests
This commit is contained in:
parent
fad005e030
commit
a8fdf547db
4 changed files with 54 additions and 54 deletions
|
|
@ -5,7 +5,7 @@ Input is text, output is embeddings vector.
|
|||
"""
|
||||
from ... base import EmbeddingsService
|
||||
|
||||
from ollama import Client
|
||||
from ollama import AsyncClient
|
||||
import os
|
||||
import logging
|
||||
|
||||
|
|
@ -30,24 +30,24 @@ class Processor(EmbeddingsService):
|
|||
}
|
||||
)
|
||||
|
||||
self.client = Client(host=ollama)
|
||||
self.client = AsyncClient(host=ollama)
|
||||
self.default_model = model
|
||||
self._checked_models = set()
|
||||
|
||||
def _ensure_model(self, model_name):
|
||||
async def _ensure_model(self, model_name):
|
||||
"""Check if model exists locally, pull it if not."""
|
||||
if model_name in self._checked_models:
|
||||
return
|
||||
|
||||
try:
|
||||
self.client.show(model_name)
|
||||
await self.client.show(model_name)
|
||||
self._checked_models.add(model_name)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, 'status_code', None)
|
||||
if status_code == 404 or "not found" in str(e).lower():
|
||||
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
|
||||
try:
|
||||
self.client.pull(model_name)
|
||||
await self.client.pull(model_name)
|
||||
self._checked_models.add(model_name)
|
||||
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
|
||||
except Exception as pull_e:
|
||||
|
|
@ -63,10 +63,10 @@ class Processor(EmbeddingsService):
|
|||
use_model = model or self.default_model
|
||||
|
||||
# Ensure the model exists/is pulled
|
||||
self._ensure_model(use_model)
|
||||
await self._ensure_model(use_model)
|
||||
|
||||
# Ollama handles batch input efficiently
|
||||
embeds = self.client.embed(
|
||||
embeds = await self.client.embed(
|
||||
model = use_model,
|
||||
input = texts
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ Simple LLM service, performs text prompt completion using an Ollama service.
|
|||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
from ollama import Client
|
||||
from ollama import AsyncClient
|
||||
import os
|
||||
import logging
|
||||
|
||||
|
|
@ -38,23 +38,23 @@ class Processor(LlmService):
|
|||
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.llm = Client(host=ollama)
|
||||
self.llm = AsyncClient(host=ollama)
|
||||
self._checked_models = set()
|
||||
|
||||
def _ensure_model(self, model_name):
|
||||
async def _ensure_model(self, model_name):
|
||||
"""Check if model exists locally, pull it if not."""
|
||||
if model_name in self._checked_models:
|
||||
return
|
||||
|
||||
try:
|
||||
self.llm.show(model_name)
|
||||
await self.llm.show(model_name)
|
||||
self._checked_models.add(model_name)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, 'status_code', None)
|
||||
if status_code == 404 or "not found" in str(e).lower():
|
||||
logger.info(f"Ollama model '{model_name}' not found locally. Pulling, this may take a while...")
|
||||
try:
|
||||
self.llm.pull(model_name)
|
||||
await self.llm.pull(model_name)
|
||||
self._checked_models.add(model_name)
|
||||
logger.info(f"Successfully pulled Ollama model '{model_name}'.")
|
||||
except Exception as pull_e:
|
||||
|
|
@ -66,9 +66,9 @@ class Processor(LlmService):
|
|||
|
||||
# Use provided model or fall back to default
|
||||
model_name = model or self.default_model
|
||||
|
||||
|
||||
# Ensure the model exists/is pulled
|
||||
self._ensure_model(model_name)
|
||||
await self._ensure_model(model_name)
|
||||
# Use provided temperature or fall back to default
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
|
|
@ -79,7 +79,7 @@ class Processor(LlmService):
|
|||
|
||||
try:
|
||||
|
||||
response = self.llm.generate(model_name, prompt, options={'temperature': effective_temperature})
|
||||
response = await self.llm.generate(model_name, prompt, options={'temperature': effective_temperature})
|
||||
|
||||
response_text = response['response']
|
||||
logger.debug("Sending response...")
|
||||
|
|
@ -113,7 +113,7 @@ class Processor(LlmService):
|
|||
model_name = model or self.default_model
|
||||
|
||||
# Ensure the model exists/is pulled
|
||||
self._ensure_model(model_name)
|
||||
await self._ensure_model(model_name)
|
||||
|
||||
effective_temperature = temperature if temperature is not None else self.temperature
|
||||
|
||||
|
|
@ -123,7 +123,7 @@ class Processor(LlmService):
|
|||
prompt = system + "\n\n" + prompt
|
||||
|
||||
try:
|
||||
stream = self.llm.generate(
|
||||
stream = await self.llm.generate(
|
||||
model_name,
|
||||
prompt,
|
||||
options={'temperature': effective_temperature},
|
||||
|
|
@ -133,7 +133,7 @@ class Processor(LlmService):
|
|||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
|
||||
for chunk in stream:
|
||||
async for chunk in stream:
|
||||
if 'response' in chunk and chunk['response']:
|
||||
yield LlmChunk(
|
||||
text=chunk['response'],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue