Update Operator & Benchmark

This commit is contained in:
didi 2024-10-21 23:08:51 +08:00
parent fe3fca514a
commit 2d1d7ca219
8 changed files with 87 additions and 51 deletions

View file

@ -40,7 +40,7 @@ class GSM8KBenchmark(BaseBenchmark):
async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[str, str, float, float, float]:
max_retries = 5
retries = 0
while retries < max_retries:
try:
prediction, cost = await graph(problem["question"])

View file

@ -1,9 +1,16 @@
# -*- coding: utf-8 -*-
# @Date : 2024-10-20
# @Author : MoshiQAQ & didi
# @Desc : Download and extract dataset files
import os
import requests
import tarfile
from tqdm import tqdm
from typing import List, Dict
def download_file(url, filename):
def download_file(url: str, filename: str) -> None:
"""Download a file from the given URL and show progress."""
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
block_size = 1024
@ -15,22 +22,49 @@ def download_file(url, filename):
progress_bar.update(size)
progress_bar.close()
def extract_tar_gz(filename, extract_path):
def extract_tar_gz(filename: str, extract_path: str) -> None:
"""Extract a tar.gz file to the specified path."""
with tarfile.open(filename, 'r:gz') as tar:
tar.extractall(path=extract_path)
url = "https://drive.google.com/uc?export=download&id=1tXp5cLw89egeKRwDuood2TPqoEWd8_C0"
filename = "aflow_data.tar.gz"
extract_path = "./"
def process_dataset(url: str, filename: str, extract_path: str) -> None:
"""Download, extract, and clean up a dataset."""
print(f"Downloading {filename}...")
download_file(url, filename)
print(f"Extracting {filename}...")
extract_tar_gz(filename, extract_path)
print(f"{filename} download and extraction completed.")
os.remove(filename)
print(f"Removed {filename}")
print("Downloading data file...")
download_file(url, filename)
# Define the datasets to be downloaded
# Users can modify this list to choose which datasets to download
datasets_to_download: List[Dict[str, str]] = [
{
"name": "datasets",
"url": "https://drive.google.com/uc?export=download&id=1tXp5cLw89egeKRwDuood2TPqoEWd8_C0",
"filename": "aflow_data.tar.gz",
"extract_path": "examples/aflow/data"
},
{
"name": "results",
"url": "", # Please fill in the correct URL
"filename": "result.tar.gz",
"extract_path": "examples/aflow/data/results"
},
{
"name": "initial_rounds",
"url": "", # Please fill in the correct URL
"filename": "first_round.tar.gz",
"extract_path": "examples/aflow/scripts/optimized"
}
]
print("Extracting data file...")
extract_tar_gz(filename, extract_path)
print("Download and extraction completed.")
# Clean up the compressed file
os.remove(filename)
print(f"Removed {filename}")
def download(datasets):
"""Main function to process all selected datasets."""
for dataset_name in datasets:
dataset = datasets_to_download[dataset_name]
process_dataset(dataset['url'], dataset['filename'], dataset['extract_path'])

View file

@ -39,25 +39,21 @@ class Evaluator:
data_path = self._get_data_path(dataset, is_test)
benchmark_class = self.dataset_configs[dataset]
benchmark = benchmark_class(dataset, data_path, path)
benchmark = benchmark_class(name=dataset, file_path=data_path, log_path=path)
# Use params to configure the graph and benchmark
configured_graph = await self._configure_graph(graph, params)
configured_graph = await self._configure_graph(dataset, graph, params)
va_list = [1,2,3] # Use va_list from params, or use default value if not provided
return await benchmark.run_evaluation(configured_graph, va_list)
async def _configure_graph(self, graph, params: dict):
async def _configure_graph(self, dataset, graph, params: dict):
# Here you can configure the graph based on params
# For example: set LLM configuration, dataset configuration, etc.
dataset_config = params.get("dataset", {})
llm_config = params.get("llm_config", {})
return graph(name=self.dataset_configs[dataset]["name"], llm_config=llm_config, dataset=dataset_config)
return graph(name=dataset, llm_config=llm_config, dataset=dataset_config)
def _get_data_path(self, dataset: DatasetType, test: bool) -> str:
base_path = f"examples/aflow/data/{dataset.lower()}"
return f"{base_path}_test.jsonl" if test else f"{base_path}_validate.jsonl"
# Alias methods for backward compatibility
for dataset in ["gsm8k", "math", "humaneval", "mbpp", "hotpotqa", "drop"]:
setattr(Evaluator, f"_{dataset}_eval", lambda self, *args, dataset=dataset.upper(), **kwargs: self.graph_evaluate(dataset, *args, **kwargs))

View file

@ -82,23 +82,27 @@ class Optimizer:
retry_count = 0
max_retries = 1
while retry_count < max_retries:
try:
score = loop.run_until_complete(self._optimize_graph())
break
except Exception as e:
retry_count += 1
logger.info(f"Error occurred: {e}. Retrying... (Attempt {retry_count}/{max_retries})")
if retry_count == max_retries:
logger.info("Max retries reached. Moving to next round.")
score = None
score = loop.run_until_complete(self._optimize_graph())
wait_time = 5 * retry_count
time.sleep(wait_time)
# while retry_count < max_retries:
# try:
# score = loop.run_until_complete(self._optimize_graph())
# break
# except Exception as e:
# retry_count += 1
# logger.info(f"Error occurred: {e}. Retrying... (Attempt {retry_count}/{max_retries})")
# if retry_count == max_retries:
# logger.info("Max retries reached. Moving to next round.")
# score = None
if retry_count < max_retries:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# wait_time = 5 * retry_count
# time.sleep(wait_time)
# if retry_count < max_retries:
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
self.round += 1
logger.info(f"Score for round {self.round}: {score}")
@ -114,7 +118,7 @@ class Optimizer:
time.sleep(5)
async def _optimize_graph(self):
validation_n = 5 # validation datasets's execution number
validation_n = 2 # validation datasets's execution number
graph_path = f"{self.root_path}/workflows"
data = self.data_utils.load_results(graph_path)