mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
Merge branch 'dev_tool_selection' of https://gitlab.deepwisdomai.com/agents/data_agents_opt into dev_tool_selection
This commit is contained in:
commit
9d39a058aa
36 changed files with 3953 additions and 916 deletions
1004
metagpt/roles/catboost_info/catboost_training.json
Normal file
1004
metagpt/roles/catboost_info/catboost_training.json
Normal file
File diff suppressed because it is too large
Load diff
BIN
metagpt/roles/catboost_info/learn/events.out.tfevents
Normal file
BIN
metagpt/roles/catboost_info/learn/events.out.tfevents
Normal file
Binary file not shown.
1001
metagpt/roles/catboost_info/learn_error.tsv
Normal file
1001
metagpt/roles/catboost_info/learn_error.tsv
Normal file
File diff suppressed because it is too large
Load diff
1001
metagpt/roles/catboost_info/time_left.tsv
Normal file
1001
metagpt/roles/catboost_info/time_left.tsv
Normal file
File diff suppressed because it is too large
Load diff
153
metagpt/roles/kaggle_manager.py
Normal file
153
metagpt/roles/kaggle_manager.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
from typing import Dict, List, Union, Tuple
|
||||
import json
|
||||
import subprocess
|
||||
import os
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import WORKSPACE_ROOT
|
||||
from metagpt.roles import Role
|
||||
from metagpt.actions import Action, BossRequirement
|
||||
from metagpt.actions.ml_da_action import AskReview, SummarizeAnalysis
|
||||
from metagpt.schema import Message, Task, Plan
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
||||
os.environ["KAGGLE_USERNAME"] = CONFIG.kaggle_username
|
||||
os.environ["KAGGLE_KEY"] = CONFIG.kaggle_key
|
||||
|
||||
def run_command(cmd):
|
||||
print(cmd)
|
||||
output = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
if output.returncode != 0:
|
||||
print("Error output:", output.stderr)
|
||||
exit()
|
||||
else:
|
||||
print(output.stdout)
|
||||
return output.stdout
|
||||
|
||||
class DownloadData(Action):
|
||||
|
||||
async def run(self, competition, data_desc="") -> str:
|
||||
data_path = WORKSPACE_ROOT / competition
|
||||
|
||||
output = run_command(f"kaggle competitions list --search {competition}")
|
||||
assert output != "No competitions found", "You must provide the correct competition name"
|
||||
|
||||
run_command(f"kaggle competitions download {competition} --path {WORKSPACE_ROOT}")
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
# if True:
|
||||
# run_command(f"rm -r {data_path / '*'}")
|
||||
run_command(f"unzip -o {WORKSPACE_ROOT / '*.zip'} -d {data_path}") # FIXME: not safe
|
||||
|
||||
file_list = run_command(f"ls {data_path}")
|
||||
|
||||
rsp = f"""
|
||||
Location:
|
||||
Data downloaded at {data_path} folder, including {file_list}
|
||||
Data Description:
|
||||
{data_desc}
|
||||
"""
|
||||
return rsp
|
||||
|
||||
class SubmitResult(Action):
|
||||
PROMPT_TEMPLATE = """
|
||||
# Summary
|
||||
__summary__
|
||||
# Your task
|
||||
Extract the file path for test set prediction from the summary above, output a json following the format:
|
||||
```json
|
||||
{"file_path": str = "the file path, for example, /path/to/the/prediction/file/xxx.csv, /path/to/the/prediction/file/xxx.xlsx"}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = "", context=None, llm=None) -> str:
|
||||
super().__init__(name, context, llm)
|
||||
|
||||
async def _parse_submit_file_path(self, context) -> str:
|
||||
prompt = self.PROMPT_TEMPLATE.replace("__summary__", context)
|
||||
rsp = await self._aask(prompt)
|
||||
rsp = CodeParser.parse_code(block=None, text=rsp)
|
||||
file_path = json.loads(rsp)["file_path"]
|
||||
return file_path
|
||||
|
||||
async def run(self, competition, submit_message="") -> str:
|
||||
submit_file_path = await self._parse_submit_file_path(submit_message)
|
||||
|
||||
data_path = WORKSPACE_ROOT / competition
|
||||
submit_message = submit_message.replace("'", "")
|
||||
|
||||
run_command(f"kaggle competitions submit {competition} -f {submit_file_path} -m '{submit_message}'")
|
||||
run_command(f"kaggle competitions leaderboard --show --csv {competition} > {data_path / 'leaderboard.csv'}")
|
||||
run_command(f"kaggle competitions submissions --csv {competition} > {data_path / 'submission.csv'}")
|
||||
|
||||
leaderboard = pd.read_csv(data_path / 'leaderboard.csv')
|
||||
submission = pd.read_csv(data_path / 'submission.csv')
|
||||
print(submission) # submission.to_json(orient="records")
|
||||
|
||||
submission_score = submission.loc[0, "publicScore"]
|
||||
best_score = max(submission["publicScore"]) # might be min
|
||||
rank = leaderboard.loc[leaderboard["score"] == best_score].index[0]
|
||||
rank_pct = round(rank / len(leaderboard), 4) * 100
|
||||
|
||||
submission_summary = f"""
|
||||
# All histories:
|
||||
{submission.head(5).to_string()}
|
||||
# Current
|
||||
Current submission score: {submission_score}, best score: {best_score}, best rank: {rank} (top {rank_pct}%)
|
||||
"""
|
||||
logger.info(submission_summary)
|
||||
return submission_summary
|
||||
|
||||
|
||||
class KaggleManager(Role):
|
||||
def __init__(
|
||||
self, name="ABC", profile="KaggleManager", goal="", competition="titanic", data_desc=""
|
||||
):
|
||||
super().__init__(name=name, profile=profile, goal=goal)
|
||||
self._init_actions([DownloadData, SubmitResult])
|
||||
self._watch([BossRequirement, SummarizeAnalysis])
|
||||
self.competition = competition
|
||||
self.data_desc = data_desc # currently passed in, later can be scrapped down from web by another Role
|
||||
|
||||
async def _think(self):
|
||||
observed = self.get_memories()[-1].cause_by
|
||||
if observed == BossRequirement:
|
||||
self._set_state(0) # DownloadData, get competition of interest from human, download datasets
|
||||
elif observed == SummarizeAnalysis:
|
||||
self._set_state(1) # SubmitResult, get prediction from MLEngineer and submit it to Kaggle
|
||||
|
||||
async def _act(self):
|
||||
todo = self._rc.todo
|
||||
logger.info(f"{self._setting}: ready to {self._rc.todo}")
|
||||
|
||||
if isinstance(todo, DownloadData):
|
||||
rsp = await todo.run(self.competition, self.data_desc)
|
||||
|
||||
elif isinstance(todo, SubmitResult):
|
||||
submit_message = self.get_memories()[-1].content # use analysis summary from MLEngineer as submission message
|
||||
rsp = await todo.run(competition=self.competition, submit_message=submit_message)
|
||||
|
||||
msg = Message(content=rsp, role="user", cause_by=type(todo))
|
||||
|
||||
return msg
|
||||
|
||||
if __name__ == "__main__":
|
||||
competition, data_desc, requirement = (
|
||||
"titanic",
|
||||
"Training set is train.csv.\nTest set is test.csv. We also include gender_submission.csv, a set of predictions that assume all and only female passengers survive, as an example of what a submission file should look like.",
|
||||
"Run EDA on the train dataset, train a model to predict survival (20% as validation) and save it, predict the test set using saved model, save the test result according to format",
|
||||
)
|
||||
|
||||
summary = "I used Python with pandas for data preprocessing, sklearn's RandomForestClassifier for modeling, and achieved 82.12% accuracy on validation. Predictions saved at '/Users/gary/Desktop/data_agents_opt/workspace/titanic/gender_submission.csv'."
|
||||
|
||||
async def main(requirement: str = requirement):
|
||||
role = KaggleManager(competition=competition, data_desc=data_desc)
|
||||
# await role.run(Message(content="", cause_by=BossRequirement))
|
||||
await role.run(Message(content=summary, cause_by=SummarizeAnalysis))
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
@ -1,90 +1,32 @@
|
|||
from typing import List
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
import fire
|
||||
import nbformat
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.debug_code import DebugCode
|
||||
from metagpt.actions.execute_code import ExecutePyCode
|
||||
from metagpt.actions.ml_da_action import AskReview, SummarizeAnalysis, Reflect, ReviewConst
|
||||
from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools
|
||||
from metagpt.actions.write_code_steps import WriteCodeSteps
|
||||
from metagpt.actions.write_plan import WritePlan
|
||||
from metagpt.actions.write_plan import update_plan_from_rsp, precheck_update_plan_from_rsp
|
||||
from metagpt.const import DATA_PATH, PROJECT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.prompts.ml_engineer import STRUCTURAL_CONTEXT
|
||||
from metagpt.prompts.ml_engineer import (
|
||||
GEN_DATA_DESC_PROMPT,
|
||||
UPDATE_DATA_COLUMNS,
|
||||
PRINT_DATA_COLUMNS
|
||||
)
|
||||
from metagpt.roles import Role
|
||||
from metagpt.roles.kaggle_manager import DownloadData, SubmitResult
|
||||
from metagpt.schema import Message, Plan
|
||||
from metagpt.utils.common import CodeParser, remove_comments, create_func_config
|
||||
from metagpt.actions.debug_code import DebugCode
|
||||
|
||||
STRUCTURAL_CONTEXT = """
|
||||
## User Requirement
|
||||
{user_requirement}
|
||||
## Dataset Description
|
||||
{data_desc}
|
||||
## Current Plan
|
||||
{tasks}
|
||||
## Current Task
|
||||
{current_task}
|
||||
## Packages Installed
|
||||
pandas
|
||||
numpy
|
||||
"""
|
||||
|
||||
|
||||
# scikit-learn
|
||||
# lightgbm
|
||||
# xgboost
|
||||
# catboost
|
||||
|
||||
def truncate(result: str, keep_len: int = 1000) -> str:
|
||||
desc = "Truncated to show only the last 1000 characters\n"
|
||||
if result.startswith(desc):
|
||||
result = result[-len(desc):]
|
||||
|
||||
if len(result) > keep_len:
|
||||
result = result[-keep_len:]
|
||||
|
||||
if not result.startswith(desc):
|
||||
return desc + result
|
||||
return desc
|
||||
|
||||
|
||||
def remove_escape_and_color_codes(input_str):
|
||||
# 使用正则表达式去除转义字符和颜色代码
|
||||
pattern = re.compile(r'\x1b\[[0-9;]*[mK]')
|
||||
result = pattern.sub('', input_str)
|
||||
return result
|
||||
|
||||
|
||||
class AskReview(Action):
|
||||
async def run(self, context: List[Message], plan: Plan = None):
|
||||
logger.info("Current overall plan:")
|
||||
logger.info(
|
||||
"\n".join([f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}" for task in plan.tasks])
|
||||
)
|
||||
|
||||
logger.info("most recent context:")
|
||||
latest_action = context[-1].cause_by.__name__ if context[-1].cause_by else ""
|
||||
prompt = f"\nPlease review output from {latest_action}:\n" \
|
||||
"If you want to change a task in the plan, say 'change task task_id, ... (things to change)'\n" \
|
||||
"If you confirm the output and wish to continue with the current process, type CONFIRM\n" \
|
||||
"If you want to terminate the process, type exit:\n"
|
||||
rsp = input(prompt)
|
||||
|
||||
if rsp.lower() in ("exit"):
|
||||
exit()
|
||||
|
||||
confirmed = rsp.lower() in ("confirm", "yes", "y")
|
||||
|
||||
return rsp, confirmed
|
||||
from metagpt.utils.common import remove_comments, create_func_config
|
||||
from metagpt.utils.save_code import save_code_file
|
||||
|
||||
|
||||
class UpdateDataColumns(Action):
|
||||
|
|
@ -100,50 +42,95 @@ class UpdateDataColumns(Action):
|
|||
|
||||
class MLEngineer(Role):
|
||||
def __init__(
|
||||
self, name="ABC", profile="MLEngineer", goal="", auto_run: bool = False,
|
||||
self, name="ABC", profile="MLEngineer", goal="", auto_run: bool = False
|
||||
):
|
||||
super().__init__(name=name, profile=profile, goal=goal)
|
||||
self._set_react_mode(react_mode="plan_and_act")
|
||||
self._watch([DownloadData, SubmitResult])
|
||||
|
||||
self.plan = Plan(goal=goal)
|
||||
self.use_tools = True
|
||||
self.use_code_steps = True
|
||||
self.use_tools = False
|
||||
self.use_code_steps = False
|
||||
self.execute_code = ExecutePyCode()
|
||||
self.auto_run = auto_run
|
||||
self.data_desc = {}
|
||||
|
||||
|
||||
# memory for working on each task, discarded each time a task is done
|
||||
self.working_memory = Memory()
|
||||
|
||||
async def _plan_and_act(self):
|
||||
|
||||
### Actions in a multi-agent multi-turn setting ###
|
||||
memories = self.get_memories()
|
||||
if memories:
|
||||
latest_event = memories[-1].cause_by
|
||||
if latest_event == DownloadData:
|
||||
self.plan.context = memories[-1].content
|
||||
elif latest_event == SubmitResult:
|
||||
# self reflect on previous plan outcomes and think about how to improve the plan, add to working memory
|
||||
await self._reflect()
|
||||
|
||||
# get feedback for improvement from human, add to working memory
|
||||
await self._ask_review(trigger=ReviewConst.TASK_REVIEW_TRIGGER)
|
||||
|
||||
### Common Procedure in both single- and multi-agent setting ###
|
||||
# create initial plan and update until confirmation
|
||||
await self._update_plan()
|
||||
|
||||
|
||||
while self.plan.current_task:
|
||||
task = self.plan.current_task
|
||||
logger.info(f"ready to take on task {task}")
|
||||
|
||||
|
||||
# take on current task
|
||||
code, result, success, code_steps = await self._write_and_exec_code()
|
||||
|
||||
code, result, success = await self._write_and_exec_code()
|
||||
|
||||
# ask for acceptance, users can other refuse and change tasks in the plan
|
||||
task_result_confirmed = await self._ask_review()
|
||||
|
||||
if success and task_result_confirmed:
|
||||
review, task_result_confirmed = await self._ask_review(trigger=ReviewConst.TASK_REVIEW_TRIGGER)
|
||||
|
||||
if self.auto_run:
|
||||
# if human confirms the task result, then we deem the task completed, regardless of whether the code run succeeds;
|
||||
# if auto mode, then the code run has to succeed for the task to be considered completed
|
||||
task_result_confirmed = success
|
||||
|
||||
if task_result_confirmed:
|
||||
# tick off this task and record progress
|
||||
task.code = code
|
||||
task.result = result
|
||||
task.code_steps = code_steps
|
||||
self.plan.finish_current_task()
|
||||
self.working_memory.clear()
|
||||
|
||||
if self.use_tools:
|
||||
success, new_code = await self._update_data_columns()
|
||||
if success:
|
||||
task.code = task.code + "\n\n" + new_code
|
||||
|
||||
confirmed_and_more = (ReviewConst.CONTINUE_WORD[0] in review.lower()
|
||||
and review.lower() not in ReviewConst.CONTINUE_WORD[0]) # "confirm, ... (more content, such as changing downstream tasks)"
|
||||
if confirmed_and_more:
|
||||
self.working_memory.add(Message(content=review, role="user", cause_by=AskReview))
|
||||
await self._update_plan(review)
|
||||
|
||||
elif "redo" in review:
|
||||
# Ask the Role to redo this task with help of review feedback,
|
||||
# useful when the code run is successful but the procedure or result is not what we want
|
||||
continue
|
||||
|
||||
else:
|
||||
# update plan according to user's feedback and to take on changed tasks
|
||||
await self._update_plan()
|
||||
await self._update_plan(review)
|
||||
|
||||
time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
||||
self.execute_code.save_notebook(f"{DATA_PATH}/notebooks/ml_{time}.ipynb")
|
||||
completed_plan_memory = self.get_useful_memories() # completed plan as a outcome
|
||||
self._rc.memory.add(completed_plan_memory[0]) # add to persistent memory
|
||||
|
||||
summary = await SummarizeAnalysis().run(self.plan)
|
||||
rsp = Message(content=summary, cause_by=SummarizeAnalysis)
|
||||
self._rc.memory.add(rsp)
|
||||
|
||||
# save code using datetime.now or keywords related to the goal of your project (plan.goal).
|
||||
project_record = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
save_code_file(name=project_record, code_context=self.execute_code.nb, file_format="ipynb")
|
||||
return rsp
|
||||
|
||||
async def _update_data_columns(self):
|
||||
rsp = await UpdateDataColumns().run(self.plan)
|
||||
is_update, code = rsp["is_update"], rsp["code"]
|
||||
|
|
@ -155,34 +142,36 @@ class MLEngineer(Role):
|
|||
return success, code
|
||||
|
||||
async def _write_and_exec_code(self, max_retry: int = 3):
|
||||
code_steps = (
|
||||
self.plan.current_task.code_steps = (
|
||||
await WriteCodeSteps().run(self.plan)
|
||||
if self.use_code_steps
|
||||
else ""
|
||||
)
|
||||
|
||||
|
||||
counter = 0
|
||||
improve_code = ""
|
||||
success = False
|
||||
debug_context = []
|
||||
|
||||
|
||||
while not success and counter < max_retry:
|
||||
context = self.get_useful_memories()
|
||||
|
||||
if counter > 0:
|
||||
improve_code = await DebugCode().run(plan=self.plan.current_task.instruction,
|
||||
# print("*" * 10)
|
||||
# print(context)
|
||||
# print("*" * 10)
|
||||
# breakpoint()
|
||||
if counter > 0 and self.use_tools:
|
||||
code = await DebugCode().run(
|
||||
plan=self.plan.current_task.instruction,
|
||||
code=code,
|
||||
runtime_result=self.working_memory.get(),
|
||||
context=debug_context)
|
||||
|
||||
if improve_code != "":
|
||||
code = improve_code
|
||||
logger.info(f"new code \n{improve_code}")
|
||||
context=debug_context
|
||||
)
|
||||
logger.info(f"new code \n{code}")
|
||||
cause_by = DebugCode
|
||||
elif not self.use_tools or self.plan.current_task.task_type == "other":
|
||||
logger.info("Write code with pure generation")
|
||||
code = await WriteCodeByGenerate().run(
|
||||
context=context, plan=self.plan, code_steps=code_steps, temperature=0.0
|
||||
context=context, plan=self.plan, temperature=0.0
|
||||
)
|
||||
debug_context = [self.get_useful_memories(task_exclude_field={'result', 'code_steps'})[0]]
|
||||
cause_by = WriteCodeByGenerate
|
||||
|
|
@ -192,47 +181,46 @@ class MLEngineer(Role):
|
|||
tool_context, code = await WriteCodeWithTools(schema_path=schema_path).run(
|
||||
context=context,
|
||||
plan=self.plan,
|
||||
code_steps=code_steps,
|
||||
column_info=self.data_desc.get("column_info", ""),
|
||||
)
|
||||
debug_context = tool_context
|
||||
cause_by = WriteCodeWithTools
|
||||
|
||||
|
||||
self.working_memory.add(
|
||||
Message(content=code, role="assistant", cause_by=cause_by)
|
||||
)
|
||||
|
||||
# debug on code, run on runcode with finished code and new_df
|
||||
# runcode = code_context + "\n\n" + code
|
||||
|
||||
result, success = await self.execute_code.run(code)
|
||||
# truncated the result
|
||||
print(truncate(result))
|
||||
|
||||
print(result)
|
||||
self.working_memory.add(
|
||||
Message(content=truncate(remove_escape_and_color_codes(result)), role="user", cause_by=ExecutePyCode)
|
||||
Message(content=result, role="user", cause_by=ExecutePyCode)
|
||||
)
|
||||
|
||||
|
||||
if "!pip" in code:
|
||||
success = False
|
||||
# if not success:
|
||||
# await self._ask_review()
|
||||
|
||||
|
||||
counter += 1
|
||||
|
||||
return code, result, success, code_steps
|
||||
|
||||
async def _ask_review(self):
|
||||
if not self.auto_run:
|
||||
|
||||
if not success and counter >= max_retry:
|
||||
logger.info("coding failed!")
|
||||
review, _ = await self._ask_review(auto_run=False, trigger=ReviewConst.CODE_REVIEW_TRIGGER)
|
||||
if ReviewConst.CHANGE_WORD[0] in review:
|
||||
counter = 0 # redo the task again with help of human suggestions
|
||||
|
||||
return code, result, success
|
||||
|
||||
async def _ask_review(self, auto_run: bool = None, trigger: str = ReviewConst.TASK_REVIEW_TRIGGER):
|
||||
auto_run = auto_run or self.auto_run
|
||||
if not auto_run:
|
||||
context = self.get_useful_memories()
|
||||
review, confirmed = await AskReview().run(context=context[-5:], plan=self.plan)
|
||||
review, confirmed = await AskReview().run(context=context[-5:], plan=self.plan, trigger=trigger)
|
||||
if not confirmed:
|
||||
self.working_memory.add(Message(content=review, role="user", cause_by=AskReview))
|
||||
return confirmed
|
||||
return True
|
||||
|
||||
async def _update_plan(self, max_tasks: int = 3):
|
||||
return review, confirmed
|
||||
return "", True
|
||||
|
||||
async def _update_plan(self, review: str = "", max_tasks: int = 3, max_retries: int = 3):
|
||||
plan_confirmed = False
|
||||
|
||||
while not plan_confirmed:
|
||||
context = self.get_useful_memories()
|
||||
rsp = await WritePlan().run(
|
||||
|
|
@ -241,43 +229,57 @@ class MLEngineer(Role):
|
|||
self.working_memory.add(
|
||||
Message(content=rsp, role="assistant", cause_by=WritePlan)
|
||||
)
|
||||
plan_confirmed = await self._ask_review()
|
||||
|
||||
new_tasks = WritePlan.rsp_to_tasks(rsp)
|
||||
logger.debug(len(self.plan.tasks))
|
||||
logger.debug(len(new_tasks))
|
||||
## fixme: 能重复执行多轮重新plan,但应该有更优处理逻辑
|
||||
## fixme: do not overwrite original tasks
|
||||
tasks = self.plan.tasks + new_tasks
|
||||
|
||||
self.plan.add_tasks(tasks)
|
||||
|
||||
# precheck plan before asking reviews
|
||||
is_plan_valid, error = precheck_update_plan_from_rsp(rsp, self.plan)
|
||||
if not is_plan_valid and max_retries > 0:
|
||||
error_msg = f"The generated plan is not valid with error: {error}, try regenerating, remember to generate either the whole plan or the single changed task only"
|
||||
logger.warning(error_msg)
|
||||
self.working_memory.add(Message(content=error_msg, role="assistant", cause_by=WritePlan))
|
||||
max_retries -= 1
|
||||
continue
|
||||
|
||||
_, plan_confirmed = await self._ask_review(trigger=ReviewConst.TASK_REVIEW_TRIGGER)
|
||||
|
||||
update_plan_from_rsp(rsp, self.plan)
|
||||
|
||||
self.working_memory.clear()
|
||||
|
||||
def get_useful_memories(self, task_exclude_field: set = None) -> List[Message]:
|
||||
async def _reflect(self):
|
||||
context = self.get_memories()
|
||||
context = "\n".join([str(msg) for msg in context])
|
||||
# print("*" * 10)
|
||||
# print(context)
|
||||
# print("*" * 10)
|
||||
reflection = await Reflect().run(context=context)
|
||||
self.working_memory.add(Message(content=reflection, role="assistant"))
|
||||
self.working_memory.add(Message(content=Reflect.REWRITE_PLAN_INSTRUCTION, role="user"))
|
||||
|
||||
def get_useful_memories(self, task_exclude_field=None) -> List[Message]:
|
||||
"""find useful memories only to reduce context length and improve performance"""
|
||||
# TODO dataset description , code steps
|
||||
if task_exclude_field is None:
|
||||
# Shorten the context as we don't need code steps after we get the codes.
|
||||
# This doesn't affect current_task below, which should hold the code steps
|
||||
task_exclude_field = {'code_steps'}
|
||||
user_requirement = self.plan.goal
|
||||
tasks = json.dumps(
|
||||
[task.dict(exclude=task_exclude_field) for task in self.plan.tasks], indent=4, ensure_ascii=False
|
||||
)
|
||||
data_desc = self.plan.context
|
||||
tasks = [task.dict(exclude=task_exclude_field) for task in self.plan.tasks]
|
||||
tasks = json.dumps(tasks, indent=4, ensure_ascii=False)
|
||||
current_task = self.plan.current_task.json() if self.plan.current_task else {}
|
||||
context = STRUCTURAL_CONTEXT.format(
|
||||
user_requirement=user_requirement,
|
||||
data_desc=self.data_desc,
|
||||
tasks=tasks,
|
||||
current_task=current_task
|
||||
user_requirement=user_requirement, data_desc=data_desc, tasks=tasks, current_task=current_task
|
||||
)
|
||||
context_msg = [Message(content=context, role="user")]
|
||||
|
||||
return context_msg + self.get_working_memories()
|
||||
|
||||
return context_msg + self.working_memory.get()
|
||||
|
||||
@property
|
||||
def working_memory(self):
|
||||
return self._rc.memory
|
||||
def get_working_memories(self) -> List[Message]:
|
||||
return self.working_memory.get()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
# requirement = "Run data analysis on sklearn Diabetes dataset, include a plot"
|
||||
# requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy"
|
||||
# requirement = "Run data analysis on sklearn Wisconsin Breast Cancer dataset, include a plot, train a model to predict targets (20% as validation), and show validation accuracy"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue