1. dynamically add insight

2. insight from scratch in real time
This commit is contained in:
Yizhou Chi 2024-10-11 18:37:35 +08:00
parent f7374c03af
commit eda9322361
6 changed files with 115 additions and 18 deletions

View file

@ -235,7 +235,8 @@ class Node:
score_dict = {k: normalize_score(v) for k, v in score_dict.items()}
self.normalized_reward = score_dict
return score_dict
result_dict = role.get_solution()
return score_dict, result_dict
class MCTS:
@ -281,7 +282,7 @@ class MCTS:
mcts_logger.log("MCTS", f"Start simulating node {node.id}:")
while node.children:
node = random.choice(node.children)
reward = await node.run_node(role)
reward, result_dict = await node.run_node(role)
mcts_logger.log("MCTS", f"Simulated node's reward: {reward}")
return reward
@ -341,12 +342,17 @@ class MCTS:
scores["test_raw"].append(node.raw_reward["test_score"])
return scores
async def search(self, state, rollouts, load_tree=False, reflection=False):
async def search(self, state, args):
reflection = args.reflection
load_tree = args.load_tree
rollouts = args.rollouts
from_scratch = args.from_scratch
role, root = initialize_di_root_node(state, reflection=reflection)
self.root_node = root
self.instruction_generator = InstructionGenerator(
file_path=state["exp_pool_path"], use_fixed_insights=self.use_fixed_insights
state=state, use_fixed_insights=self.use_fixed_insights, from_scratch=from_scratch
)
await self.instruction_generator.initialize()
tree_loaded = False
if load_tree:

View file

@ -24,12 +24,7 @@ class MCTSExperimenter(Experimenter):
mcts = Random(root_node=None, max_depth=5, use_fixed_insights=self.args.use_fixed_insights)
else:
mcts = MCTS(root_node=None, max_depth=5, use_fixed_insights=self.args.use_fixed_insights)
best_nodes = await mcts.search(
state=self.state,
reflection=self.args.reflection,
rollouts=self.args.rollouts,
load_tree=self.args.load_tree,
)
best_nodes = await mcts.search(state=self.state, args=self.args)
best_node = best_nodes["global_best"]
dev_best_node = best_nodes["dev_best"]
score_dict = best_nodes["scores"]

View file

@ -1,6 +1,7 @@
import json
import os
import random
from difflib import SequenceMatcher
from expo.insights.solution_designer import SolutionDesigner
from expo.utils import clean_json_from_rsp, load_data_config, mcts_logger
@ -33,11 +34,21 @@ DATA_CONFIG = load_data_config()
class InstructionGenerator:
data_config = DATA_CONFIG
def __init__(self, file_path, use_fixed_insights=False):
self.file_path = file_path
def __init__(self, state, use_fixed_insights, from_scratch):
self.state = state
self.file_path = state["exp_pool_path"]
self.dataset_info_path = f"{self.data_config['datasets_dir']}/{state['task']}/dataset_info.json"
with open(self.dataset_info_path, "r") as file:
self.dataset_info = json.load(file)
self.use_fixed_insights = use_fixed_insights
self.analysis_pool = self.load_insight_pool(file_path, use_fixed_insights)
self.proposer = SolutionDesigner()
self.from_scratch = from_scratch
async def initialize(self):
if self.from_scratch:
self.insight_pool = await self.generate_solutions_from_scratch(self.dataset_info, self.state["task"])
else:
self.insight_pool = self.load_insight_pool(self.file_path, self.use_fixed_insights)
@staticmethod
def load_json_data(json_dir):
@ -84,14 +95,14 @@ class InstructionGenerator:
data.extend(fixed_insights)
for item in data:
if "task_id" not in item:
raise ValueError("task_id is not found in the analysis pool")
raise ValueError("task_id is not found in the insight_pool")
if task_id:
data = [item for item in data if int(item["task_id"]) == int(task_id)]
return data
async def generate_new_instructions(self, task_id, original_instruction, max_num, ext_info=None):
data = self.analysis_pool
data = self.insight_pool
new_instructions = []
if len(data) == 0:
mcts_logger.log("MCTS", f"No insights available for task {task_id}")
@ -108,6 +119,34 @@ class InstructionGenerator:
new_instructions.append(new_instruction)
return new_instructions
async def propose_new_insights(self, solution, score):
new_insights = await self.proposer.propose_insights(solution, score)
added_insights = self.add_insight(new_insights)
return added_insights
async def generate_solutions_from_scratch(self, dataset_info, dataset_name):
insight_pool = await self.proposer.generate_solutions(dataset_info, dataset_name, save_analysis_pool=False)
return insight_pool
def add_insight(self, new_insights):
added_insights = []
for new_insight in new_insights:
if not self.is_similar_to_existing(new_insight):
added_insights.append(new_insight)
self.insight_pool.append(new_insight)
return added_insights
def is_similar_to_existing(self, new_insight, similarity_threshold=0.8):
for existing_insight in self.insight_pool:
similarity = self.calculate_similarity(new_insight["Analysis"], existing_insight["Analysis"])
if similarity > similarity_threshold:
return True
return False
@staticmethod
def calculate_similarity(text1, text2):
return SequenceMatcher(None, text1, text2).ratio()
@staticmethod
async def generate_new_instruction(original_instruction, insights, ext_info):
prompt = CHANGE_INSTRUCTION.format(instruction=original_instruction, insights=insights)

View file

@ -70,6 +70,45 @@ Your model choices should be advanced enough to be helpful.
```
"""
INSIGHT_PROPOSAL_PROMPT = """
You are an AI assistant tasked with analyzing a machine learning solution and proposing new insights to improve its performance. Given the current solution code and development score, suggest innovative approaches to enhance the model.
Current Solution Code:
{solution_code}
Development Score: {dev_score}
Based on this information, propose 3-5 new insights across different aspects of the machine learning pipeline (Data Preprocessing, Feature Engineering, and Model Training). Your insights should be specific, actionable, and have the potential to improve the model's performance.
Please format your response as a JSON array with the following structure:
[
{{
"task_type": "Data Preprocessing",
"insights": [
"insight1",
"insight2"
]
}},
{{
"task_type": "Feature Engineering",
"insights": [
"insight1",
"insight2"
]
}},
{{
"task_type": "Model Training",
"insights": [
"insight1",
"insight2"
]
}}
]
"""
KEY_DATASET_FEATURES = [
"NumberOfClasses",
"NumberOfFeatures",
@ -86,7 +125,7 @@ TASK_TO_ID = {"EDA": 1, "Data Preprocessing": 2, "Feature Engineering": 3, "Mode
class SolutionDesigner:
data_dir: str = DATA_CONFIG["datasets_dir"]
async def generate_solutions(self, dataset_info, dataset_name):
async def generate_solutions(self, dataset_info, dataset_name, save_analysis_pool=True):
llm = LLM()
context = DATASET_INSIGHT_PROMPT.format(
dataset=dataset_info["description"],
@ -96,8 +135,18 @@ class SolutionDesigner:
rsp = await llm.aask(context)
rsp = clean_json_from_rsp(rsp)
analysis_pool = self.process_analysis_pool(json.loads(rsp))
dataset_path = f"{self.data_dir}/{dataset_name}"
self.save_analysis_pool(dataset_path, analysis_pool)
if save_analysis_pool:
dataset_path = f"{self.data_dir}/{dataset_name}"
self.save_analysis_pool(dataset_path, analysis_pool)
return analysis_pool
async def propose_new_insights(self, solution, score):
llm = LLM()
context = INSIGHT_PROPOSAL_PROMPT.format(solution_code=solution, dev_score=score)
rsp = await llm.aask(context)
rsp = clean_json_from_rsp(rsp)
new_insights = self.process_analysis_pool(json.loads(rsp))
return new_insights
def process_analysis_pool(self, insights_rsp):
analysis_pool = []

View file

@ -139,6 +139,11 @@ class ResearchAssistant(DataInterpreter):
save_notebook(role=self, save_dir=self.role_dir, name=self.get_node_name())
return task_result
def get_solution(self):
codes = [task.code for task in self.planner.plan.tasks]
results = [task.result for task in self.planner.plan.tasks]
return {"codes": codes, "results": results}
def save_state(self, static_save=False):
"""
attribute:

View file

@ -32,6 +32,9 @@ def get_mcts_args(parser):
parser.add_argument("--rollouts", type=int, default=5)
parser.add_argument("--use_fixed_insights", dest="use_fixed_insights", action="store_true")
parser.add_argument("--start_task_id", type=int, default=2)
parser.add_argument(
"--from_scratch", dest="from_scratch", action="store_true", help="Generate solutions from scratch"
)
def get_aug_exp_args(parser):