mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-02 12:22:39 +02:00
Merge branch 'fixbug/issues/1016' into HEAD
This commit is contained in:
commit
a6f31bf3e6
16 changed files with 178 additions and 93 deletions
|
|
@ -18,7 +18,7 @@ from metagpt.prompts.di.write_analysis_code import (
|
|||
STRUCTUAL_PROMPT,
|
||||
)
|
||||
from metagpt.schema import Message, Plan
|
||||
from metagpt.utils.common import CodeParser, process_message, remove_comments
|
||||
from metagpt.utils.common import CodeParser, remove_comments
|
||||
|
||||
|
||||
class WriteAnalysisCode(Action):
|
||||
|
|
@ -50,7 +50,7 @@ class WriteAnalysisCode(Action):
|
|||
)
|
||||
|
||||
working_memory = working_memory or []
|
||||
context = process_message([Message(content=structual_prompt, role="user")] + working_memory)
|
||||
context = self.llm.format_msg([Message(content=structual_prompt, role="user")] + working_memory)
|
||||
|
||||
# LLM call
|
||||
if use_reflection:
|
||||
|
|
|
|||
|
|
@ -73,6 +73,28 @@ class BaseLLM(ABC):
|
|||
def _system_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "system", "content": msg}
|
||||
|
||||
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
from metagpt.schema import Message
|
||||
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
processed_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, str):
|
||||
processed_messages.append({"role": "user", "content": msg})
|
||||
elif isinstance(msg, dict):
|
||||
assert set(msg.keys()) == set(["role", "content"])
|
||||
processed_messages.append(msg)
|
||||
elif isinstance(msg, Message):
|
||||
processed_messages.append(msg.to_dict())
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
|
||||
)
|
||||
return processed_messages
|
||||
|
||||
def _system_msgs(self, msgs: list[str]) -> list[dict[str, str]]:
|
||||
return [self._system_msg(msg) for msg in msgs]
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
|
|||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class GeminiGenerativeModel(GenerativeModel):
|
||||
|
|
@ -61,6 +62,35 @@ class GeminiLLM(BaseLLM):
|
|||
def _assistant_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "model", "parts": [msg]}
|
||||
|
||||
def _system_msg(self, msg: str) -> dict[str, str]:
|
||||
return {"role": "user", "parts": [msg]}
|
||||
|
||||
def format_msg(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
from metagpt.schema import Message
|
||||
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
# REF: https://ai.google.dev/tutorials/python_quickstart
|
||||
# As a dictionary, the message requires `role` and `parts` keys.
|
||||
# The role in a conversation can either be the `user`, which provides the prompts,
|
||||
# or `model`, which provides the responses.
|
||||
processed_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, str):
|
||||
processed_messages.append({"role": "user", "parts": [msg]})
|
||||
elif isinstance(msg, dict):
|
||||
assert set(msg.keys()) == set(["role", "parts"])
|
||||
processed_messages.append(msg)
|
||||
elif isinstance(msg, Message):
|
||||
processed_messages.append({"role": "user" if msg.role == "user" else "model", "parts": [msg.content]})
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
|
||||
)
|
||||
return processed_messages
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
|
||||
return kwargs
|
||||
|
|
|
|||
|
|
@ -29,12 +29,7 @@ from metagpt.logs import log_llm_stream, logger
|
|||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.utils.common import (
|
||||
CodeParser,
|
||||
decode_image,
|
||||
log_and_reraise,
|
||||
process_message,
|
||||
)
|
||||
from metagpt.utils.common import CodeParser, decode_image, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
|
|
@ -150,7 +145,7 @@ class OpenAILLM(BaseLLM):
|
|||
async def _achat_completion_function(
|
||||
self, messages: list[dict], timeout: int = 3, **chat_configs
|
||||
) -> ChatCompletion:
|
||||
messages = process_message(messages)
|
||||
messages = self.format_msg(messages)
|
||||
kwargs = self._cons_kwargs(messages=messages, timeout=timeout, **chat_configs)
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
|
|
|
|||
|
|
@ -240,8 +240,8 @@ class Engineer(Role):
|
|||
async def _think(self) -> Action | None:
|
||||
if not self.src_workspace:
|
||||
self.src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
|
||||
write_plan_and_change_filters = any_to_str_set([WriteTasks])
|
||||
write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode, FixBug])
|
||||
write_plan_and_change_filters = any_to_str_set([WriteTasks, FixBug])
|
||||
write_code_filters = any_to_str_set([WriteTasks, WriteCodePlanAndChange, SummarizeCode])
|
||||
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
|
||||
if not self.rc.news:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -164,8 +164,9 @@ class Planner(BaseModel):
|
|||
code_written = "\n\n".join(code_written)
|
||||
task_results = [task.result for task in finished_tasks]
|
||||
task_results = "\n\n".join(task_results)
|
||||
task_type_name = self.current_task.task_type.upper()
|
||||
guidance = TaskType[task_type_name].value.guidance if hasattr(TaskType, task_type_name) else ""
|
||||
task_type_name = self.current_task.task_type
|
||||
task_type = TaskType.get_type(task_type_name)
|
||||
guidance = task_type.guidance if task_type else ""
|
||||
|
||||
# combine components in a prompt
|
||||
prompt = PLAN_STATUS.format(
|
||||
|
|
|
|||
|
|
@ -71,3 +71,10 @@ class TaskType(Enum):
|
|||
@property
|
||||
def type_name(self):
|
||||
return self.value.name
|
||||
|
||||
@classmethod
|
||||
def get_type(cls, type_name):
|
||||
for member in cls:
|
||||
if member.type_name == type_name:
|
||||
return member.value
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -802,29 +802,6 @@ def decode_image(img_url_or_b64: str) -> Image:
|
|||
return img
|
||||
|
||||
|
||||
def process_message(messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
from metagpt.schema import Message
|
||||
|
||||
# 全部转成list
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
# 转成list[dict]
|
||||
processed_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, str):
|
||||
processed_messages.append({"role": "user", "content": msg})
|
||||
elif isinstance(msg, dict):
|
||||
assert set(msg.keys()) == set(["role", "content"])
|
||||
processed_messages.append(msg)
|
||||
elif isinstance(msg, Message):
|
||||
processed_messages.append(msg.to_dict())
|
||||
else:
|
||||
raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!")
|
||||
return processed_messages
|
||||
|
||||
|
||||
def log_and_reraise(retry_state: RetryCallState):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"):
|
|||
else:
|
||||
raise NotImplementedError(
|
||||
f"num_tokens_from_messages() is not implemented for model {model}. "
|
||||
f"See https://github.com/openai/openai-python/blob/main/chatml.md "
|
||||
f"See https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken "
|
||||
f"for information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue