diff --git a/docs/changes/33.feature.rst b/docs/changes/33.feature.rst new file mode 100644 index 0000000..cc83a98 --- /dev/null +++ b/docs/changes/33.feature.rst @@ -0,0 +1,3 @@ +Changes to `vis_loop` function in `visibility.py`: +- add a an optional tqdm progress bar to get a visual confirmation the calculation is still running +- add optional `batch_size` parameter to control memory consumption diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 0e720eb..492c9ae 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -4,6 +4,8 @@ import pyvisgen.simulation.scan as scan +from tqdm import tqdm + @dataclass class Visibilities: @@ -37,7 +39,7 @@ def add(self, visibilities): ] -def vis_loop(obs, SI, num_threads=10, noisy=True, mode="full"): +def vis_loop(obs, SI, num_threads=10, noisy=True, mode="full", batch_size=100, show_progress=False): torch.set_num_threads(num_threads) torch._dynamo.config.suppress_errors = True @@ -93,7 +95,12 @@ def vis_loop(obs, SI, num_threads=10, noisy=True, mode="full"): else: raise ValueError("Unsupported mode!") - for p in torch.arange(bas[:].shape[1]).split(1000): + batches = torch.arange(bas[:].shape[1]).split(batch_size) + + if show_progress: + batches = tqdm(batches) + + for p in batches: bas_p = bas[:][:, p] int_values = torch.cat(