diff --git a/presample_noise.py b/presample_noise.py index 425fb63..cceebab 100644 --- a/presample_noise.py +++ b/presample_noise.py @@ -62,8 +62,10 @@ def parse_args(): torch.set_default_dtype(torch.float64) - alpha = torch.ones(args.num_cat - 1) - beta = torch.arange(args.num_cat - 1, 0, -1) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + alpha = torch.ones(args.num_cat - 1).to(device) + beta = torch.arange(args.num_cat - 1, 0, -1).to(device) v_one, v_zero, v_one_loggrad, v_zero_loggrad, timepoints = noise_factory(args.num_samples, args.num_time_steps, @@ -74,7 +76,8 @@ def parse_args(): time_steps=args.steps_per_tick, logspace=args.logspace, speed_balanced=args.speed_balance, - mode=args.mode) + mode=args.mode, + device=device) v_one = v_one.cpu() v_zero = v_zero.cpu()