From 322974a917a73325d02ec9593e740e83d8cdf1cd Mon Sep 17 00:00:00 2001 From: EvensXia Date: Tue, 29 Oct 2024 16:10:56 +0800 Subject: [PATCH] fix for test --- metagpt/provider/ollama_api.py | 4 +++- tests/metagpt/provider/test_ollama_api.py | 16 +++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index ab067340c..3f8467843 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -48,7 +48,7 @@ class OllamaMessageBase: if tpe == "text": return msg["text"], None elif tpe == "image_url": - return None, msg["image_url"]["url"][self._image_b64_rms :] + return None, msg["image_url"]["url"][self._image_b64_rms:] else: raise ValueError else: @@ -211,6 +211,8 @@ class OllamaLLM(BaseLLM): else: raise ValueError + def get_choice_text(self, rsp): return self.ollama_message.get_choice(rsp) + async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index af2e929e9..75cfa86d5 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -3,11 +3,11 @@ # @Desc : the unittest of ollama api import json -from typing import Any, Tuple +from typing import Any, AsyncGenerator, Tuple import pytest -from metagpt.provider.ollama_api import OllamaLLM +from metagpt.provider.ollama_api import OllamaLLM, OpenAIResponse from tests.metagpt.provider.mock_llm_config import mock_llm_config from tests.metagpt.provider.req_resp_const import ( llm_general_chat_funcs_test, @@ -23,21 +23,19 @@ default_resp = {"message": {"role": "assistant", "content": resp_cont}} async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]: if stream: - class Iterator(object): + async def async_event_generator() -> AsyncGenerator[Any, None]: events = [ b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}', b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}', ] + for event in events: + yield OpenAIResponse(event, {}) - async def __aiter__(self): - for event in self.events: - yield event - - return Iterator(), None, None + return async_event_generator(), None, None else: raw_default_resp = default_resp.copy() raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20}) - return json.dumps(raw_default_resp).encode(), None, None + return OpenAIResponse(json.dumps(raw_default_resp).encode(), {}), None, None @pytest.mark.asyncio