From 4418fe8283635b5bdf64489b916c073d97695f80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Tue, 21 Nov 2023 11:40:15 +0800 Subject: [PATCH] feat: ask_code message support type list[Message]. --- metagpt/provider/openai_api.py | 4 ++-- tests/metagpt/provider/test_openai.py | 9 +++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 8f8d45d73..11f429601 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -262,10 +262,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): self._update_costs(rsp.get("usage")) return rsp - def _process_message(self, messages: Union[str, Message, list[dict]]) -> list[dict]: + def _process_message(self, messages: Union[str, Message, list[dict], list[Message]]) -> list[dict]: """convert messages to list[dict].""" if isinstance(messages, list): - return messages + return [msg if isinstance(msg, dict) else msg.to_dict() for msg in messages] if isinstance(messages, Message): messages = [messages.to_dict()] diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index d6d9f4f9d..93ac6623c 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -59,3 +59,12 @@ def test_ask_code_Message(): assert "language" in rsp assert "code" in rsp assert len(rsp["code"]) > 0 + + +def test_ask_code_list_Message(): + llm = OpenAIGPTAPI() + msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")] + rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} + assert "language" in rsp + assert "code" in rsp + assert len(rsp["code"]) > 0