Update test_curve.py

This commit is contained in:
didi 2024-09-26 21:56:54 +08:00
parent e8f6186a56
commit 040a7324eb

View file

@ -69,7 +69,11 @@ test_curve_ci_data = {
}
# 创建一个正方形图表
plt.figure(figsize=(10, 10))
plt.figure(figsize=(10, 8))
# 设置全局字体为 Times New Roman 并加粗
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.weight'] = 'bold'
# 绘制每个数据集
for label, data in test_curve_avg_data.items():
@ -80,7 +84,7 @@ for label, data in test_curve_avg_data.items():
rounds = rounds + [30]
scores = scores + [scores[-1]]
plt.step(rounds, scores, label=label, where='post')
plt.step(rounds, scores, label=label, where='post', linewidth=2)
# 添加置信区间
ci_data = test_curve_ci_data[label]
@ -96,25 +100,28 @@ for label, data in test_curve_avg_data.items():
# 绘制置信区间区域
plt.fill_between(ci_rounds, ci_lower, ci_upper, alpha=0.2, step='post')
# 设置y轴的范围为40到100,使变化更加剧烈
plt.ylim(40, 100)
# 设置y轴的范围为70到98,使变化更加剧烈
plt.ylim(70, 98)
# 添加标题和轴标签
plt.title("SOPTimizer's iteraton performance across tasks (%)", fontsize=16)
plt.xlabel('Iteration', fontsize=14)
plt.ylabel('Performance (%)', fontsize=14)
plt.title("SOPTimizer's iteraton performance across tasks (%)", fontsize=16, fontweight='bold')
plt.xlabel('Iteration', fontsize=14, fontweight='bold')
plt.ylabel('Performance (%)', fontsize=14, fontweight='bold')
# 显示网格
plt.grid(True, linestyle='--', alpha=0.7)
# 将图例放在图外面
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=12)
# 将图例放在图内左上角
plt.legend(loc='upper left', fontsize=12)
# 调整布局以确保图例完全显示
plt.tight_layout()
# 设置y轴刻度增加刻度数量
plt.yticks(range(40, 101, 5))
plt.yticks(range(70, 99, 2))
# 加粗刻度标签
plt.tick_params(axis='both', which='major', labelsize=10, width=2)
# 保存图表为PDF
plt.savefig('test_curve.pdf', format='pdf', bbox_inches='tight')