From 6a5d169150249a8b5c311fe18b83aa83d0db2b86 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Mon, 6 Jan 2025 11:18:54 +0800 Subject: [PATCH] check test.py --- source/tests/pt/test_make_stat_input.py | 44 +++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 2ccefb3f9e..3a095ee2b1 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -69,6 +69,7 @@ def count_non_zero_elements(self, tensor, threshold=1e-8): return torch.sum(torch.abs(tensor) > threshold).item() def test_make_stat_input(self): + #3 frames would be count lst = make_stat_input( datasets=self.datasets, dataloaders=self.dataloaders, @@ -78,7 +79,7 @@ def test_make_stat_input(self): ) bias, _ = compute_output_stats(lst, ntypes=57) energy = bias.get("energy") - self.assertIsNotNone(energy, "'energy' key not found in bias dictionary.") + print(energy) non_zero_count = self.count_non_zero_elements(energy) self.assertEqual( non_zero_count, @@ -87,6 +88,9 @@ def test_make_stat_input(self): ) def test_make_stat_input_nocomplete(self): + #missing element:13,31,37 + #only one frame would be count + lst = make_stat_input( datasets=self.datasets, dataloaders=self.dataloaders, @@ -96,7 +100,7 @@ def test_make_stat_input_nocomplete(self): ) bias, _ = compute_output_stats(lst, ntypes=57) energy = bias.get("energy") - self.assertIsNotNone(energy, "'energy' key not found in bias dictionary.") + print(energy) non_zero_count = self.count_non_zero_elements(energy) self.assertLess( non_zero_count, @@ -104,6 +108,42 @@ def test_make_stat_input_nocomplete(self): f"Expected fewer than {self.real_ntypes} non-zero elements, but got {non_zero_count}.", ) + def test_bias(self): + lst_ori = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=1, + min_frames_per_element_forstat=1, + enable_element_completion=False, + ) + lst_all = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=1, + min_frames_per_element_forstat=1, + enable_element_completion=True, + ) + bias_ori, _ = compute_output_stats(lst_ori, ntypes=57) + bias_all, _ = compute_output_stats(lst_all, ntypes=57) + energy_ori = np.array(bias_ori.get("energy").cpu()).flatten() + energy_all = np.array(bias_all.get("energy").cpu()).flatten() + + for i, (e_ori, e_all) in enumerate(zip(energy_ori, energy_all)): + if e_all == 0: + self.assertEqual( + e_ori, + 0, + f"Index {i}: energy_all=0, but energy_ori={e_ori}" + ) + else: + if e_ori != 0: + diff = abs(e_ori - e_all) + rel_diff = diff / abs(e_ori) + self.assertTrue( + rel_diff < 0.4, + f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, " + f"relative difference {rel_diff:.2%} is too large" + ) if __name__ == "__main__": unittest.main()