From 218eeef4b8616a47b190531db34cbd5cee26e57a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Sat, 18 Nov 2023 18:15:50 +0800 Subject: [PATCH] chore: messages supports more types. --- metagpt/provider/openai_api.py | 22 +++++++++++++-- tests/metagpt/provider/test_openai.py | 39 +++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index b9698e77d..287668c83 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -21,7 +21,8 @@ from tenacity import ( from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI -from metagpt.utils.function_schema import general_function_schema, general_tool_choice +from metagpt.provider.constant import general_function_schema, general_tool_choice +from metagpt.schema import Message from metagpt.utils.singleton import Singleton from metagpt.utils.token_counter import ( TOKEN_COSTS, @@ -261,7 +262,22 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): self._update_costs(rsp.get("usage")) return rsp - def ask_code(self, messages: list[dict], **kwargs) -> dict: + def _process_message(self, messages: Union[str, Message, list[dict]]) -> list[dict]: + """convert messages to list[dict].""" + if isinstance(messages, list): + return messages + + if isinstance(messages, Message): + messages = [messages.to_dict()] + elif isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + else: + raise ValueError( + f"Only support messages type are: str, Message, list[dict], but got {type(messages).__name__}!" + ) + return messages + + def ask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict: """Use function of tools to ask a code. https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools @@ -272,6 +288,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): >>> llm.ask_code(msg) {'language': 'python', 'code': "print('Hello, World!')"} """ + messages = self._process_message(messages) rsp = self._chat_completion_function(messages, **kwargs) return self.get_choice_function_arguments(rsp) @@ -285,6 +302,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): >>> msg = [{'role': 'user', 'content': "Write a python hello world code."}] >>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} """ + messages = self._process_message(messages) rsp = await self._achat_completion_function(messages, **kwargs) return self.get_choice_function_arguments(rsp) diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 4cbc896e0..d6d9f4f9d 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -1,6 +1,7 @@ import pytest from metagpt.provider.openai_api import OpenAIGPTAPI +from metagpt.schema import UserMessage @pytest.mark.asyncio @@ -13,6 +14,26 @@ async def test_aask_code(): assert len(rsp["code"]) > 0 +@pytest.mark.asyncio +async def test_aask_code_str(): + llm = OpenAIGPTAPI() + msg = "Write a python hello world code." + rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} + assert "language" in rsp + assert "code" in rsp + assert len(rsp["code"]) > 0 + + +@pytest.mark.asyncio +async def test_aask_code_Message(): + llm = OpenAIGPTAPI() + msg = UserMessage("Write a python hello world code.") + rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} + assert "language" in rsp + assert "code" in rsp + assert len(rsp["code"]) > 0 + + def test_ask_code(): llm = OpenAIGPTAPI() msg = [{"role": "user", "content": "Write a python hello world code."}] @@ -20,3 +41,21 @@ def test_ask_code(): assert "language" in rsp assert "code" in rsp assert len(rsp["code"]) > 0 + + +def test_ask_code_str(): + llm = OpenAIGPTAPI() + msg = "Write a python hello world code." + rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} + assert "language" in rsp + assert "code" in rsp + assert len(rsp["code"]) > 0 + + +def test_ask_code_Message(): + llm = OpenAIGPTAPI() + msg = UserMessage("Write a python hello world code.") + rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} + assert "language" in rsp + assert "code" in rsp + assert len(rsp["code"]) > 0