mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 11:26:23 +02:00
format code
This commit is contained in:
parent
fcd1ba66a6
commit
ab8a1d6824
17 changed files with 433 additions and 396 deletions
|
|
@ -1,50 +1,58 @@
|
|||
import yaml
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
|
||||
# from nbclient import NotebookClient
|
||||
from nbformat.notebooknode import NotebookNode
|
||||
import nbformat
|
||||
from pathlib import Path
|
||||
from loguru import logger as _logger
|
||||
from datetime import datetime
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import nbformat
|
||||
import yaml
|
||||
from loguru import logger as _logger
|
||||
|
||||
# from nbclient import NotebookClient
|
||||
from nbformat.notebooknode import NotebookNode
|
||||
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
|
||||
def load_data_config(file_path="data.yaml"):
|
||||
with open(file_path, 'r') as stream:
|
||||
with open(file_path, "r") as stream:
|
||||
data_config = yaml.safe_load(stream)
|
||||
return data_config
|
||||
|
||||
|
||||
DATA_CONFIG = load_data_config()
|
||||
|
||||
|
||||
def get_mcts_logger():
|
||||
print_level = "INFO"
|
||||
print_level2 = "MCTS"
|
||||
logfile_level="MCTS"
|
||||
logfile_level = "MCTS"
|
||||
name: str = None
|
||||
current_date = datetime.now()
|
||||
formatted_date = current_date.strftime("%Y%m%d")
|
||||
log_name = f"{name}_{formatted_date}" if name else formatted_date # name a log with prefix name
|
||||
|
||||
_logger.remove()
|
||||
new_level = _logger.level(logfile_level, color="<green>", no=25)
|
||||
_logger.level(logfile_level, color="<green>", no=25)
|
||||
_logger.add(sys.stderr, level=print_level)
|
||||
_logger.add(sys.stderr, level=print_level2)
|
||||
_logger.add(Path(DATA_CONFIG["work_dir"]) / DATA_CONFIG["role_dir"] / f"{log_name}.txt", level=logfile_level)
|
||||
_logger.propagate = False
|
||||
return _logger
|
||||
|
||||
|
||||
mcts_logger = get_mcts_logger()
|
||||
|
||||
|
||||
def get_exp_pool_path(task_name, data_config, pool_name="analysis_pool"):
|
||||
datasets_dir = data_config['datasets_dir']
|
||||
if task_name in data_config['datasets']:
|
||||
dataset = data_config['datasets'][task_name]
|
||||
data_path = os.path.join(datasets_dir, dataset['dataset'])
|
||||
datasets_dir = data_config["datasets_dir"]
|
||||
if task_name in data_config["datasets"]:
|
||||
dataset = data_config["datasets"][task_name]
|
||||
data_path = os.path.join(datasets_dir, dataset["dataset"])
|
||||
else:
|
||||
raise ValueError(f"Dataset {task_name} not found in config file. Available datasets: {data_config['datasets'].keys()}")
|
||||
raise ValueError(
|
||||
f"Dataset {task_name} not found in config file. Available datasets: {data_config['datasets'].keys()}"
|
||||
)
|
||||
exp_pool_path = os.path.join(data_path, f"{pool_name}.json")
|
||||
return exp_pool_path
|
||||
|
||||
|
|
@ -60,7 +68,6 @@ def change_plan(role, plan):
|
|||
if not finished:
|
||||
tasks[i].plan = plan
|
||||
return finished
|
||||
|
||||
|
||||
|
||||
def is_cell_to_delete(cell: NotebookNode) -> bool:
|
||||
|
|
@ -82,12 +89,14 @@ def process_cells(nb: NotebookNode) -> NotebookNode:
|
|||
nb["cells"] = new_cells
|
||||
return nb
|
||||
|
||||
|
||||
def save_notebook(role: Role, save_dir: str = "", name: str = ""):
|
||||
save_dir = Path(save_dir)
|
||||
nb = process_cells(role.execute_code.nb)
|
||||
file_path = save_dir / f"{name}.ipynb"
|
||||
nbformat.write(nb, file_path)
|
||||
|
||||
|
||||
async def load_execute_notebook(role):
|
||||
tasks = role.planner.plan.tasks
|
||||
codes = [task.code for task in tasks if task.code]
|
||||
|
|
@ -99,6 +108,7 @@ async def load_execute_notebook(role):
|
|||
print("Finish executing the loaded notebook")
|
||||
return executor
|
||||
|
||||
|
||||
def clean_json_from_rsp(text):
|
||||
pattern = r"```json(.*?)```"
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
|
|
@ -106,4 +116,4 @@ def clean_json_from_rsp(text):
|
|||
json_str = "\n".join(matches)
|
||||
return json_str
|
||||
else:
|
||||
return ""
|
||||
return ""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue