mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
1. avoid circular reference
2. add greedy
This commit is contained in:
parent
e07ed0df8b
commit
a6f71f4498
2 changed files with 23 additions and 23 deletions
9
expo/Greedy.py
Normal file
9
expo/Greedy.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from expo.MCTS import MCTS
|
||||
|
||||
|
||||
class Greedy(MCTS):
|
||||
def best_child(self):
|
||||
if len(self.children) == 0:
|
||||
return self.root_node
|
||||
all_children = [child for children in self.children.values() for child in children]
|
||||
return max(all_children, key=lambda x: x.normalized_reward.get("dev_score", 0))
|
||||
37
expo/MCTS.py
37
expo/MCTS.py
|
|
@ -143,6 +143,7 @@ class Node:
|
|||
role.state_saved = False
|
||||
role.change_next_instruction(self.action)
|
||||
mcts_logger.log("MCTS", f"Saving new role: {role.node_id}")
|
||||
role = role.model_copy()
|
||||
role.save_state(static_save=True)
|
||||
|
||||
async def expand(self, max_children):
|
||||
|
|
@ -165,18 +166,6 @@ class Node:
|
|||
node.save_new_role(new_role)
|
||||
self.add_child(node)
|
||||
|
||||
# def evaluate_test(self):
|
||||
# prediction_fpath = os.path.join(self.state["work_dir"], self.state["task"], "predictions.csv")
|
||||
# predictions = pd.read_csv(prediction_fpath)["target"]
|
||||
# # copy predictions.csv to the node_dir
|
||||
# predictions_node_fpath = os.path.join(self.state["node_dir"], "Node-{self.id}-predictions.csv")
|
||||
# predictions.to_csv(predictions_node_fpath, index=False)
|
||||
# # load test_target.csv
|
||||
# split_datasets_dir = self.state["datasets_dir"]
|
||||
# gt = pd.read_csv(os.path.join(split_datasets_dir["test_target"]))["target"]
|
||||
# metric = self.state["dataset_config"]["metric"]
|
||||
# return evaluate_score(predictions, gt, metric)
|
||||
|
||||
def evaluate_prediction(self, split):
|
||||
pred_path = os.path.join(self.state["work_dir"], self.state["task"], f"{split}_predictions.csv")
|
||||
pred_node_path = os.path.join(self.state["node_dir"], f"Node-{self.id}-{split}_predictions.csv")
|
||||
|
|
@ -331,14 +320,10 @@ class MCTS:
|
|||
self.children[root] = []
|
||||
reward = await self.simulate(root, role)
|
||||
self.backpropagate(root, reward)
|
||||
children = await self.expand(root)
|
||||
# 目前是随机选择1个,后续可以改成多个
|
||||
first_leaf = random.choice(children)
|
||||
reward = await self.simulate(first_leaf)
|
||||
self.backpropagate(first_leaf, reward)
|
||||
node, reward = await self.expand_and_simulate(root)
|
||||
# self.backpropagate(node, reward)
|
||||
else:
|
||||
root = self.root_node
|
||||
# 后续迭代:使用UCT进行选择,expand并模拟和反向传播
|
||||
for _ in range(rollouts): # number of rollouts
|
||||
mcts_logger.log("MCTS", f"Start the next rollout {_+1}")
|
||||
node = self.select(root)
|
||||
|
|
@ -350,13 +335,19 @@ class MCTS:
|
|||
mcts_logger.log("MCTS", f"Terminal node's reward: {reward}")
|
||||
self.backpropagate(node, reward)
|
||||
else:
|
||||
if node.visited > 0:
|
||||
children = await self.expand(node)
|
||||
node = random.choice(children)
|
||||
reward = await self.simulate(node)
|
||||
self.backpropagate(node, reward)
|
||||
node, reward = await self.expand_and_simulate(node)
|
||||
# self.backpropagate(node, reward)
|
||||
return self.best_path(root)
|
||||
|
||||
async def expand_and_simulate(self, node):
|
||||
# Expand and randomly select a child node, then simulate it
|
||||
if node.visited > 0:
|
||||
children = await self.expand(node)
|
||||
node = random.choice(children)
|
||||
reward = await self.simulate(node)
|
||||
self.backpropagate(node, reward)
|
||||
return node, reward
|
||||
|
||||
def load_tree(self):
|
||||
def load_children_node(node):
|
||||
mcts_logger.log("MCTS", f"Load node {node.id}'s child: {node.children}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue