mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 19:36:24 +02:00
chore: messages supports more types.
This commit is contained in:
parent
2a88158933
commit
218eeef4b8
2 changed files with 59 additions and 2 deletions
|
|
@ -21,7 +21,8 @@ from tenacity import (
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.utils.function_schema import general_function_schema, general_tool_choice
|
||||
from metagpt.provider.constant import general_function_schema, general_tool_choice
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.singleton import Singleton
|
||||
from metagpt.utils.token_counter import (
|
||||
TOKEN_COSTS,
|
||||
|
|
@ -261,7 +262,22 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
self._update_costs(rsp.get("usage"))
|
||||
return rsp
|
||||
|
||||
def ask_code(self, messages: list[dict], **kwargs) -> dict:
|
||||
def _process_message(self, messages: Union[str, Message, list[dict]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
if isinstance(messages, list):
|
||||
return messages
|
||||
|
||||
if isinstance(messages, Message):
|
||||
messages = [messages.to_dict()]
|
||||
elif isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Only support messages type are: str, Message, list[dict], but got {type(messages).__name__}!"
|
||||
)
|
||||
return messages
|
||||
|
||||
def ask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools
|
||||
|
||||
|
|
@ -272,6 +288,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
>>> llm.ask_code(msg)
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = self._chat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
|
|
@ -285,6 +302,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
|
||||
>>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.schema import UserMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -13,6 +14,26 @@ async def test_aask_code():
|
|||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
|
|
@ -20,3 +41,21 @@ def test_ask_code():
|
|||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue