Merge branch 'dev_ldy' into 'dev'

Dev ldy

See merge request agents/data_agents_opt!13
This commit is contained in:
林义章 2023-12-01 07:57:54 +00:00
commit 20a918bf39
5 changed files with 165 additions and 21 deletions

View file

@ -4,10 +4,10 @@
@Author : orange-crow
@File : write_code_v2.py
"""
import json
from typing import Dict, List, Union
from typing import Dict, List, Union, Tuple
from metagpt.actions import Action
from metagpt.logs import logger
from metagpt.prompts.ml_engineer import (
TOOL_RECOMMENDATION_PROMPT,
SELECT_FUNCTION_TOOLS,
@ -99,24 +99,31 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
"""Write code with help of local available tools. Choose tools first, then generate code to use the tools"""
@staticmethod
def _parse_recommend_tools(module: str, recommend_tools: list) -> str:
def _parse_recommend_tools(module: str, recommend_tools: list) -> Tuple[Dict, List[Dict]]:
"""
Converts recommended tools to a JSON string and checks tool availability in the registry.
Parses and validates a list of recommended tools, and retrieves their schema from registry.
Args:
module (str): The module name for querying tools in the registry.
recommend_tools (list): A list of lists of recommended tools for each step.
Returns:
str: A JSON string with available tools and their schemas for each step.
Tuple[Dict, List[Dict]]:
- valid_tools: A dict of lists of valid tools for each step.
- tool_catalog: A list of dicts of unique tool schemas.
"""
valid_tools = {}
available_tools = registry.get_all_by_module(module).keys()
for index, tools in enumerate(recommend_tools):
key = f"Step {index + 1}"
tools = [tool for tool in tools if tool in available_tools]
valid_tools[key] = registry.get_schemas(module, tools)
return json.dumps(valid_tools)
valid_tools[key] = tools
unique_tools = set()
for tools in valid_tools.values():
unique_tools.update(tools)
tool_catalog = registry.get_schemas(module, unique_tools)
return valid_tools, tool_catalog
async def _tool_recommendation(
self, task: str, data_desc: str, code_steps: str, available_tools: list
@ -165,7 +172,8 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
recommend_tools = await self._tool_recommendation(
task, task_guide, available_tools
)
recommend_tools = self._parse_recommend_tools(task_type, recommend_tools)
recommend_tools, tool_catalog = self._parse_recommend_tools(task_type, recommend_tools)
logger.info(f"Recommended tools for every steps: {recommend_tools}")
special_prompt = ML_SPECIFIC_PROMPT.get(task_type, "")
module_name = ML_MODULE_MAP[task_type]
@ -190,6 +198,7 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
module_name=module_name,
output_desc=output_desc,
available_tools=recommend_tools,
tool_catalog=tool_catalog,
)
tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS)
rsp = await self.llm.aask_code(prompt, **tool_config)

View file

@ -4,12 +4,14 @@
@Author : orange-crow
@File : plan.py
"""
from typing import List
from typing import List, Dict
import json
from metagpt.actions import Action
from metagpt.prompts.ml_engineer import ASSIGN_TASK_TYPE_PROMPT, ASSIGN_TASK_TYPE
from metagpt.schema import Message, Task
from metagpt.utils.common import CodeParser
from metagpt.utils.common import CodeParser, create_func_config
class WritePlan(Action):
PROMPT_TEMPLATE = """
@ -30,7 +32,30 @@ class WritePlan(Action):
]
```
"""
async def run(self, context: List[Message], max_tasks: int = 5) -> str:
async def assign_task_type(self, tasks: List[Dict]) -> str:
"""Assign task type to each task in tasks
Args:
tasks (List[Dict]): tasks to be assigned task type
Returns:
List[Dict]: tasks with task type assigned
"""
task_list = "\n".join(
[f"Task {task['task_id']}: {task['instruction']}" for task in tasks]
)
prompt = ASSIGN_TASK_TYPE_PROMPT.format(task_list=task_list)
tool_config = create_func_config(ASSIGN_TASK_TYPE)
rsp = await self.llm.aask_code(prompt, **tool_config)
task_type_list = rsp["task_type"]
for task, task_type in zip(tasks, task_type_list):
task["task_type"] = task_type
return json.dumps(tasks)
async def run(
self, context: List[Message], max_tasks: int = 5, use_tools: bool = False
) -> str:
prompt = (
self.PROMPT_TEMPLATE.replace("__context__", "\n".join([str(ct) for ct in context]))
# .replace("__current_plan__", current_plan)
@ -38,6 +63,8 @@ class WritePlan(Action):
)
rsp = await self._aask(prompt)
rsp = CodeParser.parse_code(block=None, text=rsp)
if use_tools:
rsp = await self.assign_task_type(json.loads(rsp))
return rsp
@staticmethod