insight pool is now able to dynamically increase

This commit is contained in:
Yizhou Chi 2024-10-11 14:48:56 +08:00
parent ae12d73747
commit 56e7a08a1c
3 changed files with 20 additions and 13 deletions

View file

@ -155,18 +155,15 @@ class Node:
role = role.model_copy()
role.save_state(static_save=True)
async def expand(self, max_children, use_fixed_insights):
async def expand(self, max_children: int, instruction_generator: InstructionGenerator):
if self.is_fully_expanded():
return
insight_geneartor = InstructionGenerator()
role = self.load_role()
original_instruction = role.get_next_instruction()
insights = await insight_geneartor.generate_new_instructions(
insights = await instruction_generator.generate_new_instructions(
task_id=role.start_task_id + 1,
original_instruction=original_instruction,
max_num=max_children,
file_path=self.state["exp_pool_path"],
use_fixed_insights=use_fixed_insights,
)
new_state = self.state.copy()
new_state["start_task_id"] += 1
@ -249,6 +246,8 @@ class MCTS:
c_explore: float = 1.4
c_unvisited: float = 0.8
node_order: list = []
# insight generator
instruction_generator: InstructionGenerator = None
def __init__(self, root_node, max_depth, use_fixed_insights):
self.root_node = root_node
@ -272,7 +271,7 @@ class MCTS:
return max(all_children, key=uct)
async def expand(self, node: Node, max_children=5):
await node.expand(max_children, self.use_fixed_insights)
await node.expand(max_children, self.instruction_generator)
if node not in self.children or not self.children[node]:
self.children[node] = node.children
return node.children
@ -284,6 +283,7 @@ class MCTS:
node = random.choice(node.children)
reward = await node.run_node(role)
mcts_logger.log("MCTS", f"Simulated node's reward: {reward}")
return reward
def backpropagate(self, node: Node, reward):
@ -344,6 +344,10 @@ class MCTS:
async def search(self, state, rollouts, load_tree=False, reflection=False):
role, root = initialize_di_root_node(state, reflection=reflection)
self.root_node = root
self.instruction_generator = InstructionGenerator(
file_path=state["exp_pool_path"], use_fixed_insights=self.use_fixed_insights
)
tree_loaded = False
if load_tree:
tree_loaded = self.load_tree()

View file

@ -2,6 +2,7 @@ import json
import os
import random
from expo.insights.solution_designer import SolutionDesigner
from expo.utils import clean_json_from_rsp, load_data_config, mcts_logger
from metagpt.llm import LLM
from metagpt.schema import Message
@ -32,6 +33,12 @@ DATA_CONFIG = load_data_config()
class InstructionGenerator:
data_config = DATA_CONFIG
def __init__(self, file_path, use_fixed_insights=False):
self.file_path = file_path
self.use_fixed_insights = use_fixed_insights
self.analysis_pool = self.load_analysis_pool(file_path, use_fixed_insights)
self.proposer = SolutionDesigner()
@staticmethod
def load_json_data(json_dir):
with open(json_dir, "r") as file:
@ -83,13 +90,8 @@ class InstructionGenerator:
data = [item for item in data if int(item["task_id"]) == int(task_id)]
return data
@staticmethod
async def generate_new_instructions(
task_id, original_instruction, max_num, file_path, ext_info=None, use_fixed_insights=False
):
data = InstructionGenerator.load_analysis_pool(
file_path, task_id=task_id, use_fixed_insights=use_fixed_insights
)
async def generate_new_instructions(self, task_id, original_instruction, max_num, ext_info=None):
data = self.analysis_pool
new_instructions = []
if len(data) == 0:
mcts_logger.log("MCTS", f"No insights available for task {task_id}")

View file

@ -21,6 +21,7 @@ The insights should be proposed based on the dataset description with different
Each task type should have at least 5 insights.
Make sure each method is diverse enough and can be implemented separately.
Be specific about models' choices, ensemble and tuning techniques, and preprocessing & feature engineering techniques.
Your model choices should be advanced enough to be helpful.
# Format
```json