fix round count bug

This commit is contained in:
liangliang 2023-10-11 16:27:42 +08:00
parent 465f362cf8
commit 985aa7b7c8

View file

@ -4,7 +4,6 @@ import os
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import math
def extract_logs(filename, start_time=None, end_time=None):
with open(filename, 'r') as f:
lines = f.readlines()
@ -26,7 +25,6 @@ def extract_logs(filename, start_time=None, end_time=None):
break
return logs_block
def extract_time_from_last_round_zero(log_filename):
with open(log_filename, 'r') as f:
lines = f.readlines()
@ -39,7 +37,6 @@ def extract_time_from_last_round_zero(log_filename):
if match:
return match.group(1)
return None
def analyze_log_block(logs_block):
rounds: list[int] = []
items_collected :list[int] = []
@ -59,55 +56,40 @@ def analyze_log_block(logs_block):
n = int(re.search(r'round_id:(\d+)', line).group(1))
if n not in rounds:
rounds.append(n)
items_collected.append(total_items) # add previous total before updating
round_start = True
check_for_task = True
check_for_info = True
if round_start:
if "Curriculum Agent human message" in line:
line_after_task = 0
continue
if check_for_task:
if check_for_info:
line_after_task += 1
if line_after_task <= 15:
if "Completed tasks so far:" in line:
tasks = line.replace("Completed tasks so far:", "").strip().split(", ")
completed_tasks.append(len(tasks))
if "Failed tasks that are too hard:" in line:
tasks = line.replace("Failed tasks that are too hard:", "").strip().split(", ")
failed_tasks.append(len(tasks))
check_for_task = False
if "Critic Agent human message" in line:
line_after_message = 0
continue
if check_for_info:
line_after_message += 1
if line_after_message <= 20:
if line_after_task <= 20:
if "Position: x=" in line:
match = re.search(r'Position: x=([\d.-]+), y=([\d.-]+), z=([\d.-]+)', line)
if match:
x, y, z = float(match.group(1)), float(match.group(2)), float(match.group(3))
positions.append((x, y, z))
x, y, z = float(match.group(1)), float(match.group(2)), float(match.group(3))
positions.append((x, y, z))
if "Inventory (" in line:
if ": Empty" in line:
check_for_info = False
continue
items = re.search(r'Inventory \(\d+/36\): ({.*?})', line).group(1)
items_dict = eval(items)
total_items = sum(items_dict.values())
items_collected.append(0)
else:
items = re.search(r'Inventory \(\d+/36\): ({.*?})', line).group(1)
items_dict = eval(items)
total_items = sum(items_dict.values())
items_collected.append(total_items) # add previous total before updating
if "Completed tasks so far:" in line:
tasks = line.replace("Completed tasks so far:", "").strip().split(", ")
completed_tasks.append(0 if tasks == ['None'] else len(tasks))
if "Failed tasks that are too hard:" in line:
tasks = line.replace("Failed tasks that are too hard:", "").strip().split(", ")
failed_tasks.append(0 if tasks == ['None'] else len(tasks))
check_for_info = False
round_start = False
return rounds, items_collected, positions, completed_tasks, failed_tasks
@ -119,9 +101,9 @@ def save_item_results_png(rounds, items_collected, start_time, path_prefix):
plt.ylabel("Total Items Collected")
plt.title("Items Collected Over Rounds")
plt.grid(True)
plt.savefig(f'{path_prefix}/{start_time}_items_collected_over_rounds.png', dpi=300)
plt.close()
def save_path_results_png(positions, start_time, path_prefix):
x_coords = [pos[0] for pos in positions]
@ -148,6 +130,7 @@ def save_path_results_png(positions, start_time, path_prefix):
ax.text(min(x_coords), max(y_coords), max(z_coords), f"Total Distance: {total_distance:.2f} units", fontsize=15, color='red')
plt.savefig(f'{path_prefix}/{start_time}_bot_movement_3D_path.png', dpi=300)
plt.close()
def save_task_results_png(rounds , completed, failed, start_time, path_prefix):
plt.plot(rounds, completed, label='Completed Tasks', marker='o')
plt.plot(rounds, failed, label='Failed Tasks', marker='x')