mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Update drop.py
Change comments into English, fix the in/out params' type. fix too many values to unpack in line 140. Unify the quotes. Remove "if" at line 148
This commit is contained in:
parent
17f3cd4955
commit
6ebf3c47c2
1 changed files with 20 additions and 27 deletions
|
|
@ -19,7 +19,7 @@ def is_number(text: str) -> bool:
|
|||
except ValueError:
|
||||
return False
|
||||
|
||||
def normalize_answer(s):
|
||||
def normalize_answer(s: str) -> List[str]:
|
||||
"""
|
||||
Normalize answers for evaluation.
|
||||
"""
|
||||
|
|
@ -39,7 +39,7 @@ def normalize_answer(s):
|
|||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
def calculate_score(ground_truth: str, prediction: str):
|
||||
def calculate_score(ground_truth: str, prediction: str) -> Tuple[float, str]:
|
||||
"""
|
||||
Compute the F1 score between prediction and ground truth answers.
|
||||
"""
|
||||
|
|
@ -71,33 +71,33 @@ def log_mismatch(problem: str, expected_output, prediction: str, predicted_numbe
|
|||
|
||||
log_file = os.path.join(path, 'log.json')
|
||||
|
||||
# 检查log文件是否已经存在
|
||||
# Check if the log file already exists
|
||||
if os.path.exists(log_file):
|
||||
# 如果存在,加载现有的日志数据
|
||||
# If it exists, load the existing log data
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
data = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
data = []
|
||||
else:
|
||||
# 如果不存在,创建一个新的日志列表
|
||||
# If it doesn't exist, create an empty list
|
||||
data = []
|
||||
|
||||
# 添加新的日志记录
|
||||
# Add the new log data to the existing list
|
||||
data.append(log_data)
|
||||
|
||||
# 将数据写回到log.json文件
|
||||
# Write the updated list back to the log file
|
||||
with open(log_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
async def load_data(file_path: str, specific_indices: List[int] = None) -> List[dict]:
|
||||
data = []
|
||||
# 异步读取文件内容
|
||||
# Read the data from the file
|
||||
async with aiofiles.open(file_path, mode="r", encoding='utf-8') as file:
|
||||
async for line in file:
|
||||
data.append(json.loads(line))
|
||||
|
||||
# 然后在随机选择的样本中基于特定索引列表进行进一步筛选
|
||||
# Then further filter based on a specific index list in randomly selected samples
|
||||
if specific_indices is not None:
|
||||
filtered_data = [data[i] for i in specific_indices if i < len(data)]
|
||||
return filtered_data
|
||||
|
|
@ -105,22 +105,22 @@ async def load_data(file_path: str, specific_indices: List[int] = None) -> List[
|
|||
return data
|
||||
|
||||
def save_results_to_csv(results: List[Tuple[str, str, str, int]], path):
|
||||
# 创建 DataFrame
|
||||
# Create a DataFrame from the results
|
||||
df = pd.DataFrame(results, columns=["inputs", "prediction", "expected_output", "score", "cost"])
|
||||
|
||||
# 计算统计数据
|
||||
# Calculate the average score and cost
|
||||
avg_score = df["score"].mean()
|
||||
t_cost = df["cost"].max()
|
||||
a_cost = t_cost / len(df) if len(df) > 0 else 0
|
||||
|
||||
# 获取当前时间,格式为 YYYYMMDD_HHMMSS
|
||||
# Get the current time in the format YYYYMMDD_HHMMSS
|
||||
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 生成文件名,包含平均分和当前时间,保留五位小数
|
||||
# Generate a filename with the average score and the current time, rounded to five decimal places
|
||||
filename = f"{avg_score:.5f}_{current_time}.csv"
|
||||
output_file = os.path.join(path, filename)
|
||||
|
||||
# 保存到 CSV
|
||||
# Save the DataFrame to a CSV file
|
||||
df.to_csv(output_file, index=False)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
|
|
@ -137,26 +137,19 @@ async def evaluate_problem(annotation: dict, graph: Callable, log_path) -> Tuple
|
|||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
output, cost = await graph(inputs) if graph else "None"
|
||||
output, cost = await graph(inputs) if graph else (None, None)
|
||||
|
||||
f1_scores = []
|
||||
|
||||
# if '|' in the output, split it and calculate the score for each part
|
||||
|
||||
# if "|" in the output, split it and calculate the score for each part
|
||||
for answer in answers:
|
||||
if answer.strip() != "":
|
||||
if '|' in output:
|
||||
output_parts = output.split('|')
|
||||
for output_part in output_parts:
|
||||
f1_score, _ = calculate_score(answer, output_part)
|
||||
f1_scores.append(f1_score)
|
||||
else:
|
||||
f1_score, _ = calculate_score(answer, output)
|
||||
output_parts = output.split("|")
|
||||
for output_part in output_parts:
|
||||
f1_score, _ = calculate_score(answer, output_part)
|
||||
f1_scores.append(f1_score)
|
||||
|
||||
max_score = max(f1_scores)
|
||||
|
||||
uni_score = max_score
|
||||
uni_score = max(f1_scores)
|
||||
|
||||
print("uni_score", uni_score)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue