mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-07-02 16:01:04 +02:00
generalize write code with tools, simplify ml_engineer
This commit is contained in:
parent
3a312007c2
commit
e12ab25b7c
6 changed files with 221 additions and 162 deletions
|
|
@ -12,14 +12,16 @@ import yaml
|
|||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.ml_engineer import (
|
||||
CODE_GENERATOR_WITH_TOOLS,
|
||||
GENERATE_CODE_PROMPT,
|
||||
ML_MODULE_MAP,
|
||||
ML_SPECIFIC_PROMPT,
|
||||
ML_TOOL_USAGE_PROMPT,
|
||||
SELECT_FUNCTION_TOOLS,
|
||||
TASK_MODULE_MAP,
|
||||
TASK_SPECIFIC_PROMPT,
|
||||
TOOL_RECOMMENDATION_PROMPT,
|
||||
TOOL_USAGE_PROMPT,
|
||||
)
|
||||
|
|
@ -60,13 +62,12 @@ class BaseWriteAnalysisCode(Action):
|
|||
}
|
||||
return messages
|
||||
|
||||
async def run(self, context: List[Message], plan: Plan = None, code_steps: str = "") -> str:
|
||||
async def run(self, context: List[Message], plan: Plan = None) -> str:
|
||||
"""Run of a code writing action, used in data analysis or modeling
|
||||
|
||||
Args:
|
||||
context (List[Message]): Action output history, source action denoted by Message.cause_by
|
||||
plan (Plan, optional): Overall plan. Defaults to None.
|
||||
code_steps (str, optional): suggested step breakdown for the current task. Defaults to "".
|
||||
|
||||
Returns:
|
||||
str: The code string.
|
||||
|
|
@ -92,15 +93,12 @@ class WriteCodeByGenerate(BaseWriteAnalysisCode):
|
|||
class WriteCodeWithTools(BaseWriteAnalysisCode):
|
||||
"""Write code with help of local available tools. Choose tools first, then generate code to use the tools"""
|
||||
|
||||
schema_path: str = ""
|
||||
schema_path: Union[Path, str] = METAGPT_ROOT / "metagpt/tools/functions/schemas"
|
||||
available_tools: dict = {}
|
||||
|
||||
def __init__(self, schema_path="", **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.schema_path = schema_path
|
||||
|
||||
if schema_path:
|
||||
self._load_tools(schema_path)
|
||||
self._load_tools(self.schema_path)
|
||||
|
||||
def _load_tools(self, schema_path, schema_module=None):
|
||||
"""Load tools from yaml file"""
|
||||
|
|
@ -171,12 +169,11 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
|
|||
self,
|
||||
context: List[Message],
|
||||
plan: Plan = None,
|
||||
column_info: str = "",
|
||||
**kwargs,
|
||||
) -> Tuple[List[Message], str]:
|
||||
) -> str:
|
||||
task_type = plan.current_task.task_type
|
||||
available_tools = self.available_tools.get(task_type, {})
|
||||
special_prompt = ML_SPECIFIC_PROMPT.get(task_type, "")
|
||||
special_prompt = TASK_SPECIFIC_PROMPT.get(task_type, "")
|
||||
code_steps = plan.current_task.code_steps
|
||||
|
||||
finished_tasks = plan.get_finished_tasks()
|
||||
|
|
@ -192,9 +189,54 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
|
|||
tool_catalog = self._parse_recommend_tools(task_type, recommend_tools)
|
||||
logger.info(f"Recommended tools: \n{recommend_tools}")
|
||||
|
||||
module_name = ML_MODULE_MAP[task_type]
|
||||
module_name = TASK_MODULE_MAP[task_type]
|
||||
|
||||
prompt = TOOL_USAGE_PROMPT.format(
|
||||
else:
|
||||
tool_catalog = {}
|
||||
module_name = ""
|
||||
|
||||
tools_instruction = TOOL_USAGE_PROMPT.format(
|
||||
special_prompt=special_prompt, module_name=module_name, tool_catalog=tool_catalog
|
||||
)
|
||||
|
||||
context.append(Message(content=tools_instruction, role="user"))
|
||||
|
||||
prompt = self.process_msg(context)
|
||||
|
||||
tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS)
|
||||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
return rsp["code"]
|
||||
|
||||
|
||||
class WriteCodeWithToolsML(WriteCodeWithTools):
|
||||
async def run(
|
||||
self,
|
||||
context: List[Message],
|
||||
plan: Plan = None,
|
||||
column_info: str = "",
|
||||
**kwargs,
|
||||
) -> Tuple[List[Message], str]:
|
||||
task_type = plan.current_task.task_type
|
||||
available_tools = self.available_tools.get(task_type, {})
|
||||
special_prompt = TASK_SPECIFIC_PROMPT.get(task_type, "")
|
||||
code_steps = plan.current_task.code_steps
|
||||
|
||||
finished_tasks = plan.get_finished_tasks()
|
||||
code_context = [remove_comments(task.code) for task in finished_tasks]
|
||||
code_context = "\n\n".join(code_context)
|
||||
|
||||
if len(available_tools) > 0:
|
||||
available_tools = {k: v["description"] for k, v in available_tools.items()}
|
||||
|
||||
recommend_tools = await self._tool_recommendation(
|
||||
plan.current_task.instruction, code_steps, available_tools
|
||||
)
|
||||
tool_catalog = self._parse_recommend_tools(task_type, recommend_tools)
|
||||
logger.info(f"Recommended tools: \n{recommend_tools}")
|
||||
|
||||
module_name = TASK_MODULE_MAP[task_type]
|
||||
|
||||
prompt = ML_TOOL_USAGE_PROMPT.format(
|
||||
user_requirement=plan.goal,
|
||||
history_code=code_context,
|
||||
current_task=plan.current_task.instruction,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue