fix critical bug: human prior not injected

This commit is contained in:
yzlin 2024-03-19 17:53:26 +08:00
parent a221d1c418
commit 51ba2b393b
5 changed files with 55 additions and 4 deletions

View file

@ -164,8 +164,9 @@ class Planner(BaseModel):
code_written = "\n\n".join(code_written)
task_results = [task.result for task in finished_tasks]
task_results = "\n\n".join(task_results)
task_type_name = self.current_task.task_type.upper()
guidance = TaskType[task_type_name].value.guidance if hasattr(TaskType, task_type_name) else ""
task_type_name = self.current_task.task_type
task_type = TaskType.get_type(task_type_name)
guidance = task_type.guidance if task_type else ""
# combine components in a prompt
prompt = PLAN_STATUS.format(

View file

@ -71,3 +71,10 @@ class TaskType(Enum):
@property
def type_name(self):
return self.value.name
@classmethod
def get_type(cls, type_name):
for member in cls:
if member.type_name == type_name:
return member.value
return None

File diff suppressed because one or more lines are too long

View file

@ -25,7 +25,6 @@ async def test_interpreter(mocker, auto_run):
@pytest.mark.asyncio
async def test_interpreter_react_mode(mocker):
mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
mocker.patch("builtins.input", return_value="confirm")
requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."

View file

@ -0,0 +1,37 @@
from metagpt.schema import Plan, Task
from metagpt.strategy.planner import Planner
from metagpt.strategy.task_type import TaskType
MOCK_TASK_MAP = {
"1": Task(
task_id="1",
instruction="test instruction for finished task",
task_type=TaskType.EDA.type_name,
dependent_task_ids=[],
code="some finished test code",
result="some finished test result",
is_finished=True,
),
"2": Task(
task_id="2",
instruction="test instruction for current task",
task_type=TaskType.DATA_PREPROCESS.type_name,
dependent_task_ids=["1"],
),
}
MOCK_PLAN = Plan(
goal="test goal",
tasks=list(MOCK_TASK_MAP.values()),
task_map=MOCK_TASK_MAP,
current_task_id="2",
)
def test_planner_get_plan_status():
planner = Planner(plan=MOCK_PLAN)
status = planner.get_plan_status()
assert "some finished test code" in status
assert "some finished test result" in status
assert "test instruction for current task" in status
assert TaskType.DATA_PREPROCESS.value.guidance in status # current task guidance