mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Merge pull request #1045 from garylin2099/di_fixes
fix bug for task type prompt, add tool recommender arg, support register tools by path
This commit is contained in:
commit
46bba83c1d
9 changed files with 113 additions and 14 deletions
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Literal, Union
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
|
|
@ -39,7 +39,7 @@ class DataInterpreter(Role):
|
|||
use_plan: bool = True
|
||||
use_reflection: bool = False
|
||||
execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True)
|
||||
tools: Union[str, list[str]] = [] # Use special symbol ["<all>"] to indicate use of all registered tools
|
||||
tools: list[str] = [] # Use special symbol ["<all>"] to indicate use of all registered tools
|
||||
tool_recommender: ToolRecommender = None
|
||||
react_mode: Literal["plan_and_act", "react"] = "plan_and_act"
|
||||
max_react_loop: int = 10 # used for react mode
|
||||
|
|
@ -50,7 +50,7 @@ class DataInterpreter(Role):
|
|||
self.use_plan = (
|
||||
self.react_mode == "plan_and_act"
|
||||
) # create a flag for convenience, overwrite any passed-in value
|
||||
if self.tools:
|
||||
if self.tools and not self.tool_recommender:
|
||||
self.tool_recommender = BM25ToolRecommender(tools=self.tools)
|
||||
self.set_actions([WriteAnalysisCode])
|
||||
self._set_state(0)
|
||||
|
|
@ -104,7 +104,7 @@ class DataInterpreter(Role):
|
|||
plan_status = self.planner.get_plan_status() if self.use_plan else ""
|
||||
|
||||
# tool info
|
||||
if self.tools:
|
||||
if self.tool_recommender:
|
||||
context = (
|
||||
self.working_memory.get()[-1].content if self.working_memory.get() else ""
|
||||
) # thoughts from _think stage in 'react' mode
|
||||
|
|
|
|||
|
|
@ -164,8 +164,9 @@ class Planner(BaseModel):
|
|||
code_written = "\n\n".join(code_written)
|
||||
task_results = [task.result for task in finished_tasks]
|
||||
task_results = "\n\n".join(task_results)
|
||||
task_type_name = self.current_task.task_type.upper()
|
||||
guidance = TaskType[task_type_name].value.guidance if hasattr(TaskType, task_type_name) else ""
|
||||
task_type_name = self.current_task.task_type
|
||||
task_type = TaskType.get_type(task_type_name)
|
||||
guidance = task_type.guidance if task_type else ""
|
||||
|
||||
# combine components in a prompt
|
||||
prompt = PLAN_STATUS.format(
|
||||
|
|
|
|||
|
|
@ -71,3 +71,10 @@ class TaskType(Enum):
|
|||
@property
|
||||
def type_name(self):
|
||||
return self.value.name
|
||||
|
||||
@classmethod
|
||||
def get_type(cls, type_name):
|
||||
for member in cls:
|
||||
if member.type_name == type_name:
|
||||
return member.value
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ PARSER = GoogleDocstringParser
|
|||
|
||||
def convert_code_to_tool_schema(obj, include: list[str] = None):
|
||||
docstring = inspect.getdoc(obj)
|
||||
assert docstring, "no docstring found for the objects, skip registering"
|
||||
# assert docstring, "no docstring found for the objects, skip registering"
|
||||
|
||||
if inspect.isclass(obj):
|
||||
schema = {"type": "class", "description": remove_spaces(docstring), "methods": {}}
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import inspect
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -127,15 +127,63 @@ def make_schema(tool_source_object, include, path):
|
|||
return schema
|
||||
|
||||
|
||||
def validate_tool_names(tools: Union[list[str], str]) -> str:
|
||||
def validate_tool_names(tools: list[str]) -> dict[str, Tool]:
|
||||
assert isinstance(tools, list), "tools must be a list of str"
|
||||
valid_tools = {}
|
||||
for key in tools:
|
||||
# one can define either tool names or tool type names, take union to get the whole set
|
||||
if TOOL_REGISTRY.has_tool(key):
|
||||
# one can define either tool names OR tool tags OR tool path, take union to get the whole set
|
||||
# if tool paths are provided, they will be registered on the fly
|
||||
if os.path.isdir(key) or os.path.isfile(key):
|
||||
valid_tools.update(register_tools_from_path(key))
|
||||
elif TOOL_REGISTRY.has_tool(key):
|
||||
valid_tools.update({key: TOOL_REGISTRY.get_tool(key)})
|
||||
elif TOOL_REGISTRY.has_tool_tag(key):
|
||||
valid_tools.update(TOOL_REGISTRY.get_tools_by_tag(key))
|
||||
else:
|
||||
logger.warning(f"invalid tool name or tool type name: {key}, skipped")
|
||||
return valid_tools
|
||||
|
||||
|
||||
def load_module_from_file(filepath):
|
||||
module_name = os.path.splitext(os.path.basename(filepath))[0]
|
||||
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def register_tools_from_file(file_path) -> dict[str, Tool]:
|
||||
registered_tools = {}
|
||||
module = load_module_from_file(file_path)
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) or inspect.isfunction(obj):
|
||||
if obj.__module__ == module.__name__:
|
||||
# excluding imported classes and functions, register only those defined in the file
|
||||
if "metagpt" in file_path:
|
||||
# split to handle ../metagpt/metagpt/tools/... where only metapgt/tools/... is needed
|
||||
file_path = "metagpt" + file_path.split("metagpt")[-1]
|
||||
|
||||
TOOL_REGISTRY.register_tool(
|
||||
tool_name=name,
|
||||
tool_path=file_path,
|
||||
tool_code="", # inspect.getsource(obj) will resulted in TypeError, skip it for now
|
||||
tool_source_object=obj,
|
||||
)
|
||||
registered_tools.update({name: TOOL_REGISTRY.get_tool(name)})
|
||||
|
||||
return registered_tools
|
||||
|
||||
|
||||
def register_tools_from_path(path) -> dict[str, Tool]:
|
||||
tools_registered = {}
|
||||
if os.path.isfile(path) and path.endswith(".py"):
|
||||
# Path is a Python file
|
||||
tools_registered.update(register_tools_from_file(path))
|
||||
elif os.path.isdir(path):
|
||||
# Path is a directory
|
||||
for root, _, files in os.walk(path):
|
||||
for file in files:
|
||||
if file.endswith(".py"):
|
||||
file_path = os.path.join(root, file)
|
||||
tools_registered.update(register_tools_from_file(file_path))
|
||||
return tools_registered
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Tuple
|
|||
|
||||
|
||||
def remove_spaces(text):
|
||||
return re.sub(r"\s+", " ", text).strip()
|
||||
return re.sub(r"\s+", " ", text).strip() if text else ""
|
||||
|
||||
|
||||
class DocstringParser:
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -25,7 +25,6 @@ async def test_interpreter(mocker, auto_run):
|
|||
@pytest.mark.asyncio
|
||||
async def test_interpreter_react_mode(mocker):
|
||||
mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
|
||||
mocker.patch("builtins.input", return_value="confirm")
|
||||
|
||||
requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
|
||||
|
||||
|
|
|
|||
37
tests/metagpt/strategy/test_planner.py
Normal file
37
tests/metagpt/strategy/test_planner.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
from metagpt.schema import Plan, Task
|
||||
from metagpt.strategy.planner import Planner
|
||||
from metagpt.strategy.task_type import TaskType
|
||||
|
||||
MOCK_TASK_MAP = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="test instruction for finished task",
|
||||
task_type=TaskType.EDA.type_name,
|
||||
dependent_task_ids=[],
|
||||
code="some finished test code",
|
||||
result="some finished test result",
|
||||
is_finished=True,
|
||||
),
|
||||
"2": Task(
|
||||
task_id="2",
|
||||
instruction="test instruction for current task",
|
||||
task_type=TaskType.DATA_PREPROCESS.type_name,
|
||||
dependent_task_ids=["1"],
|
||||
),
|
||||
}
|
||||
MOCK_PLAN = Plan(
|
||||
goal="test goal",
|
||||
tasks=list(MOCK_TASK_MAP.values()),
|
||||
task_map=MOCK_TASK_MAP,
|
||||
current_task_id="2",
|
||||
)
|
||||
|
||||
|
||||
def test_planner_get_plan_status():
|
||||
planner = Planner(plan=MOCK_PLAN)
|
||||
status = planner.get_plan_status()
|
||||
|
||||
assert "some finished test code" in status
|
||||
assert "some finished test result" in status
|
||||
assert "test instruction for current task" in status
|
||||
assert TaskType.DATA_PREPROCESS.value.guidance in status # current task guidance
|
||||
Loading…
Add table
Add a link
Reference in a new issue