mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
commit
ea71ab879a
31 changed files with 866 additions and 1086 deletions
|
|
@ -23,7 +23,7 @@ from metagpt.actions.write_prd import WritePRD
|
|||
from metagpt.actions.write_prd_review import WritePRDReview
|
||||
from metagpt.actions.write_test import WriteTest
|
||||
from metagpt.actions.mi.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.mi.write_analysis_code import WriteCodeWithoutTools, WriteCodeWithTools
|
||||
from metagpt.actions.mi.write_analysis_code import WriteCodeWithTools
|
||||
from metagpt.actions.mi.write_plan import WritePlan
|
||||
|
||||
|
||||
|
|
@ -46,7 +46,6 @@ class ActionType(Enum):
|
|||
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
|
||||
CONDUCT_RESEARCH = ConductResearch
|
||||
EXECUTE_NB_CODE = ExecuteNbCode
|
||||
WRITE_CODE_WITHOUT_TOOLS = WriteCodeWithoutTools
|
||||
WRITE_CODE_WITH_TOOLS = WriteCodeWithTools
|
||||
WRITE_PLAN = WritePlan
|
||||
|
||||
|
|
|
|||
|
|
@ -1,109 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from metagpt.actions.mi.write_analysis_code import BaseWriteAnalysisCode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import create_func_call_config
|
||||
|
||||
DEBUG_REFLECTION_EXAMPLE = '''
|
||||
Example 1:
|
||||
[previous impl]:
|
||||
```python
|
||||
def add(a: int, b: int) -> int:
|
||||
"""
|
||||
Given integers a and b, return the total value of a and b.
|
||||
"""
|
||||
return a - b
|
||||
```
|
||||
|
||||
[runtime Error]:
|
||||
Tested passed:
|
||||
|
||||
Tests failed:
|
||||
assert add(1, 2) == 3 # output: -1
|
||||
assert add(1, 2) == 4 # output: -1
|
||||
|
||||
[reflection on previous impl]:
|
||||
The implementation failed the test cases where the input integers are 1 and 2. The issue arises because the code does not add the two integers together, but instead subtracts the second integer from the first. To fix this issue, we should change the operator from `-` to `+` in the return statement. This will ensure that the function returns the correct output for the given input.
|
||||
|
||||
[improved impl]:
|
||||
```python
|
||||
def add(a: int, b: int) -> int:
|
||||
"""
|
||||
Given integers a and b, return the total value of a and b.
|
||||
"""
|
||||
return a + b
|
||||
```
|
||||
'''
|
||||
|
||||
REFLECTION_PROMPT = """
|
||||
Here is an example for you.
|
||||
{debug_example}
|
||||
[context]
|
||||
{context}
|
||||
|
||||
[previous impl]
|
||||
{code}
|
||||
[runtime Error]
|
||||
{runtime_result}
|
||||
|
||||
Analysis the error step by step, provide me improve method and code. Remember to follow [context] requirement. Don't forget write code for steps behind the error step.
|
||||
[reflection on previous impl]:
|
||||
xxx
|
||||
"""
|
||||
|
||||
CODE_REFLECTION = {
|
||||
"name": "execute_reflection_code",
|
||||
"description": "Execute reflection code.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reflection": {
|
||||
"type": "string",
|
||||
"description": "Reflection on previous impl.",
|
||||
},
|
||||
"improved_impl": {
|
||||
"type": "string",
|
||||
"description": "Refined code after reflection.",
|
||||
},
|
||||
},
|
||||
"required": ["reflection", "improved_impl"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class DebugCode(BaseWriteAnalysisCode):
|
||||
async def run(
|
||||
self,
|
||||
context: list[Message] = None,
|
||||
code: str = "",
|
||||
runtime_result: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Execute the debugging process based on the provided context, code, and runtime_result.
|
||||
|
||||
Args:
|
||||
context (list[Message]): A list of Message objects representing the context.
|
||||
code (str): The code to be debugged.
|
||||
runtime_result (str): The result of the code execution.
|
||||
|
||||
Returns:
|
||||
str: The improved implementation based on the debugging process.
|
||||
"""
|
||||
|
||||
info = []
|
||||
reflection_prompt = REFLECTION_PROMPT.format(
|
||||
debug_example=DEBUG_REFLECTION_EXAMPLE,
|
||||
context=context,
|
||||
code=code,
|
||||
runtime_result=runtime_result,
|
||||
)
|
||||
system_prompt = "You are an AI Python assistant. You will be given your previous implementation code of a task, runtime error results, and a hint to change the implementation appropriately. Write your full implementation "
|
||||
info.append(Message(role="system", content=system_prompt))
|
||||
info.append(Message(role="user", content=reflection_prompt))
|
||||
|
||||
tool_config = create_func_call_config(CODE_REFLECTION)
|
||||
reflection = await self.llm.aask_code(messages=info, **tool_config)
|
||||
logger.info(f"reflection is {reflection}")
|
||||
|
||||
return {"code": reflection["improved_impl"]}
|
||||
|
|
@ -1,70 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.mi.write_analysis_code import WriteCodeWithTools
|
||||
from metagpt.prompts.mi.ml_action import (
|
||||
ML_GENERATE_CODE_PROMPT,
|
||||
ML_TOOL_USAGE_PROMPT,
|
||||
PRINT_DATA_COLUMNS,
|
||||
UPDATE_DATA_COLUMNS,
|
||||
)
|
||||
from metagpt.prompts.mi.write_analysis_code import CODE_GENERATOR_WITH_TOOLS
|
||||
from metagpt.schema import Message, Plan
|
||||
from metagpt.utils.common import create_func_call_config, remove_comments
|
||||
|
||||
|
||||
class WriteCodeWithToolsML(WriteCodeWithTools):
|
||||
async def run(
|
||||
self,
|
||||
context: list[Message],
|
||||
plan: Plan = None,
|
||||
column_info: str = "",
|
||||
**kwargs,
|
||||
) -> Tuple[list[Message], str]:
|
||||
# prepare tool schemas and tool-type-specific instruction
|
||||
tool_schemas, tool_type_usage_prompt = await self._prepare_tools(plan=plan)
|
||||
|
||||
# ML-specific variables to be used in prompt
|
||||
finished_tasks = plan.get_finished_tasks()
|
||||
code_context = [remove_comments(task.code) for task in finished_tasks]
|
||||
code_context = "\n\n".join(code_context)
|
||||
|
||||
# prepare prompt depending on tool availability & LLM call
|
||||
if tool_schemas:
|
||||
prompt = ML_TOOL_USAGE_PROMPT.format(
|
||||
user_requirement=plan.goal,
|
||||
history_code=code_context,
|
||||
current_task=plan.current_task.instruction,
|
||||
column_info=column_info,
|
||||
tool_type_usage_prompt=tool_type_usage_prompt,
|
||||
tool_schemas=tool_schemas,
|
||||
)
|
||||
|
||||
else:
|
||||
prompt = ML_GENERATE_CODE_PROMPT.format(
|
||||
user_requirement=plan.goal,
|
||||
history_code=code_context,
|
||||
current_task=plan.current_task.instruction,
|
||||
column_info=column_info,
|
||||
tool_type_usage_prompt=tool_type_usage_prompt,
|
||||
)
|
||||
tool_config = create_func_call_config(CODE_GENERATOR_WITH_TOOLS)
|
||||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
|
||||
# Extra output to be used for potential debugging
|
||||
context = [Message(content=prompt, role="user")]
|
||||
|
||||
return context, rsp
|
||||
|
||||
|
||||
class UpdateDataColumns(Action):
|
||||
async def run(self, plan: Plan = None) -> dict:
|
||||
finished_tasks = plan.get_finished_tasks()
|
||||
code_context = [remove_comments(task.code) for task in finished_tasks]
|
||||
code_context = "\n\n".join(code_context)
|
||||
prompt = UPDATE_DATA_COLUMNS.format(history_code=code_context)
|
||||
tool_config = create_func_call_config(PRINT_DATA_COLUMNS)
|
||||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
return rsp
|
||||
|
|
@ -6,150 +6,72 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
import json
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.mi.write_analysis_code import (
|
||||
CODE_GENERATOR_WITH_TOOLS,
|
||||
SELECT_FUNCTION_TOOLS,
|
||||
TOOL_RECOMMENDATION_PROMPT,
|
||||
TOOL_USAGE_PROMPT,
|
||||
CHECK_DATA_PROMPT,
|
||||
DEBUG_REFLECTION_EXAMPLE,
|
||||
INTERPRETER_SYSTEM_MSG,
|
||||
REFLECTION_PROMPT,
|
||||
REFLECTION_SYSTEM_MSG,
|
||||
STRUCTUAL_PROMPT,
|
||||
)
|
||||
from metagpt.schema import Message, Plan, SystemMessage
|
||||
from metagpt.tools import TOOL_REGISTRY
|
||||
from metagpt.tools.tool_registry import validate_tool_names
|
||||
from metagpt.utils.common import create_func_call_config
|
||||
from metagpt.schema import Message, Plan
|
||||
from metagpt.utils.common import CodeParser, process_message, remove_comments
|
||||
|
||||
|
||||
class BaseWriteAnalysisCode(Action):
|
||||
DEFAULT_SYSTEM_MSG: str = """You are Code Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.**Notice: The code for the next step depends on the code for the previous step. Must reuse variables in the lastest other code directly, dont creat it again, it is very import for you. Use !pip install in a standalone block to install missing packages.Usually the libraries you need are already installed.Dont check if packages already imported.**""" # prompt reference: https://github.com/KillianLucas/open-interpreter/blob/v0.1.4/interpreter/system_message.txt
|
||||
# REUSE_CODE_INSTRUCTION = """ATTENTION: DONT include codes from previous tasks in your current code block, include new codes only, DONT repeat codes!"""
|
||||
|
||||
def insert_system_message(self, context: list[Message], system_msg: str = None):
|
||||
system_msg = system_msg or self.DEFAULT_SYSTEM_MSG
|
||||
context.insert(0, SystemMessage(content=system_msg)) if context[0].role != "system" else None
|
||||
return context
|
||||
|
||||
async def run(self, context: list[Message], plan: Plan = None) -> dict:
|
||||
"""Run of a code writing action, used in data analysis or modeling
|
||||
|
||||
Args:
|
||||
context (list[Message]): Action output history, source action denoted by Message.cause_by
|
||||
plan (Plan, optional): Overall plan. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: code result in the format of {"code": "print('hello world')", "language": "python"}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WriteCodeWithoutTools(BaseWriteAnalysisCode):
|
||||
"""Ask LLM to generate codes purely by itself without local user-defined tools"""
|
||||
|
||||
async def run(self, context: list[Message], plan: Plan = None, system_msg: str = None, **kwargs) -> dict:
|
||||
messages = self.insert_system_message(context, system_msg)
|
||||
rsp = await self.llm.aask_code(messages, **kwargs)
|
||||
return rsp
|
||||
|
||||
|
||||
class WriteCodeWithTools(BaseWriteAnalysisCode):
|
||||
class WriteCodeWithTools(Action):
|
||||
"""Write code with help of local available tools. Choose tools first, then generate code to use the tools"""
|
||||
|
||||
# selected tools to choose from, listed by their names. An empty list means selection from all tools.
|
||||
selected_tools: list[str] = []
|
||||
|
||||
def _get_tools_by_type(self, tool_type: str) -> dict:
|
||||
"""
|
||||
Retreive tools by tool type from registry, but filtered by pre-selected tool list
|
||||
|
||||
Args:
|
||||
tool_type (str): Tool type to retrieve from the registry
|
||||
|
||||
Returns:
|
||||
dict: A dict of tool name to Tool object, representing available tools under the type
|
||||
"""
|
||||
candidate_tools = TOOL_REGISTRY.get_tools_by_type(tool_type)
|
||||
if self.selected_tools:
|
||||
candidate_tool_names = set(self.selected_tools) & candidate_tools.keys()
|
||||
candidate_tools = {tool_name: candidate_tools[tool_name] for tool_name in candidate_tool_names}
|
||||
return candidate_tools
|
||||
|
||||
async def _recommend_tool(
|
||||
self,
|
||||
task: str,
|
||||
available_tools: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Recommend tools for the specified task.
|
||||
|
||||
Args:
|
||||
task (str): the task to recommend tools for
|
||||
available_tools (dict): the available tools description
|
||||
|
||||
Returns:
|
||||
dict: schemas of recommended tools for the specified task
|
||||
"""
|
||||
prompt = TOOL_RECOMMENDATION_PROMPT.format(
|
||||
current_task=task,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
tool_config = create_func_call_config(SELECT_FUNCTION_TOOLS)
|
||||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
recommend_tools = rsp["recommend_tools"]
|
||||
logger.info(f"Recommended tools: \n{recommend_tools}")
|
||||
|
||||
# Parses and validates the recommended tools, for LLM might hallucinate and recommend non-existing tools
|
||||
valid_tools = validate_tool_names(recommend_tools, return_tool_object=True)
|
||||
|
||||
tool_schemas = {tool.name: tool.schemas for tool in valid_tools}
|
||||
|
||||
return tool_schemas
|
||||
|
||||
async def _prepare_tools(self, plan: Plan) -> Tuple[dict, str]:
|
||||
"""Prepare tool schemas and usage instructions according to current task
|
||||
|
||||
Args:
|
||||
plan (Plan): The overall plan containing task information.
|
||||
|
||||
Returns:
|
||||
Tuple[dict, str]: A tool schemas ({tool_name: tool_schema_dict}) and a usage prompt for the type of tools selected
|
||||
"""
|
||||
# find tool type from task type through exact match, can extend to retrieval in the future
|
||||
tool_type = plan.current_task.task_type
|
||||
|
||||
# prepare tool-type-specific instruction
|
||||
tool_type_usage_prompt = (
|
||||
TOOL_REGISTRY.get_tool_type(tool_type).usage_prompt if TOOL_REGISTRY.has_tool_type(tool_type) else ""
|
||||
async def _debug_with_reflection(self, context: list[Message], working_memory: list[Message]):
|
||||
reflection_prompt = REFLECTION_PROMPT.format(
|
||||
debug_example=DEBUG_REFLECTION_EXAMPLE,
|
||||
context=context,
|
||||
previous_impl=working_memory,
|
||||
)
|
||||
|
||||
# prepare schemas of available tools
|
||||
tool_schemas = {}
|
||||
available_tools = self._get_tools_by_type(tool_type)
|
||||
if available_tools:
|
||||
available_tools = {tool_name: tool.schemas["description"] for tool_name, tool in available_tools.items()}
|
||||
tool_schemas = await self._recommend_tool(plan.current_task.instruction, available_tools)
|
||||
rsp = await self._aask(reflection_prompt, system_msgs=[REFLECTION_SYSTEM_MSG])
|
||||
reflection = json.loads(CodeParser.parse_code(block=None, text=rsp))
|
||||
|
||||
return tool_schemas, tool_type_usage_prompt
|
||||
return reflection["improved_impl"]
|
||||
|
||||
async def run(
|
||||
self,
|
||||
context: list[Message],
|
||||
plan: Plan,
|
||||
user_requirement: str,
|
||||
plan_status: str = "",
|
||||
tool_info: str = "",
|
||||
working_memory: list[Message] = None,
|
||||
use_reflection: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
# prepare tool schemas and tool-type-specific instruction
|
||||
tool_schemas, tool_type_usage_prompt = await self._prepare_tools(plan=plan)
|
||||
|
||||
# form a complete tool usage instruction and include it as a message in context
|
||||
tools_instruction = TOOL_USAGE_PROMPT.format(
|
||||
tool_schemas=tool_schemas, tool_type_usage_prompt=tool_type_usage_prompt
|
||||
structual_prompt = STRUCTUAL_PROMPT.format(
|
||||
user_requirement=user_requirement,
|
||||
plan_status=plan_status,
|
||||
tool_info=tool_info,
|
||||
)
|
||||
context.append(Message(content=tools_instruction, role="user"))
|
||||
|
||||
# prepare prompt & LLM call
|
||||
prompt = self.insert_system_message(context)
|
||||
tool_config = create_func_call_config(CODE_GENERATOR_WITH_TOOLS)
|
||||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
working_memory = working_memory or []
|
||||
context = [Message(content=structual_prompt, role="user")] + working_memory
|
||||
context = process_message(context)
|
||||
|
||||
return rsp
|
||||
# LLM call
|
||||
if not use_reflection:
|
||||
rsp = await self.llm.aask(context, system_msgs=[INTERPRETER_SYSTEM_MSG], **kwargs)
|
||||
code = CodeParser.parse_code(block=None, text=rsp)
|
||||
|
||||
else:
|
||||
code = await self._debug_with_reflection(context=context, working_memory=working_memory)
|
||||
|
||||
return code
|
||||
|
||||
|
||||
class CheckData(Action):
|
||||
async def run(self, plan: Plan = None) -> dict:
|
||||
finished_tasks = plan.get_finished_tasks()
|
||||
code_written = [remove_comments(task.code) for task in finished_tasks]
|
||||
code_written = "\n\n".join(code_written)
|
||||
prompt = CHECK_DATA_PROMPT.format(code_written=code_written)
|
||||
rsp = await self._aask(prompt)
|
||||
code = CodeParser.parse_code(block=None, text=rsp)
|
||||
return code
|
||||
|
|
|
|||
|
|
@ -12,70 +12,42 @@ from typing import Tuple
|
|||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.mi.write_analysis_code import (
|
||||
ASSIGN_TASK_TYPE_CONFIG,
|
||||
ASSIGN_TASK_TYPE_PROMPT,
|
||||
)
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
from metagpt.tools import TOOL_REGISTRY
|
||||
from metagpt.utils.common import CodeParser, create_func_call_config
|
||||
from metagpt.strategy.task_type import TaskType
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
||||
class WritePlan(Action):
|
||||
PROMPT_TEMPLATE: str = """
|
||||
# Context:
|
||||
__context__
|
||||
{context}
|
||||
# Available Task Types:
|
||||
{task_type_desc}
|
||||
# Task:
|
||||
Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to __max_tasks__ tasks.
|
||||
Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to {max_tasks} tasks.
|
||||
If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes. Give the whole plan unless instructed to modify only one task of the plan.
|
||||
If you encounter errors on the current task, revise and output the current single task only.
|
||||
Output a list of jsons following the format:
|
||||
```json
|
||||
[
|
||||
{
|
||||
{{
|
||||
"task_id": str = "unique identifier for a task in plan, can be an ordinal",
|
||||
"dependent_task_ids": list[str] = "ids of tasks prerequisite to this task",
|
||||
"instruction": "what you should do in this task, one short phrase or sentence",
|
||||
},
|
||||
"task_type": "type of this task, should be one of Available Task Types",
|
||||
}},
|
||||
...
|
||||
]
|
||||
```
|
||||
"""
|
||||
|
||||
async def assign_task_type(self, tasks: list[dict]) -> str:
|
||||
"""Assign task type to each task in tasks
|
||||
|
||||
Args:
|
||||
tasks (list[dict]): tasks to be assigned task type
|
||||
|
||||
Returns:
|
||||
str: tasks with task type assigned in a json string
|
||||
"""
|
||||
task_info = "\n".join([f"Task {task['task_id']}: {task['instruction']}" for task in tasks])
|
||||
task_type_desc = "\n".join(
|
||||
[f"- **{tool_type.name}**: {tool_type.desc}" for tool_type in TOOL_REGISTRY.get_tool_types().values()]
|
||||
) # task type are binded with tool type now, should be improved in the future
|
||||
prompt = ASSIGN_TASK_TYPE_PROMPT.format(
|
||||
task_info=task_info, task_type_desc=task_type_desc
|
||||
) # task types are set to be the same as tool types, for now
|
||||
tool_config = create_func_call_config(ASSIGN_TASK_TYPE_CONFIG)
|
||||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
task_type_list = rsp["task_type"]
|
||||
logger.info(f"assigned task types: {task_type_list}")
|
||||
for task, task_type in zip(tasks, task_type_list):
|
||||
task["task_type"] = task_type
|
||||
return json.dumps(tasks)
|
||||
|
||||
async def run(self, context: list[Message], max_tasks: int = 5, use_tools: bool = False) -> str:
|
||||
prompt = (
|
||||
self.PROMPT_TEMPLATE.replace("__context__", "\n".join([str(ct) for ct in context]))
|
||||
# .replace("__current_plan__", current_plan)
|
||||
.replace("__max_tasks__", str(max_tasks))
|
||||
task_type_desc = "\n".join([f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType])
|
||||
prompt = self.PROMPT_TEMPLATE.format(
|
||||
context="\n".join([str(ct) for ct in context]), max_tasks=max_tasks, task_type_desc=task_type_desc
|
||||
)
|
||||
rsp = await self._aask(prompt)
|
||||
rsp = CodeParser.parse_code(block=None, text=rsp)
|
||||
if use_tools:
|
||||
rsp = await self.assign_task_type(json.loads(rsp))
|
||||
return rsp
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,128 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Time : 2023/11/24 15:43
|
||||
# @Author : lidanyang
|
||||
# @File : ml_action
|
||||
# @Desc :
|
||||
UPDATE_DATA_COLUMNS = """
|
||||
# Background
|
||||
Keep dataset column information updated before model train.
|
||||
## Done Tasks
|
||||
```python
|
||||
{history_code}
|
||||
```end
|
||||
|
||||
# Task
|
||||
Update and print the dataset's column information only if the train or test data has changed. Use the following code:
|
||||
```python
|
||||
from metagpt.tools.libs.data_preprocess import get_column_info
|
||||
|
||||
column_info = get_column_info(df)
|
||||
print("column_info")
|
||||
print(column_info)
|
||||
```end
|
||||
|
||||
# Constraints:
|
||||
- Use the DataFrame variable from 'Done Tasks' in place of df.
|
||||
- Import `get_column_info` only if it's not already imported.
|
||||
"""
|
||||
|
||||
PRINT_DATA_COLUMNS = {
|
||||
"name": "print_column_info",
|
||||
"description": "Print the latest column information after 'Done Tasks' code if first read or data changed.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to be added to a new cell in jupyter.",
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
}
|
||||
|
||||
ML_COMMON_PROMPT = """
|
||||
# Background
|
||||
As a data scientist, you need to help user to achieve their goal [{user_requirement}] step-by-step in an continuous Jupyter notebook.
|
||||
|
||||
## Done Tasks
|
||||
```python
|
||||
{history_code}
|
||||
```end
|
||||
|
||||
## Current Task
|
||||
{current_task}
|
||||
|
||||
# Latest Data Info
|
||||
Latest data info after previous tasks:
|
||||
{column_info}
|
||||
|
||||
# Task
|
||||
Write complete code for 'Current Task'. And avoid duplicating code from 'Done Tasks', such as repeated import of packages, reading data, etc.
|
||||
Specifically, {tool_type_usage_prompt}
|
||||
"""
|
||||
|
||||
USE_NO_TOOLS_EXAMPLE = """
|
||||
# Output Example:
|
||||
when current task is "train a lightgbm model on training data", the code can be like:
|
||||
```python
|
||||
# Step 1: check data type and convert to numeric
|
||||
obj_cols = train.select_dtypes(include='object').columns.tolist()
|
||||
|
||||
for col in obj_cols:
|
||||
encoder = LabelEncoder()
|
||||
train[col] = encoder.fit_transform(train[col].unique().tolist() + ['unknown'])
|
||||
test[col] = test[col].apply(lambda x: x if x in encoder.classes_ else 'unknown')
|
||||
test[col] = encoder.transform(test[col])
|
||||
|
||||
# Step 2: train lightgbm model
|
||||
model = LGBMClassifier()
|
||||
model.fit(train, y_train)
|
||||
```end
|
||||
|
||||
# Constraints:
|
||||
- Ensure the output new code is executable in the same Jupyter notebook with previous tasks code have been executed.
|
||||
"""
|
||||
|
||||
USE_TOOLS_EXAMPLE = """
|
||||
# Capabilities
|
||||
- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python Class.
|
||||
- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
|
||||
|
||||
# Available Tools:
|
||||
Each Class tool is described in JSON format. When you call a tool, import the tool from its path first.
|
||||
{tool_schemas}
|
||||
|
||||
# Output Example:
|
||||
when current task is "do data preprocess, like fill missing value, handle outliers, etc.", the code can be like:
|
||||
```python
|
||||
# Step 1: fill missing value
|
||||
# Tools used: ['FillMissingValue']
|
||||
from metagpt.tools.libs.data_preprocess import FillMissingValue
|
||||
|
||||
train_processed = train.copy()
|
||||
test_processed = test.copy()
|
||||
num_cols = train_processed.select_dtypes(include='number').columns.tolist()
|
||||
if 'label' in num_cols:
|
||||
num_cols.remove('label')
|
||||
fill_missing_value = FillMissingValue(features=num_cols, strategy='mean')
|
||||
fill_missing_value.fit(train_processed)
|
||||
train_processed = fill_missing_value.transform(train_processed)
|
||||
test_processed = fill_missing_value.transform(test_processed)
|
||||
|
||||
# Step 2: handle outliers
|
||||
for col in num_cols:
|
||||
low, high = train_processed[col].quantile([0.01, 0.99])
|
||||
train_processed[col] = train_processed[col].clip(low, high)
|
||||
test_processed[col] = test_processed[col].clip(low, high)
|
||||
```end
|
||||
|
||||
# Constraints:
|
||||
- Ensure the output new code is executable in the same Jupyter notebook with previous tasks code have been executed.
|
||||
- Always prioritize using pre-defined tools for the same functionality.
|
||||
- Always copy the DataFrame before processing it and use the copy to process.
|
||||
"""
|
||||
|
||||
ML_GENERATE_CODE_PROMPT = ML_COMMON_PROMPT + USE_NO_TOOLS_EXAMPLE
|
||||
ML_TOOL_USAGE_PROMPT = ML_COMMON_PROMPT + USE_TOOLS_EXAMPLE
|
||||
|
|
@ -1,93 +1,112 @@
|
|||
ASSIGN_TASK_TYPE_PROMPT = """
|
||||
Please assign a task type to each task in the list below from the given categories:
|
||||
{task_info}
|
||||
INTERPRETER_SYSTEM_MSG = """As a data scientist, you need to help user to achieve their goal step by step in a continuous Jupyter notebook. Since it is a notebook environment, don't use asyncio.run. Instead, use await if you need to call an async function."""
|
||||
|
||||
## All Task Type:
|
||||
{task_type_desc}
|
||||
STRUCTUAL_PROMPT = """
|
||||
# User Requirement
|
||||
{user_requirement}
|
||||
|
||||
# Plan Status
|
||||
{plan_status}
|
||||
|
||||
# Tool Info
|
||||
{tool_info}
|
||||
|
||||
# Constraints
|
||||
- Take on Current Task if it is in Plan Status, otherwise, tackle User Requirement directly.
|
||||
- Ensure the output new code is executable in the same Jupyter notebook as the previous executed code.
|
||||
- Always prioritize using pre-defined tools for the same functionality.
|
||||
|
||||
# Output
|
||||
While some concise thoughts are helpful, code is absolutely required. Always output one and only one code block in your response. Output code in the following format:
|
||||
```python
|
||||
your code
|
||||
```
|
||||
"""
|
||||
|
||||
ASSIGN_TASK_TYPE_CONFIG = {
|
||||
"name": "assign_task_type",
|
||||
"description": "Assign task type to each task by order.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_type": {
|
||||
"type": "array",
|
||||
"description": "List of task type. The length should as long as task list",
|
||||
"items": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["task_type"],
|
||||
},
|
||||
}
|
||||
REFLECTION_SYSTEM_MSG = """You are an AI Python assistant. You will be given your previous implementation code of a task, runtime error results, and a hint to change the implementation appropriately. Write your full implementation."""
|
||||
|
||||
TOOL_RECOMMENDATION_PROMPT = """
|
||||
## User Requirement:
|
||||
{current_task}
|
||||
DEBUG_REFLECTION_EXAMPLE = '''
|
||||
[previous impl]:
|
||||
assistant:
|
||||
```python
|
||||
def add(a: int, b: int) -> int:
|
||||
"""
|
||||
Given integers a and b, return the total value of a and b.
|
||||
"""
|
||||
return a - b
|
||||
```
|
||||
|
||||
## Task
|
||||
Recommend up to five tools from 'Available Tools' that can help solve the 'User Requirement'.
|
||||
user:
|
||||
Tests failed:
|
||||
assert add(1, 2) == 3 # output: -1
|
||||
assert add(1, 2) == 4 # output: -1
|
||||
|
||||
## Available Tools:
|
||||
{available_tools}
|
||||
[reflection on previous impl]:
|
||||
The implementation failed the test cases where the input integers are 1 and 2. The issue arises because the code does not add the two integers together, but instead subtracts the second integer from the first. To fix this issue, we should change the operator from `-` to `+` in the return statement. This will ensure that the function returns the correct output for the given input.
|
||||
|
||||
## Tool Selection and Instructions:
|
||||
- Select tools most relevant to completing the 'User Requirement'.
|
||||
- If you believe that no tools are suitable, indicate with an empty list.
|
||||
- Only list the names of the tools, not the full schema of each tool.
|
||||
- Ensure selected tools are listed in 'Available Tools'.
|
||||
[improved impl]:
|
||||
def add(a: int, b: int) -> int:
|
||||
"""
|
||||
Given integers a and b, return the total value of a and b.
|
||||
"""
|
||||
return a + b
|
||||
'''
|
||||
|
||||
REFLECTION_PROMPT = """
|
||||
[example]
|
||||
Here is an example of debugging with reflection.
|
||||
{debug_example}
|
||||
[/example]
|
||||
|
||||
[context]
|
||||
{context}
|
||||
|
||||
[previous impl]:
|
||||
{previous_impl}
|
||||
|
||||
[instruction]
|
||||
Analyze your previous code and error in [context] step by step, provide me with improved method and code. Remember to follow [context] requirement. Don't forget to write code for steps behind the error step.
|
||||
Output a json following the format:
|
||||
```json
|
||||
{{
|
||||
"reflection": str = "Reflection on previous implementation",
|
||||
"improved_impl": str = "Refined code after reflection.",
|
||||
}}
|
||||
```
|
||||
"""
|
||||
|
||||
SELECT_FUNCTION_TOOLS = {
|
||||
"name": "select_function_tools",
|
||||
"description": "For current task, select suitable tools for it.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"recommend_tools": {
|
||||
"type": "array",
|
||||
"description": "List of tool names. Empty list if no tool is suitable.",
|
||||
"items": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["recommend_tools"],
|
||||
},
|
||||
}
|
||||
CHECK_DATA_PROMPT = """
|
||||
# Background
|
||||
Check latest data info to guide subsequent tasks.
|
||||
|
||||
CODE_GENERATOR_WITH_TOOLS = {
|
||||
"name": "add_subtask_code",
|
||||
"description": "Add new code cell of current task to the end of an active Jupyter notebook.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "The code to be added to a new cell in jupyter.",
|
||||
},
|
||||
},
|
||||
"required": ["code"],
|
||||
},
|
||||
}
|
||||
## Finished Tasks
|
||||
```python
|
||||
{code_written}
|
||||
```end
|
||||
|
||||
TOOL_USAGE_PROMPT = """
|
||||
# Instruction
|
||||
Write complete code for 'Current Task'. And avoid duplicating code from finished tasks, such as repeated import of packages, reading data, etc.
|
||||
Specifically, {tool_type_usage_prompt}
|
||||
# Task
|
||||
Check code in finished tasks, print key variables to guide your following actions.
|
||||
Specifically, if it is a data analysis or machine learning task, print the the latest column information using the following code, with DataFrame variable from 'Finished Tasks' in place of df:
|
||||
```python
|
||||
from metagpt.tools.libs.data_preprocess import get_column_info
|
||||
|
||||
# Capabilities
|
||||
- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python Class.
|
||||
- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
|
||||
|
||||
# Available Tools (can be empty):
|
||||
Each Class tool is described in JSON format. When you call a tool, import the tool first.
|
||||
{tool_schemas}
|
||||
column_info = get_column_info(df)
|
||||
print("column_info")
|
||||
print(column_info)
|
||||
```end
|
||||
Otherwise, print out any key variables you see fit. Return an empty string if you think there is no important data to check.
|
||||
|
||||
# Constraints:
|
||||
- Ensure the output new code is executable in the same Jupyter notebook with previous tasks code have been executed.
|
||||
- Always prioritize using pre-defined tools for the same functionality.
|
||||
- Your code is to be added to a new cell in jupyter.
|
||||
|
||||
# Instruction
|
||||
Output code following the format:
|
||||
```python
|
||||
your code
|
||||
```
|
||||
"""
|
||||
|
||||
DATA_INFO = """
|
||||
# Latest Data Info
|
||||
Latest data info after previous tasks:
|
||||
{info}
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
# Prompt for using tools of "eda" type
|
||||
# Prompt for taking on "eda" tasks
|
||||
EDA_PROMPT = """
|
||||
The current task is about exploratory data analysis, please note the following:
|
||||
- Distinguish column types with `select_dtypes` for tailored analysis and visualization, such as correlation.
|
||||
- Remember to `import numpy as np` before using Numpy functions.
|
||||
"""
|
||||
|
||||
# Prompt for using tools of "data_preprocess" type
|
||||
# Prompt for taking on "data_preprocess" tasks
|
||||
DATA_PREPROCESS_PROMPT = """
|
||||
The current task is about data preprocessing, please note the following:
|
||||
- Monitor data types per column, applying appropriate methods.
|
||||
|
|
@ -15,9 +15,10 @@ The current task is about data preprocessing, please note the following:
|
|||
- Prefer alternatives to one-hot encoding for categorical data.
|
||||
- Only encode or scale necessary columns to allow for potential feature-specific engineering tasks (like time_extract, binning, extraction, etc.) later.
|
||||
- Each step do data preprocessing to train, must do same for test separately at the same time.
|
||||
- Always copy the DataFrame before processing it and use the copy to process.
|
||||
"""
|
||||
|
||||
# Prompt for using tools of "feature_engineering" type
|
||||
# Prompt for taking on "feature_engineering" tasks
|
||||
FEATURE_ENGINEERING_PROMPT = """
|
||||
The current task is about feature engineering. when performing it, please adhere to the following principles:
|
||||
- Generate as diverse features as possible to improve the model's performance step-by-step.
|
||||
|
|
@ -27,9 +28,10 @@ The current task is about feature engineering. when performing it, please adhere
|
|||
- Each feature engineering operation performed on the train set must also applies to the test separately at the same time.
|
||||
- Avoid using the label column to create features, except for cat encoding.
|
||||
- Use the data from previous task result if exist, do not mock or reload data yourself.
|
||||
- Always copy the DataFrame before processing it and use the copy to process.
|
||||
"""
|
||||
|
||||
# Prompt for using tools of "model_train" type
|
||||
# Prompt for taking on "model_train" tasks
|
||||
MODEL_TRAIN_PROMPT = """
|
||||
The current task is about training a model, please ensure high performance:
|
||||
- Keep in mind that your user prioritizes results and is highly focused on model performance. So, when needed, feel free to use models of any complexity to improve effectiveness, such as XGBoost, CatBoost, etc.
|
||||
|
|
@ -38,14 +40,14 @@ The current task is about training a model, please ensure high performance:
|
|||
- Set suitable hyperparameters for the model, make metrics as high as possible.
|
||||
"""
|
||||
|
||||
# Prompt for using tools of "model_evaluate" type
|
||||
# Prompt for taking on "model_evaluate" tasks
|
||||
MODEL_EVALUATE_PROMPT = """
|
||||
The current task is about evaluating a model, please note the following:
|
||||
- Ensure that the evaluated data is same processed as the training data. If not, remember use object in 'Done Tasks' to transform the data.
|
||||
- Use trained model from previous task result directly, do not mock or reload model yourself.
|
||||
"""
|
||||
|
||||
# Prompt for using tools of "vision" type
|
||||
# Prompt for taking on "image2webpage" tasks
|
||||
IMAGE2WEBPAGE_PROMPT = """
|
||||
The current task is about converting image into webpage code. please note the following:
|
||||
- Single-Step Code Generation: Execute the entire code generation process in a single step, encompassing HTML, CSS, and JavaScript. Avoid fragmenting the code generation into multiple separate steps to maintain consistency and simplify the development workflow.
|
||||
|
|
@ -69,7 +69,7 @@ class BaseLLM(ABC):
|
|||
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
msg: Union[str, list[dict[str, str]]],
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
|
|
@ -84,7 +84,10 @@ class BaseLLM(ABC):
|
|||
message = []
|
||||
if format_msgs:
|
||||
message.extend(format_msgs)
|
||||
message.append(self._user_msg(msg, images=images))
|
||||
if isinstance(msg, str):
|
||||
message.append(self._user_msg(msg, images=images))
|
||||
else:
|
||||
message.extend(msg)
|
||||
logger.debug(message)
|
||||
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
|
||||
return rsp
|
||||
|
|
@ -102,8 +105,7 @@ class BaseLLM(ABC):
|
|||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3) -> dict:
|
||||
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], timeout=3, **kwargs) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ GENERAL_FUNCTION_SCHEMA = {
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
# tool_choice value for general_function_schema
|
||||
# https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
GENERAL_TOOL_CHOICE = {"type": "function", "function": {"name": "execute"}}
|
||||
|
|
|
|||
|
|
@ -28,8 +28,7 @@ from metagpt.logs import log_llm_stream, logger
|
|||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import CodeParser, decode_image
|
||||
from metagpt.utils.common import CodeParser, decode_image, process_message
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
|
|
@ -145,44 +144,16 @@ class OpenAILLM(BaseLLM):
|
|||
rsp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict:
|
||||
"""Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create"""
|
||||
if "tools" not in kwargs:
|
||||
configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]}
|
||||
kwargs.update(configs)
|
||||
|
||||
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
|
||||
|
||||
def _process_message(self, messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
# 全部转成list
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
# 转成list[dict]
|
||||
processed_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, str):
|
||||
processed_messages.append({"role": "user", "content": msg})
|
||||
elif isinstance(msg, dict):
|
||||
assert set(msg.keys()) == set(["role", "content"])
|
||||
processed_messages.append(msg)
|
||||
elif isinstance(msg, Message):
|
||||
processed_messages.append(msg.to_dict())
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!"
|
||||
)
|
||||
return processed_messages
|
||||
|
||||
async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion:
|
||||
messages = self._process_message(messages)
|
||||
kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs)
|
||||
async def _achat_completion_function(
|
||||
self, messages: list[dict], timeout: int = 3, **chat_configs
|
||||
) -> ChatCompletion:
|
||||
messages = process_message(messages)
|
||||
kwargs = self._cons_kwargs(messages=messages, timeout=timeout, **chat_configs)
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
async def aask_code(self, messages: list[dict], **kwargs) -> dict:
|
||||
async def aask_code(self, messages: list[dict], timeout: int = 3, **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
|
|
@ -192,12 +163,15 @@ class OpenAILLM(BaseLLM):
|
|||
>>> rsp = await llm.aask_code(msg)
|
||||
# -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
if "tools" not in kwargs:
|
||||
configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]}
|
||||
kwargs.update(configs)
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
def _parse_arguments(self, arguments: str) -> dict:
|
||||
"""parse arguments in openai function call"""
|
||||
if "langugae" not in arguments and "code" not in arguments:
|
||||
if "language" not in arguments and "code" not in arguments:
|
||||
logger.warning(f"Not found `code`, `language`, We assume it is pure code:\n {arguments}\n. ")
|
||||
return {"language": "python", "code": arguments}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,52 +1,95 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pydantic import Field
|
||||
import json
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions.mi.ask_review import ReviewConst
|
||||
from metagpt.actions.mi.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.mi.write_analysis_code import (
|
||||
WriteCodeWithoutTools,
|
||||
WriteCodeWithTools,
|
||||
)
|
||||
from metagpt.actions.mi.write_analysis_code import CheckData, WriteCodeWithTools
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.mi.write_analysis_code import DATA_INFO
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message, Task, TaskResult
|
||||
from metagpt.strategy.task_type import TaskType
|
||||
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
REACT_THINK_PROMPT = """
|
||||
# User Requirement
|
||||
{user_requirement}
|
||||
# Context
|
||||
{context}
|
||||
|
||||
Output a json following the format:
|
||||
```json
|
||||
{{
|
||||
"thoughts": str = "Thoughts on current situation, reflect on how you should proceed to fulfill the user requirement",
|
||||
"state": bool = "Decide whether you need to take more actions to complete the user requirement. Return true if you think so. Return false if you think the requirement has been completely fulfilled."
|
||||
}}
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class Interpreter(Role):
|
||||
name: str = "Ivy"
|
||||
profile: str = "Interpreter"
|
||||
auto_run: bool = True
|
||||
use_tools: bool = False
|
||||
use_plan: bool = True
|
||||
use_reflection: bool = False
|
||||
execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True)
|
||||
tools: list[str] = []
|
||||
tools: Union[str, list[str]] = []
|
||||
tool_recommender: ToolRecommender = None
|
||||
react_mode: Literal["plan_and_act", "react"] = "plan_and_act"
|
||||
max_react_loop: int = 10 # used for react mode
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auto_run=True,
|
||||
use_tools=False,
|
||||
tools=[],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(auto_run=auto_run, use_tools=use_tools, tools=tools, **kwargs)
|
||||
self._set_react_mode(react_mode="plan_and_act", auto_run=auto_run, use_tools=use_tools)
|
||||
if use_tools and tools:
|
||||
from metagpt.tools.tool_registry import (
|
||||
validate_tool_names, # import upon use
|
||||
)
|
||||
|
||||
self.tools = validate_tool_names(tools)
|
||||
logger.info(f"will only use {self.tools} as tools")
|
||||
@model_validator(mode="after")
|
||||
def set_plan_and_tool(self) -> "Interpreter":
|
||||
self._set_react_mode(react_mode=self.react_mode, max_react_loop=self.max_react_loop, auto_run=self.auto_run)
|
||||
self.use_plan = (
|
||||
self.react_mode == "plan_and_act"
|
||||
) # create a flag for convenience, overwrite any passed-in value
|
||||
if self.tools:
|
||||
self.tool_recommender = BM25ToolRecommender(tools=self.tools)
|
||||
self.set_actions([WriteCodeWithTools])
|
||||
return self
|
||||
|
||||
@property
|
||||
def working_memory(self):
|
||||
return self.rc.working_memory
|
||||
|
||||
async def _think(self) -> bool:
|
||||
"""Useful in 'react' mode. Use LLM to decide whether and what to do next."""
|
||||
user_requirement = self.get_memories()[0].content
|
||||
context = self.working_memory.get()
|
||||
|
||||
if not context:
|
||||
# just started the run, we need action certainly
|
||||
self.working_memory.add(self.get_memories()[0]) # add user requirement to working memory
|
||||
self._set_state(0)
|
||||
return True
|
||||
|
||||
prompt = REACT_THINK_PROMPT.format(user_requirement=user_requirement, context=context)
|
||||
rsp = await self.llm.aask(prompt)
|
||||
rsp_dict = json.loads(CodeParser.parse_code(block=None, text=rsp))
|
||||
self.working_memory.add(Message(content=rsp_dict["thoughts"], role="assistant"))
|
||||
need_action = rsp_dict["state"]
|
||||
self._set_state(0) if need_action else self._set_state(-1)
|
||||
|
||||
return need_action
|
||||
|
||||
async def _act(self) -> Message:
|
||||
"""Useful in 'react' mode. Return a Message conforming to Role._act interface."""
|
||||
code, _, _ = await self._write_and_exec_code()
|
||||
return Message(content=code, role="assistant", cause_by=WriteCodeWithTools)
|
||||
|
||||
async def _plan_and_act(self) -> Message:
|
||||
await super()._plan_and_act()
|
||||
await self.execute_code.terminate()
|
||||
|
||||
async def _act_on_task(self, current_task: Task) -> TaskResult:
|
||||
"""Useful in 'plan_and_act' mode. Wrap the output in a TaskResult for review and confirmation."""
|
||||
code, result, is_success = await self._write_and_exec_code()
|
||||
task_result = TaskResult(code=code, result=result, is_success=is_success)
|
||||
return task_result
|
||||
|
|
@ -55,14 +98,30 @@ class Interpreter(Role):
|
|||
counter = 0
|
||||
success = False
|
||||
|
||||
# plan info
|
||||
plan_status = self.planner.get_plan_status() if self.use_plan else ""
|
||||
|
||||
# tool info
|
||||
if self.tools:
|
||||
context = (
|
||||
self.working_memory.get()[-1].content if self.working_memory.get() else ""
|
||||
) # thoughts from _think stage in 'react' mode
|
||||
plan = self.planner.plan if self.use_plan else None
|
||||
tool_info = await self.tool_recommender.get_recommended_tool_info(context=context, plan=plan)
|
||||
else:
|
||||
tool_info = ""
|
||||
|
||||
# data info
|
||||
await self._check_data()
|
||||
|
||||
while not success and counter < max_retry:
|
||||
### write code ###
|
||||
code, cause_by = await self._write_code()
|
||||
code, cause_by = await self._write_code(counter, plan_status, tool_info)
|
||||
|
||||
self.working_memory.add(Message(content=code["code"], role="assistant", cause_by=cause_by))
|
||||
self.working_memory.add(Message(content=code, role="assistant", cause_by=cause_by))
|
||||
|
||||
### execute code ###
|
||||
result, success = await self.execute_code.run(**code)
|
||||
result, success = await self.execute_code.run(code)
|
||||
print(result)
|
||||
|
||||
self.working_memory.add(Message(content=result, role="user", cause_by=ExecuteNbCode))
|
||||
|
|
@ -76,14 +135,49 @@ class Interpreter(Role):
|
|||
if ReviewConst.CHANGE_WORDS[0] in review:
|
||||
counter = 0 # redo the task again with help of human suggestions
|
||||
|
||||
return code["code"], result, success
|
||||
return code, result, success
|
||||
|
||||
async def _write_code(self):
|
||||
todo = WriteCodeWithoutTools() if not self.use_tools else WriteCodeWithTools(selected_tools=self.tools)
|
||||
async def _write_code(
|
||||
self,
|
||||
counter,
|
||||
plan_status="",
|
||||
tool_info="",
|
||||
):
|
||||
todo = WriteCodeWithTools()
|
||||
logger.info(f"ready to {todo.name}")
|
||||
use_reflection = counter > 0 and self.use_reflection
|
||||
|
||||
context = self.planner.get_useful_memories()
|
||||
# print(*context, sep="\n***\n")
|
||||
code = await todo.run(context=context, plan=self.planner.plan, temperature=0.0)
|
||||
user_requirement = self.get_memories()[0].content
|
||||
|
||||
code = await todo.run(
|
||||
user_requirement=user_requirement,
|
||||
plan_status=plan_status,
|
||||
tool_info=tool_info,
|
||||
working_memory=self.working_memory.get(),
|
||||
use_reflection=use_reflection,
|
||||
)
|
||||
|
||||
return code, todo
|
||||
|
||||
async def _check_data(self):
|
||||
if (
|
||||
not self.use_plan
|
||||
or not self.planner.plan.get_finished_tasks()
|
||||
or self.planner.plan.current_task.task_type
|
||||
not in [
|
||||
TaskType.DATA_PREPROCESS.type_name,
|
||||
TaskType.FEATURE_ENGINEERING.type_name,
|
||||
TaskType.MODEL_TRAIN.type_name,
|
||||
]
|
||||
):
|
||||
return
|
||||
logger.info("Check updated data")
|
||||
code = await CheckData().run(self.planner.plan)
|
||||
if not code.strip():
|
||||
return
|
||||
success = False
|
||||
result, success = await self.execute_code.run(code)
|
||||
if success:
|
||||
print(result)
|
||||
data_info = DATA_INFO.format(info=result)
|
||||
self.working_memory.add(Message(content=data_info, role="user", cause_by=CheckData))
|
||||
|
|
|
|||
|
|
@ -1,64 +0,0 @@
|
|||
from metagpt.actions.mi.debug_code import DebugCode
|
||||
from metagpt.actions.mi.execute_nb_code import ExecuteNbCode
|
||||
from metagpt.actions.mi.ml_action import UpdateDataColumns, WriteCodeWithToolsML
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.mi.interpreter import Interpreter
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
class MLEngineer(Interpreter):
|
||||
name: str = "Mark"
|
||||
profile: str = "MLEngineer"
|
||||
debug_context: list = []
|
||||
latest_code: str = ""
|
||||
|
||||
async def _write_code(self):
|
||||
if not self.use_tools:
|
||||
return await super()._write_code()
|
||||
|
||||
# In a trial and errors settings, check whether this is our first attempt to tackle the task. If there is no code execution before, then it is.
|
||||
is_first_trial = any_to_str(ExecuteNbCode) not in [msg.cause_by for msg in self.working_memory.get()]
|
||||
|
||||
if is_first_trial:
|
||||
# For the first trial, write task code from scratch
|
||||
column_info = await self._update_data_columns()
|
||||
|
||||
logger.info("Write code with tools")
|
||||
tool_context, code = await WriteCodeWithToolsML(selected_tools=self.tools).run(
|
||||
context=[], # context assembled inside the Action
|
||||
plan=self.planner.plan,
|
||||
column_info=column_info,
|
||||
)
|
||||
self.debug_context = tool_context
|
||||
cause_by = WriteCodeWithToolsML
|
||||
|
||||
else:
|
||||
# Previous trials resulted in error, debug and rewrite the code
|
||||
logger.warning("We got a bug, now start to debug...")
|
||||
code = await DebugCode().run(
|
||||
code=self.latest_code,
|
||||
runtime_result=self.working_memory.get(),
|
||||
context=self.debug_context,
|
||||
)
|
||||
logger.info(f"new code \n{code}")
|
||||
cause_by = DebugCode
|
||||
|
||||
self.latest_code = code["code"]
|
||||
|
||||
return code, cause_by
|
||||
|
||||
async def _update_data_columns(self):
|
||||
current_task = self.planner.plan.current_task
|
||||
if current_task.task_type not in [
|
||||
ToolType.DATA_PREPROCESS.type_name,
|
||||
ToolType.FEATURE_ENGINEERING.type_name,
|
||||
ToolType.MODEL_TRAIN.type_name,
|
||||
]:
|
||||
return ""
|
||||
logger.info("Check columns in updated data")
|
||||
code = await UpdateDataColumns().run(self.planner.plan)
|
||||
success = False
|
||||
result, success = await self.execute_code.run(**code)
|
||||
print(result)
|
||||
return result if success else ""
|
||||
|
|
@ -283,7 +283,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
self.actions.append(i)
|
||||
self.states.append(f"{len(self.actions)}. {action}")
|
||||
|
||||
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True, use_tools: bool = False):
|
||||
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True):
|
||||
"""Set strategy of the Role reacting to observed Message. Variation lies in how
|
||||
this Role elects action to perform during the _think stage, especially if it is capable of multiple Actions.
|
||||
|
||||
|
|
@ -304,9 +304,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
if react_mode == RoleReactMode.REACT:
|
||||
self.rc.max_react_loop = max_react_loop
|
||||
elif react_mode == RoleReactMode.PLAN_AND_ACT:
|
||||
self.planner = Planner(
|
||||
goal=self.goal, working_memory=self.rc.working_memory, auto_run=auto_run, use_tools=use_tools
|
||||
)
|
||||
self.planner = Planner(goal=self.goal, working_memory=self.rc.working_memory, auto_run=auto_run)
|
||||
|
||||
def _watch(self, actions: Iterable[Type[Action]] | Iterable[Action]):
|
||||
"""Watch Actions of interest. Role will select Messages caused by these Actions from its personal message
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from metagpt.actions.mi.write_plan import (
|
|||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.schema import Message, Plan, Task, TaskResult
|
||||
from metagpt.strategy.task_type import TaskType
|
||||
from metagpt.utils.common import remove_comments
|
||||
|
||||
STRUCTURAL_CONTEXT = """
|
||||
## User Requirement
|
||||
|
|
@ -25,6 +27,24 @@ STRUCTURAL_CONTEXT = """
|
|||
{current_task}
|
||||
"""
|
||||
|
||||
PLAN_STATUS = """
|
||||
## Finished Tasks
|
||||
### code
|
||||
```python
|
||||
{code_written}
|
||||
```
|
||||
|
||||
### execution result
|
||||
{task_results}
|
||||
|
||||
## Current Task
|
||||
{current_task}
|
||||
|
||||
## Task Guidance
|
||||
Write complete code for 'Current Task'. And avoid duplicating code from 'Finished Tasks', such as repeated import of packages, reading data, etc.
|
||||
Specifically, {guidance}
|
||||
"""
|
||||
|
||||
|
||||
class Planner(BaseModel):
|
||||
plan: Plan
|
||||
|
|
@ -32,7 +52,6 @@ class Planner(BaseModel):
|
|||
default_factory=Memory
|
||||
) # memory for working on each task, discarded each time a task is done
|
||||
auto_run: bool = False
|
||||
use_tools: bool = False
|
||||
|
||||
def __init__(self, goal: str = "", plan: Plan = None, **kwargs):
|
||||
plan = plan or Plan(goal=goal)
|
||||
|
|
@ -53,7 +72,7 @@ class Planner(BaseModel):
|
|||
plan_confirmed = False
|
||||
while not plan_confirmed:
|
||||
context = self.get_useful_memories()
|
||||
rsp = await WritePlan().run(context, max_tasks=max_tasks, use_tools=self.use_tools)
|
||||
rsp = await WritePlan().run(context, max_tasks=max_tasks)
|
||||
self.working_memory.add(Message(content=rsp, role="assistant", cause_by=WritePlan))
|
||||
|
||||
# precheck plan before asking reviews
|
||||
|
|
@ -137,3 +156,23 @@ class Planner(BaseModel):
|
|||
context_msg = [Message(content=context, role="user")]
|
||||
|
||||
return context_msg + self.working_memory.get()
|
||||
|
||||
def get_plan_status(self) -> str:
|
||||
# prepare components of a plan status
|
||||
finished_tasks = self.plan.get_finished_tasks()
|
||||
code_written = [remove_comments(task.code) for task in finished_tasks]
|
||||
code_written = "\n\n".join(code_written)
|
||||
task_results = [task.result for task in finished_tasks]
|
||||
task_results = "\n\n".join(task_results)
|
||||
task_type_name = self.current_task.task_type.upper()
|
||||
guidance = TaskType[task_type_name].value.guidance if hasattr(TaskType, task_type_name) else ""
|
||||
|
||||
# combine components in a prompt
|
||||
prompt = PLAN_STATUS.format(
|
||||
code_written=code_written,
|
||||
task_results=task_results,
|
||||
current_task=self.current_task.instruction,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
|
|
|||
57
metagpt/strategy/task_type.py
Normal file
57
metagpt/strategy/task_type.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.prompts.task_type import (
|
||||
DATA_PREPROCESS_PROMPT,
|
||||
EDA_PROMPT,
|
||||
FEATURE_ENGINEERING_PROMPT,
|
||||
IMAGE2WEBPAGE_PROMPT,
|
||||
MODEL_EVALUATE_PROMPT,
|
||||
MODEL_TRAIN_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
class TaskTypeDef(BaseModel):
|
||||
name: str
|
||||
desc: str = ""
|
||||
guidance: str = ""
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
EDA = TaskTypeDef(
|
||||
name="eda",
|
||||
desc="For performing exploratory data analysis",
|
||||
guidance=EDA_PROMPT,
|
||||
)
|
||||
DATA_PREPROCESS = TaskTypeDef(
|
||||
name="data_preprocess",
|
||||
desc="For preprocessing dataset in a data analysis or machine learning task ONLY,"
|
||||
"general data operation doesn't fall into this type",
|
||||
guidance=DATA_PREPROCESS_PROMPT,
|
||||
)
|
||||
FEATURE_ENGINEERING = TaskTypeDef(
|
||||
name="feature_engineering",
|
||||
desc="Only for creating new columns for input data.",
|
||||
guidance=FEATURE_ENGINEERING_PROMPT,
|
||||
)
|
||||
MODEL_TRAIN = TaskTypeDef(
|
||||
name="model_train",
|
||||
desc="Only for training model.",
|
||||
guidance=MODEL_TRAIN_PROMPT,
|
||||
)
|
||||
MODEL_EVALUATE = TaskTypeDef(
|
||||
name="model_evaluate",
|
||||
desc="Only for evaluating model.",
|
||||
guidance=MODEL_EVALUATE_PROMPT,
|
||||
)
|
||||
IMAGE2WEBPAGE = TaskTypeDef(
|
||||
name="image2webpage",
|
||||
desc="For converting image into webpage code.",
|
||||
guidance=IMAGE2WEBPAGE_PROMPT,
|
||||
)
|
||||
OTHER = TaskTypeDef(name="other", desc="Any tasks not in the defined categories")
|
||||
|
||||
@property
|
||||
def type_name(self):
|
||||
return self.value.name
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
|
@ -16,9 +17,8 @@ from sklearn.preprocessing import (
|
|||
)
|
||||
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
TOOL_TYPE = ToolType.DATA_PREPROCESS.type_name
|
||||
TAGS = ["data preprocessing", "machine learning"]
|
||||
|
||||
|
||||
class MLProcess:
|
||||
|
|
@ -85,20 +85,22 @@ class DataPreprocessTool(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class FillMissingValue(DataPreprocessTool):
|
||||
"""
|
||||
Completing missing values with simple strategies.
|
||||
"""
|
||||
|
||||
def __init__(self, features: list, strategy: str = "mean", fill_value=None):
|
||||
def __init__(
|
||||
self, features: list, strategy: Literal["mean", "median", "most_frequent", "constant"] = "mean", fill_value=None
|
||||
):
|
||||
"""
|
||||
Initialize self.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only
|
||||
be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
|
||||
strategy (Literal["mean", "median", "most_frequent", "constant"], optional): The imputation strategy, notice 'mean' and 'median' can only
|
||||
be used for numeric features. Defaults to 'mean'.
|
||||
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
|
@ -106,7 +108,7 @@ class FillMissingValue(DataPreprocessTool):
|
|||
self.model = SimpleImputer(strategy=strategy, fill_value=fill_value)
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class MinMaxScale(DataPreprocessTool):
|
||||
"""
|
||||
Transform features by scaling each feature to a range, which is (0, 1).
|
||||
|
|
@ -117,7 +119,7 @@ class MinMaxScale(DataPreprocessTool):
|
|||
self.model = MinMaxScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class StandardScale(DataPreprocessTool):
|
||||
"""
|
||||
Standardize features by removing the mean and scaling to unit variance.
|
||||
|
|
@ -128,7 +130,7 @@ class StandardScale(DataPreprocessTool):
|
|||
self.model = StandardScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class MaxAbsScale(DataPreprocessTool):
|
||||
"""
|
||||
Scale each feature by its maximum absolute value.
|
||||
|
|
@ -139,7 +141,7 @@ class MaxAbsScale(DataPreprocessTool):
|
|||
self.model = MaxAbsScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class RobustScale(DataPreprocessTool):
|
||||
"""
|
||||
Apply the RobustScaler to scale features using statistics that are robust to outliers.
|
||||
|
|
@ -150,7 +152,7 @@ class RobustScale(DataPreprocessTool):
|
|||
self.model = RobustScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class OrdinalEncode(DataPreprocessTool):
|
||||
"""
|
||||
Encode categorical features as ordinal integers.
|
||||
|
|
@ -161,7 +163,7 @@ class OrdinalEncode(DataPreprocessTool):
|
|||
self.model = OrdinalEncoder()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class OneHotEncode(DataPreprocessTool):
|
||||
"""
|
||||
Apply one-hot encoding to specified categorical columns, the original columns will be dropped.
|
||||
|
|
@ -180,7 +182,7 @@ class OneHotEncode(DataPreprocessTool):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class LabelEncode(DataPreprocessTool):
|
||||
"""
|
||||
Apply label encoding to specified categorical columns in-place.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from imap_tools import MailBox
|
||||
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
# Define a dictionary mapping email domains to their IMAP server addresses
|
||||
IMAP_SERVERS = {
|
||||
|
|
@ -24,7 +23,7 @@ IMAP_SERVERS = {
|
|||
}
|
||||
|
||||
|
||||
@register_tool(tool_type=ToolType.EMAIL_LOGIN.type_name)
|
||||
@register_tool()
|
||||
def email_login_imap(email_address, email_password):
|
||||
"""
|
||||
Use imap_tools package to log in to your email (the email that supports IMAP protocol) to verify and return the account object.
|
||||
|
|
|
|||
|
|
@ -19,12 +19,11 @@ from sklearn.preprocessing import KBinsDiscretizer, PolynomialFeatures
|
|||
|
||||
from metagpt.tools.libs.data_preprocess import MLProcess
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
TOOL_TYPE = ToolType.FEATURE_ENGINEERING.type_name
|
||||
TAGS = ["feature engineering", "machine learning"]
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class PolynomialExpansion(MLProcess):
|
||||
"""
|
||||
Add polynomial and interaction features from selected numeric columns to input DataFrame.
|
||||
|
|
@ -67,7 +66,7 @@ class PolynomialExpansion(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class CatCount(MLProcess):
|
||||
"""
|
||||
Add value counts of a categorical column as new feature.
|
||||
|
|
@ -92,7 +91,7 @@ class CatCount(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class TargetMeanEncoder(MLProcess):
|
||||
"""
|
||||
Encode a categorical column by the mean of the label column, and adds the result as a new feature.
|
||||
|
|
@ -119,7 +118,7 @@ class TargetMeanEncoder(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class KFoldTargetMeanEncoder(MLProcess):
|
||||
"""
|
||||
Add a new feature to the DataFrame by k-fold mean encoding of a categorical column using the label column.
|
||||
|
|
@ -159,7 +158,7 @@ class KFoldTargetMeanEncoder(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class CatCross(MLProcess):
|
||||
"""
|
||||
Add pairwise crossed features and convert them to numerical features.
|
||||
|
|
@ -216,7 +215,7 @@ class CatCross(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class GroupStat(MLProcess):
|
||||
"""
|
||||
Aggregate specified column in a DataFrame grouped by another column, adding new features named '<agg_col>_<agg_func>_by_<group_col>'.
|
||||
|
|
@ -248,7 +247,7 @@ class GroupStat(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class SplitBins(MLProcess):
|
||||
"""
|
||||
Inplace binning of continuous data into intervals, returning integer-encoded bin identifiers directly.
|
||||
|
|
@ -276,7 +275,7 @@ class SplitBins(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
# @register_tool(tool_type=TOOL_TYPE)
|
||||
# @register_tool(tags=TAGS)
|
||||
class ExtractTimeComps(MLProcess):
|
||||
"""
|
||||
Extract time components from a datetime column and add them as new features.
|
||||
|
|
@ -316,7 +315,7 @@ class ExtractTimeComps(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class GeneralSelection(MLProcess):
|
||||
"""
|
||||
Drop all nan feats and feats with only one unique value.
|
||||
|
|
@ -349,7 +348,7 @@ class GeneralSelection(MLProcess):
|
|||
|
||||
|
||||
# skip for now because lgb is needed
|
||||
# @register_tool(tool_type=TOOL_TYPE)
|
||||
# @register_tool(tags=TAGS)
|
||||
class TreeBasedSelection(MLProcess):
|
||||
"""
|
||||
Select features based on tree-based model and remove features with low importance.
|
||||
|
|
@ -403,7 +402,7 @@ class TreeBasedSelection(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class VarianceBasedSelection(MLProcess):
|
||||
"""
|
||||
Select features based on variance and remove features with low variance.
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from PIL import Image, PngImagePlugin
|
|||
from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
payload = {
|
||||
"prompt": "",
|
||||
|
|
@ -55,7 +54,7 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
|||
|
||||
|
||||
@register_tool(
|
||||
tool_type=ToolType.STABLE_DIFFUSION.type_name,
|
||||
tags=["text2image", "multimodal"],
|
||||
include_functions=["__init__", "simple_run_t2i", "run_t2i", "construct_payload", "save"],
|
||||
)
|
||||
class SDEngine:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
from metagpt.tools.web_browser_engine_playwright import PlaywrightWrapper
|
||||
|
||||
|
||||
@register_tool(tool_type=ToolType.WEBSCRAPING.type_name)
|
||||
@register_tool(tags=["web scraping", "web"])
|
||||
async def scrape_web_playwright(url):
|
||||
"""
|
||||
Asynchronously Scrape and save the HTML structure and inner text content of a web page using Playwright.
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ import inspect
|
|||
|
||||
from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces
|
||||
|
||||
PARSER = GoogleDocstringParser
|
||||
|
||||
def convert_code_to_tool_schema(obj, include: list[str] = []):
|
||||
|
||||
def convert_code_to_tool_schema(obj, include: list[str] = None):
|
||||
docstring = inspect.getdoc(obj)
|
||||
assert docstring, "no docstring found for the objects, skip registering"
|
||||
|
||||
|
|
@ -23,54 +25,31 @@ def convert_code_to_tool_schema(obj, include: list[str] = []):
|
|||
return schema
|
||||
|
||||
|
||||
def function_docstring_to_schema(fn_obj, docstring):
|
||||
def function_docstring_to_schema(fn_obj, docstring) -> dict:
|
||||
"""
|
||||
Converts a function's docstring into a schema dictionary.
|
||||
|
||||
Args:
|
||||
fn_obj: The function object.
|
||||
docstring: The docstring of the function.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the schema of the function's docstring.
|
||||
The dictionary contains the following keys:
|
||||
- 'type': The type of the function ('function' or 'async_function').
|
||||
- 'description': The first section of the docstring describing the function overall. Provided to LLMs for both recommending and using the function.
|
||||
- 'signature': The signature of the function, which helps LLMs understand how to call the function.
|
||||
- 'parameters': Docstring section describing parameters including args and returns, served as extra details for LLM perception.
|
||||
"""
|
||||
signature = inspect.signature(fn_obj)
|
||||
|
||||
docstring = remove_spaces(docstring)
|
||||
|
||||
overall_desc, param_desc = PARSER.parse(docstring)
|
||||
|
||||
function_type = "function" if not inspect.iscoroutinefunction(fn_obj) else "async_function"
|
||||
return {"type": function_type, **docstring_to_schema(docstring)}
|
||||
|
||||
|
||||
def docstring_to_schema(docstring: str):
|
||||
if docstring is None:
|
||||
return {}
|
||||
|
||||
parser = GoogleDocstringParser(docstring=docstring)
|
||||
|
||||
# 匹配简介部分
|
||||
description = parser.parse_desc()
|
||||
|
||||
# 匹配Args部分
|
||||
params = parser.parse_params()
|
||||
parameter_schema = {"properties": {}, "required": []}
|
||||
for param in params:
|
||||
param_name, param_type, param_desc = param
|
||||
# check required or optional
|
||||
is_optional, param_type = parser.check_and_parse_optional(param_type)
|
||||
if not is_optional:
|
||||
parameter_schema["required"].append(param_name)
|
||||
# type and desc
|
||||
param_dict = {"type": param_type, "description": remove_spaces(param_desc)}
|
||||
# match Default for optional args
|
||||
has_default_val, default_val = parser.check_and_parse_default_value(param_desc)
|
||||
if has_default_val:
|
||||
param_dict["default"] = default_val
|
||||
# match Enum
|
||||
has_enum, enum_vals = parser.check_and_parse_enum(param_desc)
|
||||
if has_enum:
|
||||
param_dict["enum"] = enum_vals
|
||||
# add to parameter schema
|
||||
parameter_schema["properties"].update({param_name: param_dict})
|
||||
|
||||
# 匹配Returns部分
|
||||
returns = parser.parse_returns()
|
||||
|
||||
# 构建YAML字典
|
||||
schema = {
|
||||
"description": description,
|
||||
"parameters": parameter_schema,
|
||||
}
|
||||
if returns:
|
||||
schema["returns"] = [{"type": ret[0], "description": remove_spaces(ret[1])} for ret in returns]
|
||||
|
||||
return schema
|
||||
return {"type": function_type, "description": overall_desc, "signature": str(signature), "parameters": param_desc}
|
||||
|
||||
|
||||
def get_class_method_docstring(cls, method_name):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,6 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ToolTypeDef(BaseModel):
|
||||
name: str
|
||||
desc: str = ""
|
||||
usage_prompt: str = ""
|
||||
|
||||
|
||||
class ToolSchema(BaseModel):
|
||||
description: str
|
||||
|
||||
|
|
@ -16,3 +10,4 @@ class Tool(BaseModel):
|
|||
path: str
|
||||
schemas: dict = {}
|
||||
code: str = ""
|
||||
tags: list[str] = []
|
||||
|
|
|
|||
197
metagpt/tools/tool_recommend.py
Normal file
197
metagpt/tools/tool_recommend.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import jieba
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, field_validator
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Plan
|
||||
from metagpt.tools import TOOL_REGISTRY
|
||||
from metagpt.tools.tool_data_type import Tool
|
||||
from metagpt.tools.tool_registry import validate_tool_names
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
TOOL_INFO_PROMPT = """
|
||||
## Capabilities
|
||||
- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python class or function.
|
||||
- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
|
||||
|
||||
## Available Tools:
|
||||
Each tool is described in JSON format. When you call a tool, import the tool from its path first.
|
||||
{tool_schemas}
|
||||
"""
|
||||
|
||||
|
||||
TOOL_RECOMMENDATION_PROMPT = """
|
||||
## User Requirement:
|
||||
{current_task}
|
||||
|
||||
## Task
|
||||
Recommend up to {topk} tools from 'Available Tools' that can help solve the 'User Requirement'.
|
||||
|
||||
## Available Tools:
|
||||
{available_tools}
|
||||
|
||||
## Tool Selection and Instructions:
|
||||
- Select tools most relevant to completing the 'User Requirement'.
|
||||
- If you believe that no tools are suitable, indicate with an empty list.
|
||||
- Only list the names of the tools, not the full schema of each tool.
|
||||
- Ensure selected tools are listed in 'Available Tools'.
|
||||
- Output a json list of tool names:
|
||||
```json
|
||||
["tool_name1", "tool_name2", ...]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class RecommendTool(Action):
|
||||
async def run(self, prompt):
|
||||
return await self._aask(prompt)
|
||||
|
||||
|
||||
class ToolRecommender(BaseModel):
|
||||
"""
|
||||
The default ToolRecommender:
|
||||
1. Recall: If plan exists, use exact match between task type and tool type to recall tools;
|
||||
If plan doesn't exist (e.g. we use ReAct), return all user-specified tools;
|
||||
2. Rank: Use LLM to select final candidates from recalled set.
|
||||
"""
|
||||
|
||||
tools: dict[str, Tool] = {}
|
||||
force: bool = False
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def validate_tools(cls, v: list[str]) -> dict[str, Tool]:
|
||||
if v == ["<all>"]:
|
||||
return TOOL_REGISTRY.get_all_tools()
|
||||
else:
|
||||
return validate_tool_names(v)
|
||||
|
||||
async def recommend_tools(
|
||||
self, context: str = "", plan: Plan = None, recall_topk: int = 20, topk: int = 5
|
||||
) -> list[Tool]:
|
||||
"""
|
||||
Recommends a list of tools based on the given context and plan. The recommendation process includes two stages: recall from a large pool and rank the recalled tools to select the final set.
|
||||
|
||||
Args:
|
||||
context (str): The context for tool recommendation.
|
||||
plan (Plan): The plan for tool recommendation.
|
||||
recall_topk (int): The number of tools to recall in the initial step.
|
||||
topk (int): The number of tools to return after rank as final recommendations.
|
||||
|
||||
Returns:
|
||||
list[Tool]: A list of recommended tools.
|
||||
"""
|
||||
|
||||
if not self.tools:
|
||||
return []
|
||||
|
||||
if self.force or (not context and not plan):
|
||||
# directly use what users have specified as result for forced recommendation;
|
||||
# directly use the whole set if there is no useful information
|
||||
return list(self.tools.values())
|
||||
|
||||
recalled_tools = await self.recall_tools(context=context, plan=plan, topk=recall_topk)
|
||||
if not recalled_tools:
|
||||
return []
|
||||
|
||||
ranked_tools = await self.rank_tools(recalled_tools=recalled_tools, context=context, plan=plan, topk=topk)
|
||||
|
||||
logger.info(f"Recommended tools: \n{[tool.name for tool in ranked_tools]}")
|
||||
|
||||
return ranked_tools
|
||||
|
||||
async def get_recommended_tool_info(self, **kwargs) -> str:
|
||||
"""
|
||||
Wrap recommended tools with their info in a string, which can be used directly in a prompt.
|
||||
"""
|
||||
recommended_tools = await self.recommend_tools(**kwargs)
|
||||
if not recommended_tools:
|
||||
return ""
|
||||
tool_schemas = {tool.name: tool.schemas for tool in recommended_tools}
|
||||
return TOOL_INFO_PROMPT.format(tool_schemas=tool_schemas)
|
||||
|
||||
async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]:
|
||||
"""
|
||||
Retrieves a list of relevant tools from a large pool, based on the given context and plan.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def rank_tools(
|
||||
self, recalled_tools: list[Tool], context: str = "", plan: Plan = None, topk: int = 5
|
||||
) -> list[Tool]:
|
||||
"""
|
||||
Default rank methods for a ToolRecommender. Use LLM to rank the recalled tools based on the given context, plan, and topk value.
|
||||
"""
|
||||
current_task = plan.current_task.instruction if plan else context
|
||||
|
||||
available_tools = {tool.name: tool.schemas["description"] for tool in recalled_tools}
|
||||
prompt = TOOL_RECOMMENDATION_PROMPT.format(
|
||||
current_task=current_task,
|
||||
available_tools=available_tools,
|
||||
topk=topk,
|
||||
)
|
||||
rsp = await RecommendTool().run(prompt)
|
||||
rsp = CodeParser.parse_code(block=None, text=rsp)
|
||||
ranked_tools = json.loads(rsp)
|
||||
|
||||
valid_tools = validate_tool_names(ranked_tools)
|
||||
|
||||
return list(valid_tools.values())[:topk]
|
||||
|
||||
|
||||
class BM25ToolRecommender(ToolRecommender):
|
||||
"""
|
||||
A ToolRecommender using BM25 at the recall stage:
|
||||
1. Recall: Querying tool descriptions with task instruction if plan exists. Otherwise, return all user-specified tools;
|
||||
2. Rank: LLM rank, the same as the default ToolRecommender.
|
||||
"""
|
||||
|
||||
bm25: Any = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_corpus()
|
||||
|
||||
def _init_corpus(self):
|
||||
corpus = [f"{tool.name} {tool.tags}: {tool.schemas['description']}" for tool in self.tools.values()]
|
||||
tokenized_corpus = [self._tokenize(doc) for doc in corpus]
|
||||
self.bm25 = BM25Okapi(tokenized_corpus)
|
||||
|
||||
def _tokenize(self, text):
|
||||
return jieba.lcut(text) # FIXME: needs more sophisticated tokenization
|
||||
|
||||
async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]:
|
||||
query = plan.current_task.instruction if plan else context
|
||||
|
||||
query_tokens = self._tokenize(query)
|
||||
doc_scores = self.bm25.get_scores(query_tokens)
|
||||
top_indexes = np.argsort(doc_scores)[::-1][:topk]
|
||||
recalled_tools = [list(self.tools.values())[index] for index in top_indexes]
|
||||
|
||||
logger.info(
|
||||
f"Recalled tools: \n{[tool.name for tool in recalled_tools]}; Scores: {[doc_scores[index] for index in top_indexes]}"
|
||||
)
|
||||
|
||||
return recalled_tools
|
||||
|
||||
|
||||
class EmbeddingToolRecommender(ToolRecommender):
|
||||
"""
|
||||
NOTE: To be implemented.
|
||||
A ToolRecommender using embeddings at the recall stage:
|
||||
1. Recall: Use embeddings to calculate the similarity between query and tool info;
|
||||
2. Rank: LLM rank, the same as the default ToolRecommender.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]:
|
||||
pass
|
||||
|
|
@ -9,28 +9,21 @@ from __future__ import annotations
|
|||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import TOOL_SCHEMA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema
|
||||
from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolTypeDef
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
from metagpt.tools.tool_data_type import Tool, ToolSchema
|
||||
|
||||
|
||||
class ToolRegistry(BaseModel):
|
||||
tools: dict = {}
|
||||
tool_types: dict = {}
|
||||
tools_by_types: dict = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...}
|
||||
|
||||
@field_validator("tool_types", mode="before")
|
||||
@classmethod
|
||||
def init_tool_types(cls, tool_types: ToolType):
|
||||
return {tool_type.type_name: tool_type.value for tool_type in tool_types}
|
||||
tools_by_tags: dict = defaultdict(dict) # two-layer k-v, {tag: {tool_name: {...}, ...}, ...}
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
|
|
@ -38,25 +31,15 @@ class ToolRegistry(BaseModel):
|
|||
tool_path,
|
||||
schema_path="",
|
||||
tool_code="",
|
||||
tool_type="other",
|
||||
tags=None,
|
||||
tool_source_object=None,
|
||||
include_functions=[],
|
||||
include_functions=None,
|
||||
verbose=False,
|
||||
):
|
||||
if self.has_tool(tool_name):
|
||||
return
|
||||
|
||||
if tool_type not in self.tool_types:
|
||||
# register new tool type on the fly
|
||||
logger.warning(
|
||||
f"{tool_type} not previously defined, will create a temporary tool type with just a name. This tool type is only effective during this runtime. You may consider add this tool type with more configs permanently at metagpt.tools.tool_type"
|
||||
)
|
||||
temp_tool_type_obj = ToolTypeDef(name=tool_type)
|
||||
self.tool_types[tool_type] = temp_tool_type_obj
|
||||
if verbose:
|
||||
logger.info(f"tool type {tool_type} registered")
|
||||
|
||||
schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml"
|
||||
schema_path = schema_path or TOOL_SCHEMA_PATH / f"{tool_name}.yml"
|
||||
|
||||
schemas = make_schema(tool_source_object, include_functions, schema_path)
|
||||
|
||||
|
|
@ -71,10 +54,11 @@ class ToolRegistry(BaseModel):
|
|||
# logger.warning(
|
||||
# f"{tool_name} schema not conforms to required format, but will be used anyway. Mismatch: {e}"
|
||||
# )
|
||||
|
||||
tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code)
|
||||
tags = tags or []
|
||||
tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code, tags=tags)
|
||||
self.tools[tool_name] = tool
|
||||
self.tools_by_types[tool_type][tool_name] = tool
|
||||
for tag in tags:
|
||||
self.tools_by_tags[tag].update({tool_name: tool})
|
||||
if verbose:
|
||||
logger.info(f"{tool_name} registered")
|
||||
logger.info(f"schema made at {str(schema_path)}, can be used for checking")
|
||||
|
|
@ -85,31 +69,32 @@ class ToolRegistry(BaseModel):
|
|||
def get_tool(self, key) -> Tool:
|
||||
return self.tools.get(key)
|
||||
|
||||
def get_tools_by_type(self, key) -> dict[str, Tool]:
|
||||
return self.tools_by_types.get(key, {})
|
||||
def get_tools_by_tag(self, key) -> dict[str, Tool]:
|
||||
return self.tools_by_tags.get(key, {})
|
||||
|
||||
def has_tool_type(self, key) -> bool:
|
||||
return key in self.tool_types
|
||||
def get_all_tools(self) -> dict[str, Tool]:
|
||||
return self.tools
|
||||
|
||||
def get_tool_type(self, key) -> ToolType:
|
||||
return self.tool_types.get(key)
|
||||
def has_tool_tag(self, key) -> bool:
|
||||
return key in self.tools_by_tags
|
||||
|
||||
def get_tool_types(self) -> dict[str, ToolType]:
|
||||
return self.tool_types
|
||||
def get_tool_tags(self) -> list[str]:
|
||||
return list(self.tools_by_tags.keys())
|
||||
|
||||
|
||||
# Registry instance
|
||||
TOOL_REGISTRY = ToolRegistry(tool_types=ToolType)
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
|
||||
def register_tool(tool_type: str = "other", schema_path: str = "", **kwargs):
|
||||
def register_tool(tags: list[str] = None, schema_path: str = "", **kwargs):
|
||||
"""register a tool to registry"""
|
||||
|
||||
def decorator(cls):
|
||||
# Get the file path where the function / class is defined and the source code
|
||||
file_path = inspect.getfile(cls)
|
||||
if "metagpt" in file_path:
|
||||
file_path = re.search("metagpt.+", file_path).group(0)
|
||||
# split to handle ../metagpt/metagpt/tools/... where only metapgt/tools/... is needed
|
||||
file_path = "metagpt" + file_path.split("metagpt")[-1]
|
||||
source_code = inspect.getsource(cls)
|
||||
|
||||
TOOL_REGISTRY.register_tool(
|
||||
|
|
@ -117,7 +102,7 @@ def register_tool(tool_type: str = "other", schema_path: str = "", **kwargs):
|
|||
tool_path=file_path,
|
||||
schema_path=schema_path,
|
||||
tool_code=source_code,
|
||||
tool_type=tool_type,
|
||||
tags=tags,
|
||||
tool_source_object=cls,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -142,14 +127,15 @@ def make_schema(tool_source_object, include, path):
|
|||
return schema
|
||||
|
||||
|
||||
def validate_tool_names(tools: list[str], return_tool_object=False) -> list[str]:
|
||||
valid_tools = []
|
||||
for tool_name in tools:
|
||||
if not TOOL_REGISTRY.has_tool(tool_name):
|
||||
logger.warning(
|
||||
f"Specified tool {tool_name} not found and was skipped. Check if you have registered it properly"
|
||||
)
|
||||
def validate_tool_names(tools: Union[list[str], str]) -> str:
|
||||
assert isinstance(tools, list), "tools must be a list of str"
|
||||
valid_tools = {}
|
||||
for key in tools:
|
||||
# one can define either tool names or tool type names, take union to get the whole set
|
||||
if TOOL_REGISTRY.has_tool(key):
|
||||
valid_tools.update({key: TOOL_REGISTRY.get_tool(key)})
|
||||
elif TOOL_REGISTRY.has_tool_tag(key):
|
||||
valid_tools.update(TOOL_REGISTRY.get_tools_by_tag(key))
|
||||
else:
|
||||
valid_tool = TOOL_REGISTRY.get_tool(tool_name) if return_tool_object else tool_name
|
||||
valid_tools.append(valid_tool)
|
||||
logger.warning(f"invalid tool name or tool type name: {key}, skipped")
|
||||
return valid_tools
|
||||
|
|
|
|||
|
|
@ -1,64 +0,0 @@
|
|||
from enum import Enum
|
||||
|
||||
from metagpt.prompts.tool_types import (
|
||||
DATA_PREPROCESS_PROMPT,
|
||||
EDA_PROMPT,
|
||||
FEATURE_ENGINEERING_PROMPT,
|
||||
IMAGE2WEBPAGE_PROMPT,
|
||||
MODEL_EVALUATE_PROMPT,
|
||||
MODEL_TRAIN_PROMPT,
|
||||
)
|
||||
from metagpt.tools.tool_data_type import ToolTypeDef
|
||||
|
||||
|
||||
class ToolType(Enum):
|
||||
EDA = ToolTypeDef(
|
||||
name="eda",
|
||||
desc="For performing exploratory data analysis",
|
||||
usage_prompt=EDA_PROMPT,
|
||||
)
|
||||
DATA_PREPROCESS = ToolTypeDef(
|
||||
name="data_preprocess",
|
||||
desc="Only for changing value inplace.",
|
||||
usage_prompt=DATA_PREPROCESS_PROMPT,
|
||||
)
|
||||
EMAIL_LOGIN = ToolTypeDef(
|
||||
name="email_login",
|
||||
desc="For logging to an email.",
|
||||
)
|
||||
FEATURE_ENGINEERING = ToolTypeDef(
|
||||
name="feature_engineering",
|
||||
desc="Only for creating new columns for input data.",
|
||||
usage_prompt=FEATURE_ENGINEERING_PROMPT,
|
||||
)
|
||||
MODEL_TRAIN = ToolTypeDef(
|
||||
name="model_train",
|
||||
desc="Only for training model.",
|
||||
usage_prompt=MODEL_TRAIN_PROMPT,
|
||||
)
|
||||
MODEL_EVALUATE = ToolTypeDef(
|
||||
name="model_evaluate",
|
||||
desc="Only for evaluating model.",
|
||||
usage_prompt=MODEL_EVALUATE_PROMPT,
|
||||
)
|
||||
STABLE_DIFFUSION = ToolTypeDef(
|
||||
name="stable_diffusion",
|
||||
desc="Related to text2image, image2image using stable diffusion model.",
|
||||
)
|
||||
IMAGE2WEBPAGE = ToolTypeDef(
|
||||
name="image2webpage",
|
||||
desc="For converting image into webpage code.",
|
||||
usage_prompt=IMAGE2WEBPAGE_PROMPT,
|
||||
)
|
||||
WEBSCRAPING = ToolTypeDef(
|
||||
name="web_scraping",
|
||||
desc="For scraping data from web pages.",
|
||||
)
|
||||
OTHER = ToolTypeDef(name="other", desc="Any tools not in the defined categories")
|
||||
|
||||
def __missing__(self, key):
|
||||
return self.OTHER
|
||||
|
||||
@property
|
||||
def type_name(self):
|
||||
return self.value.name
|
||||
|
|
@ -361,16 +361,6 @@ def parse_recipient(text):
|
|||
return ""
|
||||
|
||||
|
||||
def create_func_call_config(func_schema: dict) -> dict:
|
||||
"""Create new function call config"""
|
||||
tools = [{"type": "function", "function": func_schema}]
|
||||
tool_choice = {"type": "function", "function": {"name": func_schema["name"]}}
|
||||
return {
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
}
|
||||
|
||||
|
||||
def remove_comments(code_str: str) -> str:
|
||||
"""Remove comments from code."""
|
||||
pattern = r"(\".*?\"|\'.*?\')|(\#.*?$)"
|
||||
|
|
@ -676,3 +666,26 @@ def decode_image(img_url_or_b64: str) -> Image:
|
|||
img_data = BytesIO(base64.b64decode(b64_data))
|
||||
img = Image.open(img_data)
|
||||
return img
|
||||
|
||||
|
||||
def process_message(messages: Union[str, Message, list[dict], list[Message], list[str]]) -> list[dict]:
|
||||
"""convert messages to list[dict]."""
|
||||
from metagpt.schema import Message
|
||||
|
||||
# 全部转成list
|
||||
if not isinstance(messages, list):
|
||||
messages = [messages]
|
||||
|
||||
# 转成list[dict]
|
||||
processed_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, str):
|
||||
processed_messages.append({"role": "user", "content": msg})
|
||||
elif isinstance(msg, dict):
|
||||
assert set(msg.keys()) == set(["role", "content"])
|
||||
processed_messages.append(msg)
|
||||
elif isinstance(msg, Message):
|
||||
processed_messages.append(msg.to_dict())
|
||||
else:
|
||||
raise ValueError(f"Only support message type are: str, Message, dict, but got {type(messages).__name__}!")
|
||||
return processed_messages
|
||||
|
|
|
|||
|
|
@ -1,45 +1,23 @@
|
|||
import re
|
||||
from typing import Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def remove_spaces(text):
|
||||
return re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
|
||||
class DocstringParser(BaseModel):
|
||||
docstring: str
|
||||
class DocstringParser:
|
||||
@staticmethod
|
||||
def parse(docstring: str) -> Tuple[str, str]:
|
||||
"""Parse the docstring and return the overall description and the parameter description.
|
||||
|
||||
def parse_desc(self) -> str:
|
||||
"""Parse and return the description from the docstring."""
|
||||
|
||||
def parse_params(self) -> list[Tuple[str, str, str]]:
|
||||
"""Parse and return the parameters from the docstring.
|
||||
Args:
|
||||
docstring (str): The docstring to be parsed.
|
||||
|
||||
Returns:
|
||||
list[Tuple[str, str, str]]: A list of input paramter info. Each info is a triple of (param name, param type, param description)
|
||||
Tuple[str, str]: A tuple of (overall description, parameter description)
|
||||
"""
|
||||
|
||||
def parse_returns(self) -> list[Tuple[str, str]]:
|
||||
"""Parse and return the output information from the docstring.
|
||||
|
||||
Returns:
|
||||
list[Tuple[str, str]]: A list of output info. Each info is a tuple of (return type, return description)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def check_and_parse_optional(param_type: str) -> Tuple[bool, str]:
|
||||
"""Check if a parameter is optional and return a processed param_type rid of the optionality info if so"""
|
||||
|
||||
@staticmethod
|
||||
def check_and_parse_default_value(param_desc: str) -> Tuple[bool, str]:
|
||||
"""Check if a parameter has a default value and return the default value if so"""
|
||||
|
||||
@staticmethod
|
||||
def check_and_parse_enum(param_desc: str) -> Tuple[bool, str]:
|
||||
"""Check if a parameter description includes an enum and return enum values if so"""
|
||||
|
||||
|
||||
class reSTDocstringParser(DocstringParser):
|
||||
"""A parser for reStructuredText (reST) docstring"""
|
||||
|
|
@ -48,40 +26,18 @@ class reSTDocstringParser(DocstringParser):
|
|||
class GoogleDocstringParser(DocstringParser):
|
||||
"""A parser for Google-stype docstring"""
|
||||
|
||||
docstring: str
|
||||
|
||||
def parse_desc(self) -> str:
|
||||
description_match = re.search(r"^(.*?)(?:Args:|Returns:|Raises:|$)", self.docstring, re.DOTALL)
|
||||
description = remove_spaces(description_match.group(1)) if description_match else ""
|
||||
return description
|
||||
|
||||
def parse_params(self) -> list[Tuple[str, str, str]]:
|
||||
args_match = re.search(r"Args:\s*(.*?)(?:Returns:|Raises:|$)", self.docstring, re.DOTALL)
|
||||
_args = args_match.group(1).strip() if args_match else ""
|
||||
# variable_pattern = re.compile(r"(\w+)\s*\((.*?)\):\s*(.*)")
|
||||
variable_pattern = re.compile(
|
||||
r"(\w+)\s*\((.*?)\):\s*(.*?)(?=\n\s*\w+\s*\(|\Z)", re.DOTALL
|
||||
) # (?=\n\w+\s*\(|\Z) is to assert that what follows is either the start of the next parameter (indicated by a newline, some word characters, and an opening parenthesis) or the end of the string (\Z).
|
||||
params = variable_pattern.findall(_args)
|
||||
return params
|
||||
|
||||
def parse_returns(self) -> list[Tuple[str, str]]:
|
||||
returns_match = re.search(r"Returns:\s*(.*?)(?:Raises:|$)", self.docstring, re.DOTALL)
|
||||
returns = returns_match.group(1).strip() if returns_match else ""
|
||||
return_pattern = re.compile(r"^(.*)\s*:\s*(.*)$")
|
||||
returns = return_pattern.findall(returns)
|
||||
return returns
|
||||
|
||||
@staticmethod
|
||||
def check_and_parse_optional(param_type: str) -> Tuple[bool, str]:
|
||||
return "optional" in param_type, param_type.replace(", optional", "")
|
||||
def parse(docstring: str) -> Tuple[str, str]:
|
||||
if not docstring:
|
||||
return "", ""
|
||||
|
||||
@staticmethod
|
||||
def check_and_parse_default_value(param_desc: str) -> Tuple[bool, str]:
|
||||
default_val = re.search(r"Defaults to (.+?)\.", param_desc)
|
||||
return (True, default_val.group(1)) if default_val else (False, "")
|
||||
docstring = remove_spaces(docstring)
|
||||
|
||||
@staticmethod
|
||||
def check_and_parse_enum(param_desc: str) -> Tuple[bool, str]:
|
||||
enum_val = re.search(r"Enum: \[(.+?)\]", param_desc)
|
||||
return (True, [e.strip() for e in enum_val.group(1).split(",")]) if enum_val else (False, [])
|
||||
if "Args:" in docstring:
|
||||
overall_desc, param_desc = docstring.split("Args:")
|
||||
param_desc = "Args:" + param_desc
|
||||
else:
|
||||
overall_desc = docstring
|
||||
param_desc = ""
|
||||
|
||||
return overall_desc, param_desc
|
||||
|
|
|
|||
|
|
@ -1,44 +1,8 @@
|
|||
from typing import Literal, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema, docstring_to_schema
|
||||
|
||||
|
||||
def test_docstring_to_schema():
|
||||
docstring = """
|
||||
Some test desc.
|
||||
|
||||
Args:
|
||||
features (list): Columns to be processed.
|
||||
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be
|
||||
used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
|
||||
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
|
||||
Defaults to None.
|
||||
Returns:
|
||||
pd.DataFrame: The transformed DataFrame.
|
||||
"""
|
||||
expected = {
|
||||
"description": "Some test desc.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
}
|
||||
schema = docstring_to_schema(docstring)
|
||||
assert schema == expected
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema
|
||||
|
||||
|
||||
class DummyClass:
|
||||
|
|
@ -81,12 +45,26 @@ class DummyClass:
|
|||
pass
|
||||
|
||||
|
||||
def dummy_fn(df: pd.DataFrame) -> dict:
|
||||
# def dummy_fn(df: pd.DataFrame, s: str, k: int = 5, type: Literal["a", "b", "c"] = "a") -> dict:
|
||||
def dummy_fn(
|
||||
df: pd.DataFrame,
|
||||
s: str,
|
||||
k: int = 5,
|
||||
type: Literal["a", "b", "c"] = "a",
|
||||
test_dict: dict[str, int] = None,
|
||||
test_union: Union[str, list[str]] = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Analyzes a DataFrame and categorizes its columns based on data types.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): The DataFrame to be analyzed.
|
||||
df: The DataFrame to be analyzed.
|
||||
Another line for df.
|
||||
s (str): Some test string param.
|
||||
Another line for s.
|
||||
k (int, optional): Some test integer param. Defaults to 5.
|
||||
type (Literal["a", "b", "c"], optional): Some test type. Defaults to 'a'.
|
||||
more_args: will be omitted here for testing
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others').
|
||||
|
|
@ -115,41 +93,21 @@ def test_convert_code_to_tool_schema_class():
|
|||
"methods": {
|
||||
"__init__": {
|
||||
"type": "function",
|
||||
"description": "Initialize self.",
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"features": {"type": "list", "description": "Columns to be processed."},
|
||||
"strategy": {
|
||||
"type": "str",
|
||||
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
|
||||
"default": "'mean'",
|
||||
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
|
||||
},
|
||||
"fill_value": {
|
||||
"type": "int",
|
||||
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
"default": "None",
|
||||
},
|
||||
},
|
||||
"required": ["features"],
|
||||
},
|
||||
"description": "Initialize self. ",
|
||||
"signature": "(self, features: list, strategy: str = 'mean', fill_value=None)",
|
||||
"parameters": "Args: features (list): Columns to be processed. strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'. fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
|
||||
},
|
||||
"fit": {
|
||||
"type": "function",
|
||||
"description": "Fit the FillMissingValue model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"description": "Fit the FillMissingValue model. ",
|
||||
"signature": "(self, df: pandas.core.frame.DataFrame)",
|
||||
"parameters": "Args: df (pd.DataFrame): The input DataFrame.",
|
||||
},
|
||||
"transform": {
|
||||
"type": "function",
|
||||
"description": "Transform the input DataFrame with the fitted model.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
|
||||
"description": "Transform the input DataFrame with the fitted model. ",
|
||||
"signature": "(self, df: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame",
|
||||
"parameters": "Args: df (pd.DataFrame): The input DataFrame. Returns: pd.DataFrame: The transformed DataFrame.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -160,11 +118,9 @@ def test_convert_code_to_tool_schema_class():
|
|||
def test_convert_code_to_tool_schema_function():
|
||||
expected = {
|
||||
"type": "function",
|
||||
"description": "Analyzes a DataFrame and categorizes its columns based on data types.",
|
||||
"parameters": {
|
||||
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
|
||||
"required": ["df"],
|
||||
},
|
||||
"description": "Analyzes a DataFrame and categorizes its columns based on data types. ",
|
||||
"signature": "(df: pandas.core.frame.DataFrame, s: str, k: int = 5, type: Literal['a', 'b', 'c'] = 'a', test_dict: dict[str, int] = None, test_union: Union[str, list[str]] = '') -> dict",
|
||||
"parameters": "Args: df: The DataFrame to be analyzed. Another line for df. s (str): Some test string param. Another line for s. k (int, optional): Some test integer param. Defaults to 5. type (Literal[\"a\", \"b\", \"c\"], optional): Some test type. Defaults to 'a'. more_args: will be omitted here for testing Returns: dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others'). Each key corresponds to a list of column names belonging to that category.",
|
||||
}
|
||||
schema = convert_code_to_tool_schema(dummy_fn)
|
||||
assert schema == expected
|
||||
|
|
|
|||
78
tests/metagpt/tools/test_tool_recommend.py
Normal file
78
tests/metagpt/tools/test_tool_recommend.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.schema import Plan, Task
|
||||
from metagpt.tools import TOOL_REGISTRY
|
||||
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_plan(mocker):
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="conduct feature engineering, add new features on the dataset",
|
||||
task_type="feature_engineering",
|
||||
)
|
||||
}
|
||||
plan = Plan(
|
||||
goal="test requirement",
|
||||
tasks=list(task_map.values()),
|
||||
task_map=task_map,
|
||||
current_task_id="1",
|
||||
)
|
||||
return plan
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bm25_tr(mocker):
|
||||
tr = BM25ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping"])
|
||||
return tr
|
||||
|
||||
|
||||
def test_tr_init():
|
||||
tr = ToolRecommender(tools=["FillMissingValue", "PolynomialExpansion", "web scraping", "non-existing tool"])
|
||||
# web_scraping is a tool tag, it has one tool scrape_web_playwright
|
||||
assert list(tr.tools.keys()) == [
|
||||
"FillMissingValue",
|
||||
"PolynomialExpansion",
|
||||
"scrape_web_playwright",
|
||||
]
|
||||
|
||||
|
||||
def test_tr_init_default_tools_value():
|
||||
tr = ToolRecommender()
|
||||
assert tr.tools == {}
|
||||
|
||||
|
||||
def test_tr_init_tools_all():
|
||||
tr = ToolRecommender(tools=["<all>"])
|
||||
assert list(tr.tools.keys()) == list(TOOL_REGISTRY.get_all_tools().keys())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_tr_recall_with_plan(mock_plan, mock_bm25_tr):
|
||||
result = await mock_bm25_tr.recall_tools(plan=mock_plan)
|
||||
assert len(result) == 3
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_tr_recall_no_plan(mock_plan, mock_bm25_tr):
|
||||
result = await mock_bm25_tr.recall_tools(
|
||||
context="conduct feature engineering, add new features on the dataset", plan=None
|
||||
)
|
||||
assert len(result) == 3
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bm25_recommend_tools(mock_bm25_tr):
|
||||
result = await mock_bm25_tr.recommend_tools(context="conduct feature engineering, add new features on the dataset")
|
||||
assert len(result) == 2 # web scraping tool should be filtered out at rank stage
|
||||
assert result[0].name == "PolynomialExpansion"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_recommended_tool_info(mock_plan, mock_bm25_tr):
|
||||
result = await mock_bm25_tr.get_recommended_tool_info(plan=mock_plan)
|
||||
assert isinstance(result, str)
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.tools.tool_registry import ToolRegistry
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -9,25 +8,11 @@ def tool_registry():
|
|||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_registry_full():
|
||||
return ToolRegistry(tool_types=ToolType)
|
||||
|
||||
|
||||
# Test Initialization
|
||||
def test_initialization(tool_registry):
|
||||
assert isinstance(tool_registry, ToolRegistry)
|
||||
assert tool_registry.tools == {}
|
||||
assert tool_registry.tool_types == {}
|
||||
assert tool_registry.tools_by_types == {}
|
||||
|
||||
|
||||
# Test Initialization with tool types
|
||||
def test_initialize_with_tool_types(tool_registry_full):
|
||||
assert isinstance(tool_registry_full, ToolRegistry)
|
||||
assert tool_registry_full.tools == {}
|
||||
assert tool_registry_full.tools_by_types == {}
|
||||
assert "data_preprocess" in tool_registry_full.tool_types
|
||||
assert tool_registry.tools_by_tags == {}
|
||||
|
||||
|
||||
class TestClassTool:
|
||||
|
|
@ -72,31 +57,24 @@ def test_get_tool(tool_registry):
|
|||
assert "description" in tool.schemas
|
||||
|
||||
|
||||
# Similar tests for has_tool_type, get_tool_type, get_tools_by_type
|
||||
def test_has_tool_type(tool_registry_full):
|
||||
assert tool_registry_full.has_tool_type("data_preprocess")
|
||||
assert not tool_registry_full.has_tool_type("NonexistentType")
|
||||
def test_has_tool_tag(tool_registry):
|
||||
tool_registry.register_tool(
|
||||
"TestClassTool", "/path/to/tool", tool_source_object=TestClassTool, tags=["machine learning", "test"]
|
||||
)
|
||||
assert tool_registry.has_tool_tag("test")
|
||||
assert not tool_registry.has_tool_tag("Non-existent tag")
|
||||
|
||||
|
||||
def test_get_tool_type(tool_registry_full):
|
||||
retrieved_type = tool_registry_full.get_tool_type("data_preprocess")
|
||||
assert retrieved_type is not None
|
||||
assert retrieved_type.name == "data_preprocess"
|
||||
|
||||
|
||||
def test_get_tools_by_type(tool_registry):
|
||||
tool_type_name = "TestType"
|
||||
def test_get_tools_by_tag(tool_registry):
|
||||
tool_tag_name = "Test Tag"
|
||||
tool_name = "TestTool"
|
||||
tool_path = "/path/to/tool"
|
||||
|
||||
tool_registry.register_tool(tool_name, tool_path, tool_type=tool_type_name, tool_source_object=TestClassTool)
|
||||
tool_registry.register_tool(tool_name, tool_path, tags=[tool_tag_name], tool_source_object=TestClassTool)
|
||||
|
||||
tools_by_type = tool_registry.get_tools_by_type(tool_type_name)
|
||||
assert tools_by_type is not None
|
||||
assert tool_name in tools_by_type
|
||||
tools_by_tag = tool_registry.get_tools_by_tag(tool_tag_name)
|
||||
assert tools_by_tag is not None
|
||||
assert tool_name in tools_by_tag
|
||||
|
||||
|
||||
# Test case for when the tool type does not exist
|
||||
def test_get_tools_by_nonexistent_type(tool_registry):
|
||||
tools_by_type = tool_registry.get_tools_by_type("NonexistentType")
|
||||
assert not tools_by_type
|
||||
tools_by_tag_non_existent = tool_registry.get_tools_by_tag("Non-existent Tag")
|
||||
assert not tools_by_tag_non_existent
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue