diff --git a/src_py/apiServer/stats.py b/src_py/apiServer/stats.py index e4bb17dc..3207939b 100644 --- a/src_py/apiServer/stats.py +++ b/src_py/apiServer/stats.py @@ -306,7 +306,10 @@ def recieved_batches_key(phase_name, source_name, worker_name): workers_model_db_list = self.nerl_model_db.get_workers_model_db_list() for source_piece_inst in sources_pieces_list: source_name = source_piece_inst.get_source_name() - source_epoch = int(globe.components.sourceEpochs[source_name]) + if self.phase == PHASE_PREDICTION_STR: + source_epoch = 1 + else: + source_epoch = int(globe.components.sourceEpochs[source_name]) target_workers_string = source_piece_inst.get_target_workers() target_workers_names = target_workers_string.split(',') for worker_db in workers_model_db_list: @@ -336,7 +339,10 @@ def missed_batches_key(phase_name, source_name, worker_name): for source_piece_inst in sources_pieces_list: source_name = source_piece_inst.get_source_name() source_policy = globe.components.sources_policy_dict[source_name] # 0 -> casting , 1 -> round robin, 2 -> random - source_epoch = int(globe.components.sourceEpochs[source_name]) + if self.phase == PHASE_PREDICTION_STR: + source_epoch = 1 + else: + source_epoch = int(globe.components.sourceEpochs[source_name]) target_workers_string = source_piece_inst.get_target_workers() target_workers_names = target_workers_string.split(',') if source_policy == '0': # casting policy