From f6a11d508904e6a56a9e35895abfeb439f5c4110 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Wed, 20 Mar 2024 17:34:30 +0800 Subject: [PATCH] fixbug: #1016 --- metagpt/actions/di/write_analysis_code.py | 4 +-- metagpt/provider/base_llm.py | 22 +++++++++++++++++ metagpt/provider/google_gemini_api.py | 30 +++++++++++++++++++++++ metagpt/provider/openai_api.py | 9 ++----- metagpt/utils/common.py | 23 ----------------- tests/mock/mock_llm.py | 3 +-- 6 files changed, 57 insertions(+), 34 deletions(-) diff --git a/metagpt/actions/di/write_analysis_code.py b/metagpt/actions/di/write_analysis_code.py index 185926e31..711e56d39 100644 --- a/metagpt/actions/di/write_analysis_code.py +++ b/metagpt/actions/di/write_analysis_code.py @@ -18,7 +18,7 @@ from metagpt.prompts.di.write_analysis_code import ( STRUCTUAL_PROMPT, ) from metagpt.schema import Message, Plan -from metagpt.utils.common import CodeParser, process_message, remove_comments +from metagpt.utils.common import CodeParser, remove_comments class WriteAnalysisCode(Action): @@ -50,7 +50,7 @@ class WriteAnalysisCode(Action): ) working_memory = working_memory or [] - context = process_message([Message(content=structual_prompt, role="user")] + working_memory) + context = self.llm.format_msg([Message(content=structual_prompt, role="user")] + working_memory) # LLM call if use_reflection: diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 71308930a..601980d5e 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -73,6 +73,28 @@ class BaseLLM(ABC): def _system_msg(self, msg: str) -> dict[str, str]: return {"role": "system", "content": msg} + def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: + """convert messages to list[dict].""" + from metagpt.schema import Message + + if not isinstance(messages, list): + messages = [messages] + + processed_messages = [] + for msg in messages: + if isinstance(msg, str): + processed_messages.append({"role": "user", "content": msg}) + elif isinstance(msg, dict): + assert set(msg.keys()) == set(["role", "content"]) + processed_messages.append(msg) + elif isinstance(msg, Message): + processed_messages.append(msg.to_dict()) + else: + raise ValueError( + f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!" + ) + return processed_messages + def _system_msgs(self, msgs: list[str]) -> list[dict[str, str]]: return [self._system_msg(msg) for msg in msgs] diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 09e554205..7370747a5 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -18,6 +18,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import log_llm_stream from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider +from metagpt.schema import Message class GeminiGenerativeModel(GenerativeModel): @@ -61,6 +62,35 @@ class GeminiLLM(BaseLLM): def _assistant_msg(self, msg: str) -> dict[str, str]: return {"role": "model", "parts": [msg]} + def _system_msg(self, msg: str) -> dict[str, str]: + return {"role": "user", "parts": [msg]} + + def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: + """convert messages to list[dict].""" + from metagpt.schema import Message + + if not isinstance(messages, list): + messages = [messages] + + # REF: https://ai.google.dev/tutorials/python_quickstart + # As a dictionary, the message requires `role` and `parts` keys. + # The role in a conversation can either be the `user`, which provides the prompts, + # or `model`, which provides the responses. + processed_messages = [] + for msg in messages: + if isinstance(msg, str): + processed_messages.append({"role": "user", "parts": [msg]}) + elif isinstance(msg, dict): + assert set(msg.keys()) == set(["role", "parts"]) + processed_messages.append(msg) + elif isinstance(msg, Message): + processed_messages.append({"role": "user" if msg.role == "user" else "model", "parts": [msg.content]}) + else: + raise ValueError( + f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!" + ) + return processed_messages + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream} return kwargs diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index b4f99e69f..2fb64dc85 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -29,12 +29,7 @@ from metagpt.logs import log_llm_stream, logger from metagpt.provider.base_llm import BaseLLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.llm_provider_registry import register_provider -from metagpt.utils.common import ( - CodeParser, - decode_image, - log_and_reraise, - process_message, -) +from metagpt.utils.common import CodeParser, decode_image, log_and_reraise from metagpt.utils.cost_manager import CostManager from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( @@ -150,7 +145,7 @@ class OpenAILLM(BaseLLM): async def _achat_completion_function( self, messages: list[dict], timeout: int = 3, **chat_configs ) -> ChatCompletion: - messages = process_message(messages) + messages = self.format_msg(messages) kwargs = self._cons_kwargs(messages=messages, timeout=timeout, **chat_configs) rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs) self._update_costs(rsp.usage) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e9cef69a4..7493712c2 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -802,29 +802,6 @@ def decode_image(img_url_or_b64: str) -> Image: return img -def process_message(messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]: - """convert messages to list[dict].""" - from metagpt.schema import Message - - # 全部转成list - if not isinstance(messages, list): - messages = [messages] - - # 转成list[dict] - processed_messages = [] - for msg in messages: - if isinstance(msg, str): - processed_messages.append({"role": "user", "content": msg}) - elif isinstance(msg, dict): - assert set(msg.keys()) == set(["role", "content"]) - processed_messages.append(msg) - elif isinstance(msg, Message): - processed_messages.append(msg.to_dict()) - else: - raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!") - return processed_messages - - def log_and_reraise(retry_state: RetryCallState): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") logger.warning( diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index b4cdfa0cf..c4262e080 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -8,7 +8,6 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import Message -from metagpt.utils.common import process_message OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM @@ -105,7 +104,7 @@ class MockLLM(OriginalLLM): return rsp async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict: - msg_key = json.dumps(process_message(messages), ensure_ascii=False) + msg_key = json.dumps(self.format_msg(messages), ensure_ascii=False) rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs) return rsp