update code due to failed unittests

This commit is contained in:
better629 2024-02-03 13:33:02 +08:00
parent 2eeb9556f5
commit f2dbb51094
4 changed files with 14 additions and 6 deletions

View file

@ -16,5 +16,5 @@ def LLM(llm_config: Optional[LLMConfig] = None, context: Context = None) -> Base
"""get the default llm provider if name is None"""
ctx = context or Context()
if llm_config is not None:
ctx.llm_with_cost_manager_from_llm_config(llm_config)
return ctx.llm_with_cost_manager_from_llm_config(llm_config)
return ctx.llm()

View file

@ -2,6 +2,8 @@
# -*- coding: utf-8 -*-
# @Desc : Google Gemini LLM from https://ai.google.dev/tutorials/python_quickstart
from typing import Optional, Union
import google.generativeai as genai
from google.ai import generativelanguage as glm
from google.generativeai.generative_models import GenerativeModel
@ -58,7 +60,7 @@ class GeminiLLM(BaseLLM):
def __init_gemini(self, config: LLMConfig):
genai.configure(api_key=config.api_key)
def _user_msg(self, msg: str) -> dict[str, str]:
def _user_msg(self, msg: str, images: Optional[Union[str, list[str]]] = None) -> dict[str, str]:
# Not to change BaseLLM default functions but update with Gemini's conversation format.
# You should follow the format.
return {"role": "user", "parts": [msg]}

View file

@ -244,13 +244,19 @@ def test_create_model_class_with_mapping():
@pytest.mark.asyncio
async def test_action_node_with_image():
async def test_action_node_with_image(mocker):
# add a mock to update model in unittest, due to the gloabl MockLLM
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
kwargs = {"messages": messages, "temperature": 0.3, "model": "gpt-4-vision-preview"}
return kwargs
invoice = ActionNode(
key="invoice", expected_type=bool, instruction="if it's a invoice file, return True else False", example="False"
)
invoice_path = Path(__file__).parent.joinpath("..", "..", "data", "invoices", "invoice-2.png")
img_base64 = encode_image(invoice_path)
mocker.patch("metagpt.provider.openai_api.OpenAILLM._cons_kwargs", _cons_kwargs)
node = await invoice.fill(context="", llm=LLM(), images=[img_base64])
assert node.instruct_content.invoice

View file

@ -135,7 +135,7 @@ def test_repair_json_format():
}
"""
target_output = """{
"Language": "en_us",
"Language": "en_us",
"Programming Language": "Python"
}"""
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
@ -148,7 +148,7 @@ def test_repair_json_format():
}
"""
target_output = """{
"Language": "en_us",
"Language": "en_us",
"Programming Language": "Python"
}"""
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
@ -161,7 +161,7 @@ def test_repair_json_format():
}
"""
target_output = """{
"Language": "#en_us#",
"Language": "#en_us#",
"Programming Language": "//Python # Code // Language//"
}"""
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)