mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-10 16:22:37 +02:00
more plan operation, review update, add kaggle team
This commit is contained in:
parent
8b3d640dd6
commit
d3d08fe5f3
10 changed files with 330 additions and 88 deletions
|
|
@ -94,4 +94,7 @@ MODEL_FOR_RESEARCHER_REPORT: gpt-3.5-turbo-16k
|
|||
### browser path for pyppeteer engine, support Chrome, Chromium,MS Edge
|
||||
#PYPPETEER_EXECUTABLE_PATH: "/usr/bin/google-chrome-stable"
|
||||
|
||||
PROMPT_FORMAT: json #json or markdown
|
||||
PROMPT_FORMAT: json #json or markdown
|
||||
|
||||
KAGGLE_USERNAME: ""
|
||||
KAGGLE_KEY: ""
|
||||
|
|
@ -12,13 +12,14 @@ async def main(
|
|||
# competition: str,
|
||||
# data_desc: str,
|
||||
# requirement: str,
|
||||
investment: float = 3.0,
|
||||
investment: float = 5.0,
|
||||
n_round: int = 5,
|
||||
):
|
||||
competition, data_desc, requirement = (
|
||||
"titanic",
|
||||
"Training set is train.csv.\nTest set is test.csv. We also include gender_submission.csv, a set of predictions that assume all and only female passengers survive, as an example of what a submission file should look like.",
|
||||
"Run EDA on the train dataset, train a model to predict survival (20% as validation) and save it, predict the test set using saved model, save the test result according to format",
|
||||
# "generate a random prediction of the same shape as gender_submission.csv and save",
|
||||
)
|
||||
|
||||
team = Team()
|
||||
|
|
|
|||
119
metagpt/actions/ml_da_action.py
Normal file
119
metagpt/actions/ml_da_action.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import json
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.schema import Message, Plan
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def truncate(result: str, keep_len: int = 1000) -> str:
|
||||
desc = "Truncated to show only the last 1000 characters\n"
|
||||
if result.startswith(desc):
|
||||
result = result[-len(desc) :]
|
||||
|
||||
if len(result) > keep_len:
|
||||
result = result[-keep_len:]
|
||||
|
||||
if not result.startswith(desc):
|
||||
return desc + result
|
||||
return desc
|
||||
|
||||
|
||||
class ReviewConst:
|
||||
TASK_REVIEW_TRIGGER = "task"
|
||||
CODE_REVIEW_TRIGGER = "code"
|
||||
CONTINUE_WORD = ["confirm", "continue", "c", "yes", "y"]
|
||||
CHANGE_WORD = ["change"]
|
||||
EXIT_WORD = ["exit"]
|
||||
TASK_REVIEW_INSTRUCTION = (
|
||||
f"If you want to change, add, delete a task or merge tasks in the plan, say '{CHANGE_WORD[0]} task task_id or current task, ... (things to change)' "
|
||||
f"If you confirm the output from the current task and wish to continue, type: {CONTINUE_WORD[0]}"
|
||||
)
|
||||
CODE_REVIEW_INSTRUCTION = (
|
||||
f"If you want the codes to be rewritten, say '{CHANGE_WORD[0]} ... (your change advice)' "
|
||||
f"If you want to leave it as is, type: {CONTINUE_WORD[0]} or {CONTINUE_WORD[1]}"
|
||||
)
|
||||
EXIT_INSTRUCTION = f"If you want to terminate the process, type: {EXIT_WORD[0]}"
|
||||
|
||||
|
||||
class AskReview(Action):
|
||||
async def run(
|
||||
self, context: List[Message], plan: Plan = None, trigger: str = "task"
|
||||
):
|
||||
logger.info("Current overall plan:")
|
||||
logger.info(
|
||||
"\n".join(
|
||||
[
|
||||
f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}"
|
||||
for task in plan.tasks
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("most recent context:")
|
||||
latest_action = context[-1].cause_by.__name__ if context[-1].cause_by else ""
|
||||
review_instruction = (
|
||||
ReviewConst.TASK_REVIEW_INSTRUCTION
|
||||
if trigger == ReviewConst.TASK_REVIEW_TRIGGER
|
||||
else ReviewConst.CODE_REVIEW_INSTRUCTION
|
||||
)
|
||||
prompt = (
|
||||
f"This is a <{trigger}> review. Please review output from {latest_action}\n"
|
||||
f"{review_instruction}\n"
|
||||
f"{ReviewConst.EXIT_INSTRUCTION}\n"
|
||||
"Please type your review below:\n"
|
||||
)
|
||||
|
||||
rsp = input(prompt)
|
||||
|
||||
if rsp.lower() in ReviewConst.EXIT_WORD:
|
||||
exit()
|
||||
|
||||
confirmed = rsp.lower() in ReviewConst.CONTINUE_WORD
|
||||
|
||||
return rsp, confirmed
|
||||
|
||||
|
||||
class SummarizeAnalysis(Action):
|
||||
PROMPT_TEMPLATE = """
|
||||
# Context
|
||||
{context}
|
||||
# Summary
|
||||
Output a 30-word summary on analysis tool and modeling algorithms you have used, and the corresponding result. Make sure to announce the complete path to your test prediction file. Your summary:
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "", context=None, llm=None) -> str:
|
||||
super().__init__(name, context, llm)
|
||||
|
||||
async def run(self, conmpleted_plan: Plan) -> str:
|
||||
tasks = json.dumps(
|
||||
[task.dict() for task in conmpleted_plan.tasks],
|
||||
indent=4,
|
||||
ensure_ascii=False,
|
||||
) # all tasks finished, return all task outputs
|
||||
prompt = self.PROMPT_TEMPLATE.format(context=tasks)
|
||||
summary = await self._aask(prompt)
|
||||
return summary
|
||||
|
||||
|
||||
class Reflect(Action):
|
||||
PROMPT_TEMPLATE = """
|
||||
# User Requirement
|
||||
{user_requirement}
|
||||
# Context
|
||||
{context}
|
||||
# Summary
|
||||
Above is all your attempts to tackle the user requirement. You plan, act, submit your output, and get the result and feedback.
|
||||
First, summarize each of your previous trial in a triple of (your methods, the corresponding result, potential improvement), list them out.
|
||||
# Takeaways
|
||||
Second, carefully find key takeaways from your summarization in a step-by-step thinking process
|
||||
# Guidance
|
||||
Finally, make a concise one-sentence guidance for improving your future plan.
|
||||
Your response:
|
||||
"""
|
||||
|
||||
async def run(self, context: str) -> str:
|
||||
user_requirement = "Score as high as possible in a data modeling competition"
|
||||
prompt = self.PROMPT_TEMPLATE.format(context=context, user_requirement=user_requirement)
|
||||
rsp = await self._aask(prompt)
|
||||
return rsp
|
||||
|
|
@ -17,7 +17,7 @@ class WritePlan(Action):
|
|||
__context__
|
||||
# Task:
|
||||
Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to __max_tasks__ tasks.
|
||||
If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes.
|
||||
If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes. Give the whole plan unless instructed to modify only one task of the plan.
|
||||
Output a list of jsons following the format:
|
||||
```json
|
||||
[
|
||||
|
|
|
|||
|
|
@ -95,6 +95,9 @@ class Config(metaclass=Singleton):
|
|||
|
||||
self.prompt_format = self._get("PROMPT_FORMAT", "markdown")
|
||||
|
||||
self.kaggle_username = self._get("KAGGLE_USERNAME", "")
|
||||
self.kaggle_key = self._get("KAGGLE_KEY", "")
|
||||
|
||||
def _init_with_config_files_and_env(self, configs: dict, yaml_file):
|
||||
"""Load from config/key.yaml, config/config.yaml, and env in decreasing order of priority"""
|
||||
configs.update(os.environ)
|
||||
|
|
|
|||
|
|
@ -168,3 +168,14 @@ ML_MODULE_MAP = {
|
|||
"classification_model": "metagpt.tools.functions.libs.machine_learning.ml_model",
|
||||
"regression_model": "metagpt.tools.functions.libs.machine_learning.ml_model",
|
||||
}
|
||||
|
||||
STRUCTURAL_CONTEXT = """
|
||||
## User Requirement
|
||||
{user_requirement}
|
||||
## Data Description
|
||||
{data_desc}
|
||||
## Current Plan
|
||||
{tasks}
|
||||
## Current Task
|
||||
{current_task}
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5,16 +5,18 @@ import subprocess
|
|||
import fire
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import WORKSPACE_ROOT
|
||||
from metagpt.roles import Role
|
||||
from metagpt.actions import Action, BossRequirement
|
||||
from metagpt.actions.write_analysis_code import AskReview, SummarizeAnalysis
|
||||
from metagpt.actions.ml_da_action import AskReview, SummarizeAnalysis
|
||||
from metagpt.schema import Message, Task, Plan
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
import os
|
||||
os.environ["KAGGLE_USERNAME"] = "xxx"
|
||||
os.environ["KAGGLE_KEY"] = "xxx"
|
||||
os.environ["KAGGLE_USERNAME"] = CONFIG.kaggle_username
|
||||
os.environ["KAGGLE_KEY"] = CONFIG.kaggle_key
|
||||
|
||||
def run_command(cmd):
|
||||
print(cmd)
|
||||
|
|
@ -38,6 +40,7 @@ class DownloadData(Action):
|
|||
|
||||
# if not os.path.exists(data_path):
|
||||
if True:
|
||||
# run_command(f"rm -r {data_path / '*'}")
|
||||
run_command(f"unzip -o {WORKSPACE_ROOT / '*.zip'} -d {data_path}") # FIXME: not safe
|
||||
|
||||
file_list = run_command(f"ls {data_path}")
|
||||
|
|
@ -52,24 +55,30 @@ class DownloadData(Action):
|
|||
|
||||
class SubmitResult(Action):
|
||||
PROMPT_TEMPLATE = """
|
||||
# Context
|
||||
{context}
|
||||
# Summary
|
||||
__summary__
|
||||
# Your task
|
||||
Extract the prediction file for test set, return only the path string, e.g., xxx.csv, xxx.xlsx
|
||||
Extract the file path for test set prediction from the summary above, output a json following the format:
|
||||
```json
|
||||
{"file_path": str = "the file path, for example, /path/to/the/prediction/file/xxx.csv, /path/to/the/prediction/file/xxx.xlsx"}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "", context=None, llm=None) -> str:
|
||||
super().__init__(name, context, llm)
|
||||
|
||||
async def _parse_submit_file_path(self, context) -> str:
|
||||
prompt = self.PROMPT_TEMPLATE.format(context=context)
|
||||
prompt = self.PROMPT_TEMPLATE.replace("__summary__", context)
|
||||
rsp = await self._aask(prompt)
|
||||
return rsp
|
||||
rsp = CodeParser.parse_code(block=None, text=rsp)
|
||||
file_path = json.loads(rsp)["file_path"]
|
||||
return file_path
|
||||
|
||||
async def run(self, competition, submit_message="") -> str:
|
||||
submit_file_path = self._parse_submit_file_path(submit_message)
|
||||
submit_file_path = await self._parse_submit_file_path(submit_message)
|
||||
|
||||
data_path = WORKSPACE_ROOT / competition
|
||||
submit_message = submit_message.replace("'", "")
|
||||
|
||||
run_command(f"kaggle competitions submit {competition} -f {submit_file_path} -m '{submit_message}'")
|
||||
run_command(f"kaggle competitions leaderboard --show --csv {competition} > {data_path / 'leaderboard.csv'}")
|
||||
|
|
@ -77,20 +86,20 @@ class SubmitResult(Action):
|
|||
|
||||
leaderboard = pd.read_csv(data_path / 'leaderboard.csv')
|
||||
submission = pd.read_csv(data_path / 'submission.csv')
|
||||
submission_score = submission.loc[0, "publicScore"]
|
||||
submission_rank = leaderboard.loc[leaderboard["score"] == submission_score].index[0]
|
||||
submission_rank_pct = round(submission_rank / len(leaderboard), 4) * 100
|
||||
print(submission) # submission.to_json(orient="records")
|
||||
|
||||
# best_score = max(submission["publicScore"])
|
||||
# best_rank = leaderboard.loc[leaderboard["score"] == best_score].index[0]
|
||||
submission_score = submission.loc[0, "publicScore"]
|
||||
best_score = max(submission["publicScore"]) # might be min
|
||||
rank = leaderboard.loc[leaderboard["score"] == best_score].index[0]
|
||||
rank_pct = round(rank / len(leaderboard), 4) * 100
|
||||
|
||||
submission_summary = f"""
|
||||
## All History
|
||||
{submission.to_json(orient="records")}
|
||||
## Current
|
||||
Current submission score: {submission_score}, rank: {submission_rank} (top {submission_rank_pct}%);
|
||||
# All histories:
|
||||
{submission.head(5).to_string()}
|
||||
# Current
|
||||
Current submission score: {submission_score}, best score: {best_score}, best rank: {rank} (top {rank_pct}%)
|
||||
"""
|
||||
print(submission_summary)
|
||||
logger.info(submission_summary)
|
||||
return submission_summary
|
||||
|
||||
|
||||
|
|
@ -110,8 +119,6 @@ class KaggleManager(Role):
|
|||
self._set_state(0) # DownloadData, get competition of interest from human, download datasets
|
||||
elif observed == SummarizeAnalysis:
|
||||
self._set_state(1) # SubmitResult, get prediction from MLEngineer and submit it to Kaggle
|
||||
elif observed == SubmitResult:
|
||||
self._set_state(2) # AskReview, ask human for improvement
|
||||
|
||||
async def _act(self):
|
||||
todo = self._rc.todo
|
||||
|
|
@ -127,3 +134,19 @@ class KaggleManager(Role):
|
|||
msg = Message(content=rsp, role="user", cause_by=type(todo))
|
||||
|
||||
return msg
|
||||
|
||||
if __name__ == "__main__":
|
||||
competition, data_desc, requirement = (
|
||||
"titanic",
|
||||
"Training set is train.csv.\nTest set is test.csv. We also include gender_submission.csv, a set of predictions that assume all and only female passengers survive, as an example of what a submission file should look like.",
|
||||
"Run EDA on the train dataset, train a model to predict survival (20% as validation) and save it, predict the test set using saved model, save the test result according to format",
|
||||
)
|
||||
|
||||
summary = "I used Python with pandas for data preprocessing, sklearn's RandomForestClassifier for modeling, and achieved 82.12% accuracy on validation. Predictions saved at '/Users/gary/Desktop/data_agents_opt/workspace/titanic/gender_submission.csv'."
|
||||
|
||||
async def main(requirement: str = requirement):
|
||||
role = KaggleManager(competition=competition, data_desc=data_desc)
|
||||
# await role.run(Message(content="", cause_by=BossRequirement))
|
||||
await role.run(Message(content=summary, cause_by=SummarizeAnalysis))
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
@ -7,55 +7,14 @@ import fire
|
|||
from metagpt.roles import Role
|
||||
from metagpt.actions import Action
|
||||
from metagpt.schema import Message, Task, Plan
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.logs import logger
|
||||
from metagpt.actions.write_plan import WritePlan
|
||||
from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools
|
||||
from metagpt.actions.ml_da_action import AskReview, SummarizeAnalysis, Reflect, ReviewConst, truncate
|
||||
from metagpt.actions.execute_code import ExecutePyCode
|
||||
|
||||
STRUCTURAL_CONTEXT = """
|
||||
## User Requirement
|
||||
{user_requirement}
|
||||
## Current Plan
|
||||
{tasks}
|
||||
## Current Task
|
||||
{current_task}
|
||||
"""
|
||||
|
||||
|
||||
def truncate(result: str, keep_len: int = 1000) -> str:
|
||||
desc = "Truncated to show only the last 1000 characters\n"
|
||||
if result.startswith(desc):
|
||||
result = result[-len(desc) :]
|
||||
|
||||
if len(result) > keep_len:
|
||||
result = result[-keep_len:]
|
||||
|
||||
if not result.startswith(desc):
|
||||
return desc + result
|
||||
return desc
|
||||
|
||||
|
||||
class AskReview(Action):
|
||||
async def run(self, context: List[Message], plan: Plan = None):
|
||||
logger.info("Current overall plan:")
|
||||
logger.info(
|
||||
"\n".join([f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}" for task in plan.tasks])
|
||||
)
|
||||
|
||||
logger.info("most recent context:")
|
||||
latest_action = context[-1].cause_by.__name__ if context[-1].cause_by else ""
|
||||
prompt = f"\nPlease review output from {latest_action}:\n" \
|
||||
"If you want to change a task in the plan, say 'change task task_id, ... (things to change)'\n" \
|
||||
"If you confirm the output and wish to continue with the current process, type CONFIRM\n" \
|
||||
"If you want to terminate the process, type exit:\n"
|
||||
rsp = input(prompt)
|
||||
|
||||
if rsp.lower() in ("exit"):
|
||||
exit()
|
||||
|
||||
confirmed = rsp.lower() in ("confirm", "yes", "y")
|
||||
|
||||
return rsp, confirmed
|
||||
from metagpt.roles.kaggle_manager import DownloadData, SubmitResult
|
||||
from metagpt.prompts.ml_engineer import STRUCTURAL_CONTEXT
|
||||
|
||||
|
||||
class WriteTaskGuide(Action):
|
||||
|
|
@ -69,13 +28,35 @@ class MLEngineer(Role):
|
|||
):
|
||||
super().__init__(name=name, profile=profile, goal=goal)
|
||||
self._set_react_mode(react_mode="plan_and_act")
|
||||
self._watch([DownloadData, SubmitResult])
|
||||
|
||||
self.plan = Plan(goal=goal)
|
||||
self.use_tools = False
|
||||
self.use_task_guide = False
|
||||
self.execute_code = ExecutePyCode()
|
||||
self.auto_run = auto_run
|
||||
|
||||
# memory for working on each task, discarded each time a task is done
|
||||
self.working_memory = Memory()
|
||||
|
||||
async def _plan_and_act(self):
|
||||
|
||||
### Actions in a multi-agent multi-turn setting ###
|
||||
memories = self.get_memories()
|
||||
if memories:
|
||||
latest_event = memories[-1].cause_by
|
||||
if latest_event == DownloadData:
|
||||
self.plan.context = memories[-1].content
|
||||
elif latest_event == SubmitResult:
|
||||
# get feedback for improvement from human, add to working memory
|
||||
await self._ask_review(trigger=ReviewConst.TASK_REVIEW_TRIGGER)
|
||||
# self reflect on previous plan outcomes and think about how to improve the plan, add to working memory
|
||||
prev_plan_outcomes = memories[-1].content
|
||||
reflection = await Reflect().run(context=prev_plan_outcomes)
|
||||
self.working_memory.add(Message(content=reflection, role="assistant"))
|
||||
|
||||
|
||||
### Common Procedure in both single- and multi-agent setting ###
|
||||
# create initial plan and update until confirmation
|
||||
await self._update_plan()
|
||||
|
||||
|
|
@ -87,7 +68,7 @@ class MLEngineer(Role):
|
|||
code, result, success = await self._write_and_exec_code()
|
||||
|
||||
# ask for acceptance, users can other refuse and change tasks in the plan
|
||||
task_result_confirmed = await self._ask_review()
|
||||
review, task_result_confirmed = await self._ask_review(trigger=ReviewConst.TASK_REVIEW_TRIGGER)
|
||||
|
||||
if success and task_result_confirmed:
|
||||
# tick off this task and record progress
|
||||
|
|
@ -98,7 +79,16 @@ class MLEngineer(Role):
|
|||
|
||||
else:
|
||||
# update plan according to user's feedback and to take on changed tasks
|
||||
await self._update_plan()
|
||||
await self._update_plan(review)
|
||||
|
||||
completed_plan_memory = self.get_useful_memories() # completed plan as a outcome
|
||||
self._rc.memory.add(completed_plan_memory[0]) # add to persistent memory
|
||||
|
||||
summary = await SummarizeAnalysis().run(self.plan)
|
||||
rsp = Message(content=summary, cause_by=SummarizeAnalysis)
|
||||
self._rc.memory.add(rsp)
|
||||
|
||||
return rsp
|
||||
|
||||
async def _write_and_exec_code(self, max_retry: int = 3):
|
||||
task_guide = (
|
||||
|
|
@ -143,23 +133,28 @@ class MLEngineer(Role):
|
|||
|
||||
if "!pip" in code:
|
||||
success = False
|
||||
# if not success:
|
||||
# await self._ask_review()
|
||||
|
||||
counter += 1
|
||||
|
||||
if not success and counter >= max_retry:
|
||||
logger.info("coding failed!")
|
||||
review, _ = await self._ask_review(auto_run=False, trigger=ReviewConst.CODE_REVIEW_TRIGGER)
|
||||
if ReviewConst.CHANGE_WORD in review:
|
||||
counter = 0 # redo the task again with help of human suggestions
|
||||
|
||||
return code, result, success
|
||||
|
||||
async def _ask_review(self):
|
||||
if not self.auto_run:
|
||||
async def _ask_review(self, auto_run: bool = None, trigger: str = ReviewConst.TASK_REVIEW_TRIGGER):
|
||||
auto_run = auto_run or self.auto_run
|
||||
if not auto_run:
|
||||
context = self.get_useful_memories()
|
||||
review, confirmed = await AskReview().run(context=context[-5:], plan=self.plan)
|
||||
review, confirmed = await AskReview().run(context=context[-5:], plan=self.plan, trigger=trigger)
|
||||
if not confirmed:
|
||||
self.working_memory.add(Message(content=review, role="user", cause_by=AskReview))
|
||||
return confirmed
|
||||
return True
|
||||
return review, confirmed
|
||||
return "", True
|
||||
|
||||
async def _update_plan(self, max_tasks: int = 3):
|
||||
async def _update_plan(self, review: str = "", max_tasks: int = 3):
|
||||
plan_confirmed = False
|
||||
while not plan_confirmed:
|
||||
context = self.get_useful_memories()
|
||||
|
|
@ -167,30 +162,36 @@ class MLEngineer(Role):
|
|||
self.working_memory.add(
|
||||
Message(content=rsp, role="assistant", cause_by=WritePlan)
|
||||
)
|
||||
plan_confirmed = await self._ask_review()
|
||||
|
||||
# TODO: precheck plan before asking reviews
|
||||
|
||||
_, plan_confirmed = await self._ask_review(trigger=ReviewConst.TASK_REVIEW_TRIGGER)
|
||||
|
||||
tasks = WritePlan.rsp_to_tasks(rsp)
|
||||
self.plan.add_tasks(tasks)
|
||||
self.working_memory.clear()
|
||||
if len(tasks) == 1 and self.plan.has_task_id(tasks[0].task_id):
|
||||
self.plan.replace_task(tasks[0])
|
||||
else:
|
||||
self.plan.add_tasks(tasks)
|
||||
self.working_memory.clear()
|
||||
|
||||
def get_useful_memories(self) -> List[Message]:
|
||||
"""find useful memories only to reduce context length and improve performance"""
|
||||
|
||||
user_requirement = self.plan.goal
|
||||
data_desc = self.plan.context
|
||||
tasks = json.dumps(
|
||||
[task.dict() for task in self.plan.tasks], indent=4, ensure_ascii=False
|
||||
)
|
||||
current_task = self.plan.current_task.json() if self.plan.current_task else {}
|
||||
context = STRUCTURAL_CONTEXT.format(
|
||||
user_requirement=user_requirement, tasks=tasks, current_task=current_task
|
||||
user_requirement=user_requirement, data_desc=data_desc, tasks=tasks, current_task=current_task
|
||||
)
|
||||
context_msg = [Message(content=context, role="user")]
|
||||
|
||||
return context_msg + self.working_memory.get()
|
||||
|
||||
@property
|
||||
def working_memory(self):
|
||||
return self._rc.memory
|
||||
return context_msg + self.get_working_memories()
|
||||
|
||||
def get_working_memories(self) -> List[Message]:
|
||||
return self.working_memory.get()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -156,7 +156,49 @@ class Plan(BaseModel):
|
|||
|
||||
# Update the task map for quick access to tasks by ID
|
||||
self.task_map = {task.task_id: task for task in self.tasks}
|
||||
|
||||
def reset_task(self, task_id: str):
|
||||
"""
|
||||
Clear code and result of the task based on task_id, and set the task as unfinished.
|
||||
|
||||
Args:
|
||||
task_id (str): The ID of the task to be reset.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if task_id in self.task_map:
|
||||
task = self.task_map[task_id]
|
||||
task.code = ""
|
||||
task.result = ""
|
||||
task.is_finished = False
|
||||
|
||||
def replace_task(self, new_task: Task):
|
||||
"""
|
||||
Replace an existing task with the new input task based on task_id, and reset all tasks depending on it.
|
||||
|
||||
Args:
|
||||
new_task (Task): The new task that will replace an existing one.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if new_task.task_id in self.task_map:
|
||||
# Replace the task in the task map and the task list
|
||||
self.task_map[new_task.task_id] = new_task
|
||||
for i, task in enumerate(self.tasks):
|
||||
if task.task_id == new_task.task_id:
|
||||
self.tasks[i] = new_task
|
||||
break
|
||||
|
||||
# Reset dependent tasks
|
||||
for task in self.tasks:
|
||||
if new_task.task_id in task.dependent_task_ids:
|
||||
self.reset_task(task.task_id)
|
||||
|
||||
def has_task_id(self, task_id: str) -> bool:
|
||||
return task_id in self.task_map
|
||||
|
||||
@property
|
||||
def current_task(self) -> Task:
|
||||
"""Find current task to execute
|
||||
|
|
|
|||
|
|
@ -104,3 +104,42 @@ class TestPlan:
|
|||
finished_tasks = plan.get_finished_tasks()
|
||||
assert len(finished_tasks) == 1
|
||||
assert finished_tasks[0].task_id == "1"
|
||||
|
||||
def test_reset_task_existing(self):
|
||||
plan = Plan(goal="")
|
||||
task = Task(task_id="1", instruction="Do something", code="print('Hello')", result="Hello", finished=True)
|
||||
plan.add_tasks([task])
|
||||
plan.reset_task("1")
|
||||
reset_task = plan.task_map["1"]
|
||||
assert reset_task.code == ""
|
||||
assert reset_task.result == ""
|
||||
assert not reset_task.is_finished
|
||||
|
||||
def test_reset_task_non_existing(self):
|
||||
plan = Plan(goal="")
|
||||
task = Task(task_id="1", instruction="Do something", code="print('Hello')", result="Hello", finished=True)
|
||||
plan.add_tasks([task])
|
||||
plan.reset_task("2") # Task with ID 2 does not exist
|
||||
assert "1" in plan.task_map
|
||||
assert "2" not in plan.task_map
|
||||
|
||||
def test_replace_task_with_dependents(self):
|
||||
plan = Plan(goal="")
|
||||
tasks = [Task(task_id="1", instruction="First Task", finished=True),
|
||||
Task(task_id="2", instruction="Second Task", dependent_task_ids=["1"], finished=True)]
|
||||
plan.add_tasks(tasks)
|
||||
new_task = Task(task_id="1", instruction="Updated First Task")
|
||||
plan.replace_task(new_task)
|
||||
assert plan.task_map["1"].instruction == "Updated First Task"
|
||||
assert not plan.task_map["2"].is_finished # Dependent task should be reset
|
||||
assert plan.task_map["2"].code == ""
|
||||
assert plan.task_map["2"].result == ""
|
||||
|
||||
def test_replace_task_non_existing(self):
|
||||
plan = Plan(goal="")
|
||||
task = Task(task_id="1", instruction="First Task")
|
||||
plan.add_tasks([task])
|
||||
new_task = Task(task_id="2", instruction="New Task")
|
||||
plan.replace_task(new_task) # Task with ID 2 does not exist in plan
|
||||
assert "1" in plan.task_map
|
||||
assert "2" not in plan.task_map
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue