From f2dbb51094484fbbc226c6609e4614a1dd903d20 Mon Sep 17 00:00:00 2001 From: better629 Date: Sat, 3 Feb 2024 13:33:02 +0800 Subject: [PATCH] update code due to failed unittests --- metagpt/llm.py | 2 +- metagpt/provider/google_gemini_api.py | 4 +++- tests/metagpt/actions/test_action_node.py | 8 +++++++- tests/metagpt/utils/test_repair_llm_raw_output.py | 6 +++--- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/metagpt/llm.py b/metagpt/llm.py index a3fc5613a..465e419a1 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -16,5 +16,5 @@ def LLM(llm_config: Optional[LLMConfig] = None, context: Context = None) -> Base """get the default llm provider if name is None""" ctx = context or Context() if llm_config is not None: - ctx.llm_with_cost_manager_from_llm_config(llm_config) + return ctx.llm_with_cost_manager_from_llm_config(llm_config) return ctx.llm() diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 6df814b55..2647ab16b 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- # @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart +from typing import Optional, Union + import google.generativeai as genai from google.ai import generativelanguage as glm from google.generativeai.generative_models import GenerativeModel @@ -58,7 +60,7 @@ class GeminiLLM(BaseLLM): def __init_gemini(self, config: LLMConfig): genai.configure(api_key=config.api_key) - def _user_msg(self, msg: str) -> dict[str, str]: + def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, str]: # Not to change BaseLLM default functions but update with Gemini's conversation format. # You should follow the format. return {"role": "user", "parts": [msg]} diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index 589282879..989e2249c 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -244,13 +244,19 @@ def test_create_model_class_with_mapping(): @pytest.mark.asyncio -async def test_action_node_with_image(): +async def test_action_node_with_image(mocker): + # add a mock to update model in unittest, due to the gloabl MockLLM + def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict: + kwargs = {"messages": messages, "temperature": 0.3, "model": "gpt-4-vision-preview"} + return kwargs + invoice = ActionNode( key="invoice", expected_type=bool, instruction="if it's a invoice file, return True else False", example="False" ) invoice_path = Path(__file__).parent.joinpath("..", "..", "data", "invoices", "invoice-2.png") img_base64 = encode_image(invoice_path) + mocker.patch("metagpt.provider.openai_api.OpenAILLM._cons_kwargs", _cons_kwargs) node = await invoice.fill(context="", llm=LLM(), images=[img_base64]) assert node.instruct_content.invoice diff --git a/tests/metagpt/utils/test_repair_llm_raw_output.py b/tests/metagpt/utils/test_repair_llm_raw_output.py index 9eec24727..e28423b91 100644 --- a/tests/metagpt/utils/test_repair_llm_raw_output.py +++ b/tests/metagpt/utils/test_repair_llm_raw_output.py @@ -135,7 +135,7 @@ def test_repair_json_format(): } """ target_output = """{ - "Language": "en_us", + "Language": "en_us", "Programming Language": "Python" }""" output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON) @@ -148,7 +148,7 @@ def test_repair_json_format(): } """ target_output = """{ - "Language": "en_us", + "Language": "en_us", "Programming Language": "Python" }""" output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON) @@ -161,7 +161,7 @@ def test_repair_json_format(): } """ target_output = """{ - "Language": "#en_us#", + "Language": "#en_us#", "Programming Language": "//Python # Code // Language//" }""" output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)