diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 8f8d45d73..11f429601 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -262,10 +262,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): self._update_costs(rsp.get("usage")) return rsp - def _process_message(self, messages: Union[str, Message, list[dict]]) -> list[dict]: + def _process_message(self, messages: Union[str, Message, list[dict], list[Message]]) -> list[dict]: """convert messages to list[dict].""" if isinstance(messages, list): - return messages + return [msg if isinstance(msg, dict) else msg.to_dict() for msg in messages] if isinstance(messages, Message): messages = [messages.to_dict()] diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index d6d9f4f9d..93ac6623c 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -59,3 +59,12 @@ def test_ask_code_Message(): assert "language" in rsp assert "code" in rsp assert len(rsp["code"]) > 0 + + +def test_ask_code_list_Message(): + llm = OpenAIGPTAPI() + msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")] + rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} + assert "language" in rsp + assert "code" in rsp + assert len(rsp["code"]) > 0