From a0b5ca7c97deafe3e142a37881db0279f0e97a03 Mon Sep 17 00:00:00 2001 From: Nicholas Lubbers <56895592+lubbersnick@users.noreply.github.com> Date: Mon, 28 Oct 2024 18:08:38 -0600 Subject: [PATCH] Fix some issues with the speed tests in custom kernels (#117) * fixes for custom kernel speed tests * update printing for speed tests --- hippynn/custom_kernels/registry.py | 3 ++- hippynn/custom_kernels/test_speed_env.py | 22 ++++++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/hippynn/custom_kernels/registry.py b/hippynn/custom_kernels/registry.py index 3ba46641..493bc09e 100644 --- a/hippynn/custom_kernels/registry.py +++ b/hippynn/custom_kernels/registry.py @@ -60,4 +60,5 @@ def get_available_implementations(self, hidden=False): :param hidden: Show all implementations, even those which have no improved performance characteristics. :return: """ - return [k for k in self._registered_implementations.keys() if not k.startswith("_")] + + return [k for k in self._registered_implementations.keys() if hidden or not k.startswith("_")] diff --git a/hippynn/custom_kernels/test_speed_env.py b/hippynn/custom_kernels/test_speed_env.py index b3a3b674..35f1bda4 100644 --- a/hippynn/custom_kernels/test_speed_env.py +++ b/hippynn/custom_kernels/test_speed_env.py @@ -18,7 +18,7 @@ def parse_args(): parser.add_argument("--all-hidden", action="store_true", default=False, help="Use all implementations, even with _ beginning.") parser.add_argument("--all-impl", action="store_true", default=False, help="Use all non-hidden implementations.") parser.add_argument("--all-gpu", action="store_true", default=False, help="Use low-mem implementations suitable for GPU.") - parser.add_argument("--all-gpu", action="store_true", default=False, help="CPU-capable implementaitons.") + parser.add_argument("--all-cpu", action="store_true", default=False, help="CPU-capable implementaitons.") for param_type in TEST_PARAMS.keys(): parser.add_argument(f"--{param_type}", type=int, default=0, help=f"Count for param type {param_type}") @@ -43,7 +43,8 @@ def main(args=None): setattr(args, k, default) test_spec = {k: count for k in TEST_PARAMS if (count := getattr(args, k, 0)) > 0} - print(TEST_PARAMS.keys()) + + print("Testing specification:") print(test_spec) results = {} @@ -60,6 +61,11 @@ def main(args=None): implementations = MessagePassingKernels.get_available_implementations() if args.all_hidden: implementations = MessagePassingKernels.get_available_implementations(hidden=True) + + + print("Testing implementations:") + print(implementations) + # Error if implementation does not exist. for impl in implementations: @@ -82,10 +88,14 @@ def main(args=None): for k, count in test_spec.items(): print(f"Testing {k} {count} times:") np.random.seed(args.seed) - out0, out1 = tester.check_speed( - n_repetitions=count, device=torch.device(args.accelerator), data_size=TEST_PARAMS[k], compare_against=impl - ) - impl_results[k] = dict(tested=out0, comparison=out1) + try: + out0, out1 = tester.check_speed( + n_repetitions=count, device=torch.device(args.accelerator), data_size=TEST_PARAMS[k], compare_against=impl) + impl_results[k] = dict(tested=out0, comparison=out1) + except (torch.OutOfMemoryError, RuntimeError) as toom: + print(toom) + print("Got out of memory for this test! Attempting to continue.") + impl_results[k] = "OUT OF MEMORY" with open(path, "wt") as f: json.dump(results, f)