refine ml prompt

This commit is contained in:
lidanyang 2023-12-07 20:48:00 +08:00
parent 824e285cc7
commit fe2b79fedc
4 changed files with 192 additions and 129 deletions

View file

@ -15,15 +15,11 @@ from metagpt.prompts.ml_engineer import (
TOO_ORGANIZATION_PROMPT,
ML_SPECIFIC_PROMPT,
ML_MODULE_MAP,
TOOL_OUTPUT_DESC,
TOOL_USAGE_PROMPT,
TOOL_OUTPUT_DESC, DATA_PROCESS_PROMPT,
)
from metagpt.schema import Message, Plan
from metagpt.tools.functions import registry
from metagpt.utils.common import create_func_config
from metagpt.prompts.ml_engineer import GEN_DATA_DESC_PROMPT, GENERATE_CODE_PROMPT
from metagpt.utils.common import CodeParser
from metagpt.actions.execute_code import ExecutePyCode
from metagpt.utils.common import create_func_config, remove_comments
class BaseWriteAnalysisCode(Action):
@ -51,13 +47,13 @@ class BaseWriteAnalysisCode(Action):
# 添加默认的提示词
if (
default_system_msg not in messages[0]["content"]
and messages[0]["role"] != "system"
default_system_msg not in messages[0]["content"]
and messages[0]["role"] != "system"
):
messages.insert(0, {"role": "system", "content": default_system_msg})
elif (
default_system_msg not in messages[0]["content"]
and messages[0]["role"] == "system"
default_system_msg not in messages[0]["content"]
and messages[0]["role"] == "system"
):
messages[0] = {
"role": "system",
@ -66,7 +62,7 @@ class BaseWriteAnalysisCode(Action):
return messages
async def run(
self, context: List[Message], plan: Plan = None, code_steps: str = ""
self, context: List[Message], plan: Plan = None, code_steps: str = ""
) -> str:
"""Run of a code writing action, used in data analysis or modeling
@ -87,12 +83,12 @@ class WriteCodeByGenerate(BaseWriteAnalysisCode):
super().__init__(name, context, llm)
async def run(
self,
context: [List[Message]],
plan: Plan = None,
code_steps: str = "",
system_msg: str = None,
**kwargs,
self,
context: [List[Message]],
plan: Plan = None,
code_steps: str = "",
system_msg: str = None,
**kwargs,
) -> str:
context.append(Message(content=self.REUSE_CODE_INSTRUCTION, role="user"))
prompt = self.process_msg(context, system_msg)
@ -102,7 +98,6 @@ class WriteCodeByGenerate(BaseWriteAnalysisCode):
class WriteCodeWithTools(BaseWriteAnalysisCode):
"""Write code with help of local available tools. Choose tools first, then generate code to use the tools"""
execute_code = ExecutePyCode()
@staticmethod
def _parse_recommend_tools(module: str, recommend_tools: list) -> List[Dict]:
@ -126,10 +121,10 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
return tool_catalog
async def _tool_recommendation(
self,
context: [List[Message]],
code_steps: str,
available_tools: list
self,
task: str,
code_steps: str,
available_tools: list
) -> list:
"""
Recommend tools for the specified task.
@ -142,86 +137,63 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
Returns:
list: recommended tools for the specified task
"""
system_prompt = TOOL_RECOMMENDATION_PROMPT.format(
prompt = TOOL_RECOMMENDATION_PROMPT.format(
current_task=task,
code_steps=code_steps,
available_tools=available_tools,
)
prompt = self.process_msg(context, system_prompt)
tool_config = create_func_config(SELECT_FUNCTION_TOOLS)
rsp = await self.llm.aask_code(prompt, **tool_config)
recommend_tools = rsp["recommend_tools"]
return recommend_tools
async def run(
self,
context: List[Message],
plan: Plan = None,
code_steps: str = "",
**kwargs,
self,
context: List[Message],
plan: Plan = None,
code_steps: str = "",
column_info: str = "",
) -> str:
task_type = plan.current_task.task_type
logger.info(f"task_type is: {task_type}")
available_tools = registry.get_all_schema_by_module(task_type)
# special_prompt = ML_SPECIFIC_PROMPT.get(task_type, "")
special_prompt = ML_SPECIFIC_PROMPT.get(task_type, "")
finished_tasks = plan.get_finished_tasks()
code_context = [task.code for task in finished_tasks]
code_context = [remove_comments(task.code) for task in finished_tasks]
code_context = "\n\n".join(code_context)
### add runtime info
result, success = await self.execute_code.run(code_context)
logger.info(result)
if len(available_tools) > 0:
available_tools = [
{k: tool[k] for k in ["name", "description"] if k in tool}
for tool in available_tools
]
final_code = code_context
recommend_tools = await self._tool_recommendation(context, code_steps, available_tools)
recommend_tools = await self._tool_recommendation(
plan.current_task.instruction,
code_steps,
available_tools
)
tool_catalog = self._parse_recommend_tools(task_type, recommend_tools)
logger.info(f"Recommended tools: \n{recommend_tools}")
module_name = ML_MODULE_MAP[task_type]
output_desc = TOOL_OUTPUT_DESC.get(task_type, "")
hist_info = f"Previous finished code is \n\n ```Python {final_code} ``` \n\n " \
f"Runtime result is {result} \n\n"
prompt = TOOL_USAGE_PROMPT.format(
goal=plan.current_task.instruction,
context=hist_info,
prompt = DATA_PROCESS_PROMPT.format(
user_requirement=plan.goal,
history_code=code_context,
current_task=plan.current_task.instruction,
column_info=column_info,
special_prompt=special_prompt,
code_steps=code_steps,
module_name=module_name,
output_desc=output_desc,
function_catalog=tool_catalog,
)
tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS)
rsp = await self.llm.aask_code(prompt, **tool_config)
logger.info(f"rsp is: {rsp}")
final_code = final_code + "\n\n" + rsp["code"]
return final_code
else:
hist_info = f"Previous finished code is \n\n ```Python {code_context} ``` \n\n " \
f"runtime result is {result} \n\n"
context.append(Message(content=self.REUSE_CODE_INSTRUCTION, role="user"))
context.append(Message(content=special_prompt, role="user"))
prompt = self.process_msg(context)
prompt = GENERATE_CODE_PROMPT.format(
goal=plan.current_task.instruction,
context=hist_info,
)
tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS)
logger.info(f"prompt is: {prompt}")
rsp = await self.llm.aask_code(prompt, **tool_config)
logger.info(f"rsp is: {rsp}")
return rsp["code"]
tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS)
rsp = await self.llm.aask_code(prompt, **tool_config)
return rsp['code']