Skip to content

Commit

Permalink
updated analysis script
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Mar 27, 2024
1 parent d6d7188 commit e4c7628
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions analysis/proc_task_cond_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ def event_key_to_str(event_key):
return f'{CODE_TO_EVENT[event_key[0]]}_{ITEM_ID_TO_NAME[event_key[1]]}'

elif event_key[0] == EventCode.GO_FARTHEST:
return '2_PROGRESS_TO_CENTER'
return '3_PROGRESS_TO_CENTER'

elif event_key[0] == EventCode.AGENT_CULLED:
return '1_AGENT_LIFESPAN'
return '2_AGENT_LIFESPAN'

else:
return CODE_TO_EVENT[event_key[0]]
Expand Down Expand Up @@ -67,7 +67,6 @@ def gather_agent_events_by_task(data_dir):
return data_by_task

def get_event_stats(task_name, task_data):
results = {'0_NAME': task_name}
num_agents = len(task_data)
assert num_agents > 0, 'There should be at least one agent'

Expand All @@ -78,6 +77,7 @@ def get_event_stats(task_name, task_data):
cnt_harvest = 0
cnt_list = 0

results = {'0_NAME': task_name, '1_COUNT': num_agents}
event_data = defaultdict(list)
for data in task_data:
for event, val in data.items():
Expand All @@ -90,7 +90,11 @@ def get_event_stats(task_name, task_data):
results[event_key_to_str(event)] = np.mean(vals) # AVG skill level
elif event[0] == EventCode.AGENT_CULLED:
life_span = np.mean(vals)
results[event_key_to_str(event)] = life_span
results['2_AGENT_LIFESPAN_AVG'] = life_span
results['2_AGENT_LIFESPAN_SD'] = np.std(vals)
elif event[0] == EventCode.GO_FARTHEST:
results['3_PROGRESS_TO_CENTER_AVG'] = np.mean(vals)
results['3_PROGRESS_TO_CENTER_SD'] = np.std(vals)
else:
results[event_key_to_str(event)] = sum(vals) / num_agents

Expand All @@ -107,19 +111,21 @@ def get_event_stats(task_name, task_data):
if event[0] == EventCode.LIST_ITEM:
cnt_list += sum(vals)

results['3_NORM_ATTACK'] = cnt_attack / life_span
results['3_NORM_BUY'] = cnt_buy / life_span
results['3_NORM_CONSUME'] = cnt_consume / life_span
results['3_NORM_EQUIP'] = cnt_equip / life_span
results['3_NORM_HARVEST'] = cnt_harvest / life_span
results['3_NORM_LIST'] = cnt_list / life_span
results['4_NORM_ATTACK'] = cnt_attack / life_span
results['4_NORM_BUY'] = cnt_buy / life_span
results['4_NORM_CONSUME'] = cnt_consume / life_span
results['4_NORM_EQUIP'] = cnt_equip / life_span
results['4_NORM_HARVEST'] = cnt_harvest / life_span
results['4_NORM_LIST'] = cnt_list / life_span

return results

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process replay data')
parser.add_argument('policy_store_dir', type=str, help='Path to the policy directory')
args = parser.parse_args()
# parser = argparse.ArgumentParser(description='Process replay data')
# parser.add_argument('policy_store_dir', type=str, help='Path to the policy directory')
# args = parser.parse_args()

args = argparse.Namespace(policy_store_dir='pol_task_cond2')

# Gather the event data by tasks, across multiple replays
data_by_task = gather_agent_events_by_task(args.policy_store_dir)
Expand Down

0 comments on commit e4c7628

Please sign in to comment.