From 657571a192bab2dd092ac4bf5860411489dc9e2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Tue, 21 Nov 2023 12:01:28 +0800 Subject: [PATCH] feat: ask_code message support type list[str]. --- metagpt/provider/openai_api.py | 3 ++- tests/metagpt/provider/test_openai.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 11f429601..ef9b790ce 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -262,9 +262,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): self._update_costs(rsp.get("usage")) return rsp - def _process_message(self, messages: Union[str, Message, list[dict], list[Message]]) -> list[dict]: + def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: """convert messages to list[dict].""" if isinstance(messages, list): + messages = [Message(msg) if isinstance(msg, str) else msg for msg in messages] return [msg if isinstance(msg, dict) else msg.to_dict() for msg in messages] if isinstance(messages, Message): diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 93ac6623c..2b0af37b5 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -68,3 +68,13 @@ def test_ask_code_list_Message(): assert "language" in rsp assert "code" in rsp assert len(rsp["code"]) > 0 + + +def test_ask_code_list_str(): + llm = OpenAIGPTAPI() + msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"] + rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} + print(rsp) + assert "language" in rsp + assert "code" in rsp + assert len(rsp["code"]) > 0