Merge branch 'fixbug/issues/1016' into HEAD

This commit is contained in:
莘权 马 2024-03-20 17:46:48 +08:00
commit a6f31bf3e6
16 changed files with 178 additions and 93 deletions

View file

@ -18,7 +18,7 @@ from metagpt.prompts.di.write_analysis_code import (
STRUCTUAL_PROMPT,
)
from metagpt.schema import Message, Plan
from metagpt.utils.common import CodeParser, process_message, remove_comments
from metagpt.utils.common import CodeParser, remove_comments
class WriteAnalysisCode(Action):
@ -50,7 +50,7 @@ class WriteAnalysisCode(Action):
)
working_memory = working_memory or []
context = process_message([Message(content=structual_prompt, role="user")] + working_memory)
context = self.llm.format_msg([Message(content=structual_prompt, role="user")] + working_memory)
# LLM call
if use_reflection:

View file

@ -73,6 +73,28 @@ class BaseLLM(ABC):
def _system_msg(self, msg: str) -> dict[str, str]:
return {"role": "system", "content": msg}
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
from metagpt.schema import Message
if not isinstance(messages, list):
messages = [messages]
processed_messages = []
for msg in messages:
if isinstance(msg, str):
processed_messages.append({"role": "user", "content": msg})
elif isinstance(msg, dict):
assert set(msg.keys()) == set(["role", "content"])
processed_messages.append(msg)
elif isinstance(msg, Message):
processed_messages.append(msg.to_dict())
else:
raise ValueError(
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
)
return processed_messages
def _system_msgs(self, msgs: list[str]) -> list[dict[str, str]]:
return [self._system_msg(msg) for msg in msgs]

View file

@ -18,6 +18,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
class GeminiGenerativeModel(GenerativeModel):
@ -61,6 +62,35 @@ class GeminiLLM(BaseLLM):
def _assistant_msg(self, msg: str) -> dict[str, str]:
return {"role": "model", "parts": [msg]}
def _system_msg(self, msg: str) -> dict[str, str]:
return {"role": "user", "parts": [msg]}
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
from metagpt.schema import Message
if not isinstance(messages, list):
messages = [messages]
# REF: https://ai.google.dev/tutorials/python_quickstart
# As a dictionary, the message requires `role` and `parts` keys.
# The role in a conversation can either be the `user`, which provides the prompts,
# or `model`, which provides the responses.
processed_messages = []
for msg in messages:
if isinstance(msg, str):
processed_messages.append({"role": "user", "parts": [msg]})
elif isinstance(msg, dict):
assert set(msg.keys()) == set(["role", "parts"])
processed_messages.append(msg)
elif isinstance(msg, Message):
processed_messages.append({"role": "user" if msg.role == "user" else "model", "parts": [msg.content]})
else:
raise ValueError(
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
)
return processed_messages
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
return kwargs

View file

@ -29,12 +29,7 @@ from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.utils.common import (
CodeParser,
decode_image,
log_and_reraise,
process_message,
)
from metagpt.utils.common import CodeParser, decode_image, log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
@ -150,7 +145,7 @@ class OpenAILLM(BaseLLM):
async def _achat_completion_function(
self, messages: list[dict], timeout: int = 3, **chat_configs
) -> ChatCompletion:
messages = process_message(messages)
messages = self.format_msg(messages)
kwargs = self._cons_kwargs(messages=messages, timeout=timeout, **chat_configs)
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)

View file

@ -240,8 +240,8 @@ class Engineer(Role):
async def _think(self) -> Action | None:
if not self.src_workspace:
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
write_plan_and_change_filters = any_to_str_set([WriteTasks])
write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode, FixBug])
write_plan_and_change_filters = any_to_str_set([WriteTasks, FixBug])
write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode])
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
if not self.rc.news:
return None

View file

@ -164,8 +164,9 @@ class Planner(BaseModel):
code_written = "\n\n".join(code_written)
task_results = [task.result for task in finished_tasks]
task_results = "\n\n".join(task_results)
task_type_name = self.current_task.task_type.upper()
guidance = TaskType[task_type_name].value.guidance if hasattr(TaskType, task_type_name) else ""
task_type_name = self.current_task.task_type
task_type = TaskType.get_type(task_type_name)
guidance = task_type.guidance if task_type else ""
# combine components in a prompt
prompt = PLAN_STATUS.format(

View file

@ -71,3 +71,10 @@ class TaskType(Enum):
@property
def type_name(self):
return self.value.name
@classmethod
def get_type(cls, type_name):
for member in cls:
if member.type_name == type_name:
return member.value
return None

View file

@ -802,29 +802,6 @@ def decode_image(img_url_or_b64: str) -> Image:
return img
def process_message(messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
"""convert messages to list[dict]."""
from metagpt.schema import Message
# 全部转成list
if not isinstance(messages, list):
messages = [messages]
# 转成list[dict]
processed_messages = []
for msg in messages:
if isinstance(msg, str):
processed_messages.append({"role": "user", "content": msg})
elif isinstance(msg, dict):
assert set(msg.keys()) == set(["role", "content"])
processed_messages.append(msg)
elif isinstance(msg, Message):
processed_messages.append(msg.to_dict())
else:
raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!")
return processed_messages
def log_and_reraise(retry_state: RetryCallState):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning(

View file

@ -229,7 +229,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
else:
raise NotImplementedError(
f"num_tokens_from_messages() is not implemented for model {model}. "
f"See https://github.com/openai/openai-python/blob/main/chatml.md "
f"See https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken "
f"for information on how messages are converted to tokens."
)
num_tokens = 0