Update drop.py

Change comments into English, fix the in/out params'  type.
fix too many values to unpack in line 140.
Unify the quotes.
Remove "if" at line 148
This commit is contained in:
Zhaoyang Yu 2024-10-19 11:40:27 +08:00
parent 17f3cd4955
commit 6ebf3c47c2

View file

@ -19,7 +19,7 @@ def is_number(text: str) -> bool:
except ValueError:
return False
def normalize_answer(s):
def normalize_answer(s: str) -> List[str]:
"""
Normalize answers for evaluation.
"""
@ -39,7 +39,7 @@ def normalize_answer(s):
return white_space_fix(remove_articles(remove_punc(lower(s))))
def calculate_score(ground_truth: str, prediction: str):
def calculate_score(ground_truth: str, prediction: str) -> Tuple[float, str]:
"""
Compute the F1 score between prediction and ground truth answers.
"""
@ -71,33 +71,33 @@ def log_mismatch(problem: str, expected_output, prediction: str, predicted_numbe
log_file = os.path.join(path, 'log.json')
# 检查log文件是否已经存在
# Check if the log file already exists
if os.path.exists(log_file):
# 如果存在,加载现有的日志数据
# If it exists, load the existing log data
with open(log_file, 'r', encoding='utf-8') as f:
try:
data = json.load(f)
except json.JSONDecodeError:
data = []
else:
# 如果不存在,创建一个新的日志列表
# If it doesn't exist, create an empty list
data = []
# 添加新的日志记录
# Add the new log data to the existing list
data.append(log_data)
# 将数据写回到log.json文件
# Write the updated list back to the log file
with open(log_file, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=4, ensure_ascii=False)
async def load_data(file_path: str, specific_indices: List[int] = None) -> List[dict]:
data = []
# 异步读取文件内容
# Read the data from the file
async with aiofiles.open(file_path, mode="r", encoding='utf-8') as file:
async for line in file:
data.append(json.loads(line))
# 然后在随机选择的样本中基于特定索引列表进行进一步筛选
# Then further filter based on a specific index list in randomly selected samples
if specific_indices is not None:
filtered_data = [data[i] for i in specific_indices if i < len(data)]
return filtered_data
@ -105,22 +105,22 @@ async def load_data(file_path: str, specific_indices: List[int] = None) -> List[
return data
def save_results_to_csv(results: List[Tuple[str, str, str, int]], path):
# 创建 DataFrame
# Create a DataFrame from the results
df = pd.DataFrame(results, columns=["inputs", "prediction", "expected_output", "score", "cost"])
# 计算统计数据
# Calculate the average score and cost
avg_score = df["score"].mean()
t_cost = df["cost"].max()
a_cost = t_cost / len(df) if len(df) > 0 else 0
# 获取当前时间,格式为 YYYYMMDD_HHMMSS
# Get the current time in the format YYYYMMDD_HHMMSS
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
# 生成文件名,包含平均分和当前时间,保留五位小数
# Generate a filename with the average score and the current time, rounded to five decimal places
filename = f"{avg_score:.5f}_{current_time}.csv"
output_file = os.path.join(path, filename)
# 保存到 CSV
# Save the DataFrame to a CSV file
df.to_csv(output_file, index=False)
print(f"Results saved to {output_file}")
@ -137,26 +137,19 @@ async def evaluate_problem(annotation: dict, graph: Callable, log_path) -> Tuple
while retries < max_retries:
try:
output, cost = await graph(inputs) if graph else "None"
output, cost = await graph(inputs) if graph else (None, None)
f1_scores = []
# if '|' in the output, split it and calculate the score for each part
# if "|" in the output, split it and calculate the score for each part
for answer in answers:
if answer.strip() != "":
if '|' in output:
output_parts = output.split('|')
for output_part in output_parts:
f1_score, _ = calculate_score(answer, output_part)
f1_scores.append(f1_score)
else:
f1_score, _ = calculate_score(answer, output)
output_parts = output.split("|")
for output_part in output_parts:
f1_score, _ = calculate_score(answer, output_part)
f1_scores.append(f1_score)
max_score = max(f1_scores)
uni_score = max_score
uni_score = max(f1_scores)
print("uni_score", uni_score)