Skip to content

Commit

Permalink
set device auto
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumehu committed Nov 16, 2023
1 parent 3ca39fa commit 9bd1dd6
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions presample_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def parse_args():

torch.set_default_dtype(torch.float64)

alpha = torch.ones(args.num_cat - 1)
beta = torch.arange(args.num_cat - 1, 0, -1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

alpha = torch.ones(args.num_cat - 1).to(device)
beta = torch.arange(args.num_cat - 1, 0, -1).to(device)

v_one, v_zero, v_one_loggrad, v_zero_loggrad, timepoints = noise_factory(args.num_samples,
args.num_time_steps,
Expand All @@ -74,7 +76,8 @@ def parse_args():
time_steps=args.steps_per_tick,
logspace=args.logspace,
speed_balanced=args.speed_balance,
mode=args.mode)
mode=args.mode,
device=device)

v_one = v_one.cpu()
v_zero = v_zero.cpu()
Expand Down

0 comments on commit 9bd1dd6

Please sign in to comment.