From be6d2edec55b5b0cc989b504f0a046d959da8254 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 13 Aug 2024 15:47:40 +0200 Subject: [PATCH 01/52] Add first implementation of polarisation simulation in the vis loop --- pyvisgen/simulation/visibility.py | 383 ++++++++++++++++++++++++++++-- 1 file changed, 358 insertions(+), 25 deletions(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 09f6896..e817857 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -1,7 +1,8 @@ from dataclasses import dataclass, fields +from tqdm import tqdm import torch -from tqdm import tqdm +import scipy.ndimage import pyvisgen.simulation.scan as scan @@ -18,6 +19,8 @@ class Visibilities: v: torch.tensor w: torch.tensor date: torch.tensor + linear_dop: torch.tensor + circular_dop: torch.tensor def __getitem__(self, i): return Visibilities(*[getattr(self, f.name)[i] for f in fields(self)]) @@ -38,35 +41,359 @@ def add(self, visibilities): ] +class Polarisation: + """Simulation of polarisation.""" + + def __init__( + self, + SI: torch.tensor, + sensitivity_cut: float, + amp_ratio: float, + delta: float, + polarisation: str, + random_state: int, + device: str, + ) -> None: + """Creates the 2 x 2 stokes matrix and simulates + polarisation if `polarisation` is either 'linear' + or 'circular'. Also computes the degree of polarisation. + + Parameters + ---------- + SI : torch.tensor + Stokes I component, i.e. intensity distribution + of the sky. + sensitivity_cut : float + Sensitivity cut, where only pixels above the value + are kept. + amp_ratio : float + Sets the ratio of $A_{x/r}$. The ratio of $A_{y/l}$ is calculated + as `1 - amp_ratio`. If set to `None`, a random value is drawn + from a uniform distribution. See also: `random_state`. + delta : float + Sets the phase difference of the amplitudes $A_x$ and $A_y$ + of the sky distribution. Defines the measure of ellipticity. + polarisation : str + Choose between `'linear'` or `'circular'` or `None` to + simulate different types of polarisations or disable + the simulation of polarisation. + random_state : int + Random state used when drawing `amp_ratio` and during the generation + of the random polarisation field. + device : str + Torch device to select for computation. + """ + self.sensitivity_cut = sensitivity_cut + self.polarisation = polarisation + self.device = device + + self.SI = SI.permute(dims=(1, 2, 0)) + + if random_state: + torch.manual_seed(random_state) + else: + torch.seed() + + if self.polarisation: + self.polarisation_field = self.rand_polarisation_field( + [SI.shape[0], SI.shape[1]], + random_state=random_state, + order=[1, 1], + scale=[0, 1], + threshold=None, + ) + + self.delta = delta + + if amp_ratio and amp_ratio >= 0: + ax2 = amp_ratio + else: + ax2 = torch.rand(1) + + ay2 = 1 - ax2 + + self.ax2 = self.SI[..., 0] * ax2 + self.ay2 = self.SI[..., 0] * ay2 + + self.I = torch.zeros( + (self.SI.shape[0], self.SI.shape[1], 4), dtype=torch.cdouble + ) # noqa: E741 + + def linear(self) -> None: + """Computes the stokes parameters I, Q, U, and V + for linear polarisation. + + .. math:: + I = A_x^2 + A_y^2 + Q = A_r^2 - A_l^2 + U = 2A_x A_y \cos\delta_{xy} + V = -2A_x A_y \sin\delta_{xy} + """ + self.I[..., 0] = self.ax2 + self.ay2 + self.I[..., 1] = self.ax2 - self.ay2 + self.I[..., 2] = ( + 2 + * torch.sqrt(self.ax2) + * torch.sqrt(self.ay2) + * torch.cos(torch.deg2rad(torch.tensor(self.delta))) + ) + self.I[..., 3] = ( + -2 + * torch.sqrt(self.ax2) + * torch.sqrt(self.ay2) + * torch.sin(torch.deg2rad(torch.tensor(self.delta))) + ) + + def circular(self) -> None: + """Computes the stokes parameters I, Q, U, and V + for circular polarisation. + + .. math:: + I = A_r^2 + A_l^2 + Q = 2A_r A_l \cos\delta_{rl} + U = -2A_r A_l \sin\delta_{rl} + V = A_r^2 - A_l^2 + """ + self.I[..., 0] = self.ax2 + self.ay2 + self.I[..., 1] = ( + 2 + * torch.sqrt(self.ax2) + * torch.sqrt(self.ay2) + * torch.cos(torch.deg2rad(torch.tensor(self.delta))) + ) + self.I[..., 2] = ( + -2 + * torch.sqrt(self.ax2) + * torch.sqrt(self.ay2) + * torch.sin(torch.deg2rad(torch.tensor(self.delta))) + ) + self.I[..., 3] = self.ax2 - self.ay2 + + def dop(self) -> None: + """Computes the degree of polarisation for each pixel.""" + mask = (self.ax2 + self.ay2) > 0 + + # apply polarisation_field to Q, U, and V only + self.I[..., 1] *= self.polarisation_field + self.I[..., 2] *= self.polarisation_field + self.I[..., 3] *= self.polarisation_field + + dop_I = self.I[..., 0].clone() + dop_I[~mask] = float("nan") + dop_Q = self.I[..., 1].clone() + dop_Q[~mask] = float("nan") + dop_U = self.I[..., 2].clone() + dop_U[~mask] = float("nan") + dop_V = self.I[..., 3].clone() + dop_V[~mask] = float("nan") + + self.lin_dop = torch.sqrt(dop_Q**2 + dop_U**2) / dop_I + self.circ_dop = torch.abs(dop_V) / dop_I + + del dop_I, dop_Q, dop_U, dop_V + + def stokes_matrix(self) -> tuple: + """Computes and returns the 2 x 2 stokes matrix B. + + Returns + ------- + B : torch.tensor + 2 x 2 stokes brightness matrix. Either for linear, + circular or no polarisation. + mask : torch.tensor + Mask of the sensitivity cut (Keep all px > sensitivity_cut). + lin_dop : torch.tensor + Degree of linear polarisation of every pixel in the sky. + circ_dop : torch.tensor + Degree of circular polarisation of every pixel in the sky. + """ + # define 2 x 2 Stokes matrix + B = torch.zeros( + (self.SI.shape[0], self.SI.shape[1], 2, 2), dtype=torch.cdouble + ).to(torch.device(self.device)) + + if self.polarisation == "linear": + self.linear() + self.dop() + + B[..., 0, 0] = self.I[..., 0] + self.I[..., 3] + B[..., 0, 1] = self.I[..., 1] + 1j * self.I[..., 2] + B[..., 1, 0] = self.I[..., 1] - 1j * self.I[..., 2] + B[..., 1, 1] = self.I[..., 0] - self.I[..., 3] + + elif self.polarisation == "circular": + self.circular() + self.dop() + + B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] + B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] + B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] + B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] + + else: + # No polarisation applied + self.I[..., 0] = self.SI[..., 0] + self.I[..., 1] = self.SI[..., 0] + self.I[..., 2] = self.SI[..., 0] + self.I[..., 3] = self.SI[..., 0] + + B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] + B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] + B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] + B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] + + # calculations only for px > sensitivity cut + mask = (self.SI >= self.sensitivity_cut)[..., 0] + B = B[mask] + + return B, mask, self.lin_dop, self.circ_dop + + def rand_polarisation_field( + self, + shape: list[int, int] | int, + order: list[int, int] | int = 1, + random_state: int = None, + scale: list = [0, 1], + threshold: float = None, + ) -> torch.tensor: + """ + Generates a random noise mask for polarisation. + + Parameters + ---------- + shape : array_like (M, N), or int + The size of the sky image. + order : array_like (M, N) or int, optional + Morphology of the random noise. Higher values create + more and smaller fluctuations. Default: 1. + random_state : int, optional + Random state for the random number generator. If None, + a random entropy is pulled from the OS. Default: None. + scale : array_like, optional + Scaling of the distribution of the image. Default: [0, 1] + threshold : float, optional + If not None, an upper threshold is applied to the image. + Default: None + + Returns + ------- + im : torch.tensor + An array containing random noise values between + scale[0] and scale[1]. + """ + if random_state: + torch.random.manual_seed(random_state) + + if not isinstance(shape, list): + shape = list(shape) + + if len(shape) < 2: + shape *= 2 + elif len(shape) > 2: + raise ValueError("Only 2d shapes are allowed!") + + if not isinstance(order, list): + order = list(order) + + if len(order) < 2: + order *= 2 + elif len(order) > 2: + raise ValueError("Only 2d shapes are allowed!") + + sigma = torch.mean(torch.tensor(shape).double()) / (40 * torch.tensor(order)) + + im = torch.rand(shape) + im = scipy.ndimage.gaussian_filter(im, sigma=sigma.numpy()) + + if scale is None: + scale = [im.min(), im.max()] + + im_flatten = torch.from_numpy(im.flatten()) + im_argsort = torch.argsort(torch.argsort(im_flatten)) + im_linspace = torch.linspace(*scale, im_argsort.size()[0]) + uniform_flatten = im_linspace[im_argsort] + + im = torch.reshape(uniform_flatten, im.shape) + + if threshold: + im = im < threshold + + return im + + def vis_loop( - obs, - SI, - num_threads=10, - noisy=True, - mode="full", - batch_size=100, - show_progress=False, -): + obs: "Observation", + SI: torch.tensor, + num_threads: int = 10, + noisy: bool = True, + mode: str = "full", + batch_size: int = 100, + polarisation: str = "linear", + delta: float = 0, + amp_ratio: float = None, + random_state: int = 42, + show_progress: bool = False, +) -> Visibilities: + r"""Computes the visibilities of an observation. + + Parameters + ---------- + obs : Observation class object + Observation class object generated by the + `~pyvisgen.simulation.Observation` class. + SI : torch.tensor + Tensor containing the sky intensity distribution. + num_threads : int, optional + Number of threads used for intraoperative parallelism + on the CPU. See `~torch.set_num_threads`. Default: 10 + noisy : bool, optional + If `True`, generate and add additional noise to + the simulated measurements. Default: True + mode : str, optional + Select one of `'full'`, `'grid'`, or `'dense'` to get + all valid baselines, a grid of unique baselines, or + dense baselines. Default: 'full' + batch_size : int, optional + Batch size for iteration over baselines. Default: 100 + polarisation : str, optional + Choose between `'linear'` or `'circular'` or `None` to + simulate different types of polarisations or disable + the simulation of polarisation. Default: 'linear' + delta : float, optional + Sets the phase difference of the amplitudes $A_x$ and $A_y$ + of the sky distribution. Defines the measure of ellipticity. + Default: 0 + amp_ratio : float, optional + Sets the ratio of $A_{x/r}$. The ratio of $A_{y/l}$ is calculated + as `1 - amp_ratio`. If set to `None`, a random value is drawn + from a uniform distribution. See also: `random_state`. Default: None + random_state : int, optional + Random state used when drawing `amp_ratio` and during the generation + of the random polarisation field. Default: 42 + show_progress : bool, optional + If `True`, show a progress bar during the iteration over the + batches of baselines. Default: False + + Returns + ------- + visibilities : Visibilities + Dataclass object containing visibilities and baselines. + """ torch.set_num_threads(num_threads) torch._dynamo.config.suppress_errors = True - # define unpolarized sky distribution - SI = SI.permute(dims=(1, 2, 0)) - I = torch.zeros((SI.shape[0], SI.shape[1], 4), dtype=torch.cdouble) - I[..., 0] = SI[..., 0] - - # define 2 x 2 Stokes matrix ((I + Q, iU + V), (iU -V, I - Q)) - B = torch.zeros((SI.shape[0], SI.shape[1], 2, 2), dtype=torch.cdouble).to( - torch.device(obs.device) + pol = Polarisation( + SI, + sensitivity_cut=obs.sensitivity_cut, + amp_ratio=amp_ratio, + delta=delta, + polarisation=polarisation, + random_state=random_state, + device=obs.device, ) - B[:, :, 0, 0] = I[:, :, 0] + I[:, :, 1] - B[:, :, 0, 1] = I[:, :, 2] + 1j * I[:, :, 3] - B[:, :, 1, 0] = I[:, :, 2] - 1j * I[:, :, 3] - B[:, :, 1, 1] = I[:, :, 0] - I[:, :, 1] - - # calculations only for px > sensitivity cut - mask = (SI >= obs.sensitivity_cut)[..., 0] - B = B[mask] + + B, mask, lin_dop, circ_dop = pol.stokes_matrix() + lm = obs.lm[mask] rd = obs.rd[mask] @@ -86,6 +413,8 @@ def vis_loop( torch.tensor([]), torch.tensor([]), torch.tensor([]), + torch.tensor([]), + torch.tensor([]), ) vis_num = torch.zeros(1) if mode == "full": @@ -149,7 +478,11 @@ def vis_loop( bas_p[5].cpu(), bas_p[8].cpu(), bas_p[10].cpu(), + torch.tensor([]), + torch.tensor([]), ) + visibilities.linear_dop = lin_dop.cpu() + visibilities.circular_dop = circ_dop.cpu() visibilities.add(vis) del int_values From 1f456e700ca79d72382ac6157324e07f45f40541 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Wed, 14 Aug 2024 11:09:19 +0200 Subject: [PATCH 02/52] Move arguments for polarisation to Observation class --- pyvisgen/simulation/observation.py | 122 ++++++++++++++++++++++++----- pyvisgen/simulation/visibility.py | 32 ++------ 2 files changed, 110 insertions(+), 44 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index 814aad4..f7e4e27 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, fields +from datetime import datetime from math import pi import astropy.constants as const @@ -155,24 +156,104 @@ def _lexsort(self, a, dim=-1): class Observation: def __init__( self, - src_ra, - src_dec, - start_time, - scan_duration, - num_scans, - scan_separation, - integration_time, - ref_frequency, - frequency_offsets, - bandwidths, - fov, - image_size, - array_layout, - corrupted, - device, - dense=False, - sensitivity_cut=1e-6, - ): + src_ra: float, + src_dec: float, + start_time: datetime, + scan_duration: int, + num_scans: int, + scan_separation: int, + integration_time: int, + ref_frequency: float, + frequency_offsets: list, + bandwidths: list, + fov: float, + image_size: int, + array_layout: str, + corrupted: bool, + device: str, + dense: bool = False, + sensitivity_cut: float = 1e-6, + polarisation: str = None, + pol_kwargs: dict = { + "delta": 0, + "amp_ratio": 0.5, + "random_state": 42, + }, + field_kwargs: dict = { + "order": [1, 1], + "scale": [0, 1], + "threshold": None, + "random_state": 42, + }, + ) -> None: + """Sets up the observation class. + + Parameters + ---------- + src_ra : float + Source right ascension coordinate. + src_dec : float + Source declination coordinate. + start_time : datetime + Observation start time. + scan_duration : int + Scan duration. + num_scans : int + Number of scans. + scan_separation : int + Scan separation. + integration_time : int + Integration time. + ref_frequency : float + Reference frequency. + frequency_offsets : list + Frequency offsets. + bandwidths : list + Frequency bandwidth. + fov : float + Field of view. + image_size : int + Image size of the sky distribution. + array_layout : str + Name of an existing array layout. See `~pyvisgen.layouts`. + corrupted : bool + If `True`, apply corruption during the vis loop. + device : str + Torch device to select for computation. + dense : bool, optional + If `True`, apply dense baseline calculation of a perfect + interferometer. Default: `False` + sensitivity_cut : float, optional + Sensitivity threshold, where only pixels above the value + are kept. Default: 1e-6 + polarisation : str, optional + Choose between `'linear'` or `'circular'` or `None` to + simulate different types of polarisations or disable + the simulation of polarisation. Default: `None` + pol_kwargs : dict, optional + Additional keyword arguments for the simulation + of polarisation. Default: `{ + "delta": 0, + "amp_ratio": 0.5, + "random_state": 42, + } + field_kwargs : dict, optional + Additional keyword arguments for the random polarisation + field that is applied when simulating polarisation. + Default: `{ + "order": [1, 1], + "scale": [0, 1], + "threshold": None, + "random_state": 42 + }` + + Notes + ----- + See `~pyvisgen.simulation.visibility.Polarisation` and + `~pyvisgen.simulation.visibility.Polarisation.rand_polarisation_field` + for more information on the keyword arguments in `pol_kwargs` + and `field_kwargs`, respectively. + """ self.ra = torch.tensor(src_ra).double() self.dec = torch.tensor(src_dec).double() @@ -229,6 +310,11 @@ def __init__( self.rd = self.create_rd_grid() self.lm = self.create_lm_grid() + # polarisation + self.polarisation = polarisation + self.pol_kwargs = pol_kwargs + self.field_kwargs = field_kwargs + def calc_dense_baselines(self): N = self.img_size fov = self.fov * pi / (3600 * 180) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index e817857..c4a635b 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -51,6 +51,7 @@ def __init__( amp_ratio: float, delta: float, polarisation: str, + field_kwargs: dict, random_state: int, device: str, ) -> None: @@ -91,21 +92,16 @@ def __init__( if random_state: torch.manual_seed(random_state) - else: - torch.seed() if self.polarisation: self.polarisation_field = self.rand_polarisation_field( [SI.shape[0], SI.shape[1]], - random_state=random_state, - order=[1, 1], - scale=[0, 1], - threshold=None, + **field_kwargs, ) self.delta = delta - if amp_ratio and amp_ratio >= 0: + if amp_ratio and (amp_ratio >= 0): ax2 = amp_ratio else: ax2 = torch.rand(1) @@ -233,9 +229,6 @@ def stokes_matrix(self) -> tuple: else: # No polarisation applied self.I[..., 0] = self.SI[..., 0] - self.I[..., 1] = self.SI[..., 0] - self.I[..., 2] = self.SI[..., 0] - self.I[..., 3] = self.SI[..., 0] B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] @@ -328,10 +321,6 @@ def vis_loop( noisy: bool = True, mode: str = "full", batch_size: int = 100, - polarisation: str = "linear", - delta: float = 0, - amp_ratio: float = None, - random_state: int = 42, show_progress: bool = False, ) -> Visibilities: r"""Computes the visibilities of an observation. @@ -359,14 +348,6 @@ def vis_loop( Choose between `'linear'` or `'circular'` or `None` to simulate different types of polarisations or disable the simulation of polarisation. Default: 'linear' - delta : float, optional - Sets the phase difference of the amplitudes $A_x$ and $A_y$ - of the sky distribution. Defines the measure of ellipticity. - Default: 0 - amp_ratio : float, optional - Sets the ratio of $A_{x/r}$. The ratio of $A_{y/l}$ is calculated - as `1 - amp_ratio`. If set to `None`, a random value is drawn - from a uniform distribution. See also: `random_state`. Default: None random_state : int, optional Random state used when drawing `amp_ratio` and during the generation of the random polarisation field. Default: 42 @@ -385,11 +366,10 @@ def vis_loop( pol = Polarisation( SI, sensitivity_cut=obs.sensitivity_cut, - amp_ratio=amp_ratio, - delta=delta, - polarisation=polarisation, - random_state=random_state, + polarisation=obs.polarisation, device=obs.device, + field_kwargs=obs.field_kwargs, + **obs.pol_kwargs, ) B, mask, lin_dop, circ_dop = pol.stokes_matrix() From 15c3434066e230a3afb5f68cf38d237facd53a52 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Wed, 14 Aug 2024 13:39:14 +0200 Subject: [PATCH 03/52] Catch case where order arg is int in Polarisation class --- pyvisgen/simulation/visibility.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index c4a635b..300445f 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -285,6 +285,9 @@ def rand_polarisation_field( elif len(shape) > 2: raise ValueError("Only 2d shapes are allowed!") + if isinstance(order, int): + order = [order] + if not isinstance(order, list): order = list(order) From 12cd0578e3c1cf688e4b0a717783d94235bcc60a Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Mon, 9 Sep 2024 14:35:27 +0200 Subject: [PATCH 04/52] Update FITS header writer - Add numeric codes for stokes parameters - Add header comments for stokes parameters - Add comment for visibilities (real, complex, weight) --- pyvisgen/fits/writer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pyvisgen/fits/writer.py b/pyvisgen/fits/writer.py index f17f97d..11cdcd5 100644 --- a/pyvisgen/fits/writer.py +++ b/pyvisgen/fits/writer.py @@ -42,9 +42,17 @@ def create_vis_hdu(data, obs, layout="vlba", source_name="sim-source-0"): freq_d = obs.bandwidths[0].cpu().numpy().item() ws = wcs.WCS(naxis=7) + + crval_stokes = -1 + stokes_comment = "-1=RR, -2=LL, -3=RL, -4=LR" + if obs.polarisation == "linear": + crval_stokes = -5 + stokes_comment = "-5=XX, -6=YY, -7=XY, -8=YX" + stokes_comment += " or -5=VV, -6=HH, -7=VH, -8=HV" + ws.wcs.crpix = [1, 1, 1, 1, 1, 1, 1] ws.wcs.cdelt = np.array([1, 1, -1, freq_d, 1, 1, 1]) - ws.wcs.crval = [1, 1, -1, freq, 1, ra, dec] + ws.wcs.crval = [1, 1, crval_stokes, freq, 1, ra, dec] ws.wcs.ctype = ["", "COMPLEX", "STOKES", "FREQ", "IF", "RA", "DEC"] h = ws.to_header() @@ -79,6 +87,8 @@ def create_vis_hdu(data, obs, layout="vlba", source_name="sim-source-0"): hdu_vis.header["PZERO" + str(i + 1)] = parbzeros[i] # add comments + hdu_vis.header.comments["CTYPE2"] = "1=real, 2=imag, 3=weight" + hdu_vis.header.comments["CTYPE3"] = stokes_comment hdu_vis.header.comments["PTYPE1"] = "u baseline coordinate in light seconds" hdu_vis.header.comments["PTYPE2"] = "v baseline coordinate in light seconds" hdu_vis.header.comments["PTYPE3"] = "w baseline coordinate in light seconds" From 56dda4b7d53eae7c2040b84018776fab2c537bf7 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Mon, 9 Sep 2024 14:38:01 +0200 Subject: [PATCH 05/52] Fix observation gridding and Baseline dataclass - Fixes image rotation due to rd grid indexing - Fixes Baselines dataclass attribute sequence, where `num_baseline` would previously be saved as first element --- pyvisgen/simulation/observation.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index f7e4e27..6f74c95 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -65,7 +65,6 @@ def get_valid_subset(self, num_baselines, device): date = (torch.from_numpy(t[:-1][mask] + t[1:][mask]) / 2).to(device) return ValidBaselineSubset( - baseline_nums, u_start, u_stop, u_valid, @@ -75,13 +74,13 @@ def get_valid_subset(self, num_baselines, device): w_start, w_stop, w_valid, + baseline_nums, date, ) @dataclass() class ValidBaselineSubset: - baseline_nums: torch.tensor u_start: torch.tensor u_stop: torch.tensor u_valid: torch.tensor @@ -91,6 +90,7 @@ class ValidBaselineSubset: w_start: torch.tensor w_stop: torch.tensor w_valid: torch.tensor + baseline_nums: torch.tensor date: torch.tensor def __getitem__(self, i): @@ -456,9 +456,10 @@ def create_rd_grid(self): - self.img_size / 2 ) * res + dec - _, R = torch.meshgrid((r, r), indexing="ij") - D, _ = torch.meshgrid((d, d), indexing="ij") + R, _ = torch.meshgrid((r, r), indexing="ij") + _, D = torch.meshgrid((d, d), indexing="ij") rd_grid = torch.cat([R[..., None], D[..., None]], dim=2) + return rd_grid def create_lm_grid(self): @@ -479,11 +480,11 @@ def create_lm_grid(self): dec = torch.deg2rad(self.dec) lm_grid = torch.zeros(self.rd.shape, device=self.device, dtype=torch.float64) - lm_grid[:, :, 0] = (torch.cos(self.rd[..., 1]) * torch.sin(self.rd[..., 0])).T - lm_grid[:, :, 1] = ( - torch.sin(self.rd[..., 1]) * torch.cos(dec) - - torch.cos(self.rd[..., 1]) * torch.sin(dec) * torch.cos(self.rd[..., 0]) - ).T + lm_grid[..., 0] = torch.cos(self.rd[..., 1]) * torch.sin(self.rd[..., 0]) + lm_grid[..., 1] = torch.sin(self.rd[..., 1]) * torch.cos(dec) - torch.cos( + self.rd[..., 1] + ) * torch.sin(dec) * torch.cos(self.rd[..., 0]) + return lm_grid def get_baselines(self, times): From b8917b615ee955de1e9011e5be866f7cbc2cebda Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Mon, 9 Sep 2024 14:42:09 +0200 Subject: [PATCH 06/52] Fix stokes calc, fix sequence of visibilities - Fix calculation of Stokes parameters in Polarisation class - Rename visibilities I -> V_11, Q -> V_12, U -> V_21, V -> V_22. The previous names are confusing, as the quantities that are returned are not the Stokes parameters but calculated from the Stokes parameters (i.e. V_11 = I + V for circular polarization) - Fix visbilities sequence -> V_11, V_22, V_12, V_21 (previously V_11, V_12, V_21, V22), i.e. main diagonal of the Jones matrix first, then subdiagonals --- pyvisgen/simulation/visibility.py | 67 +++++++++++++++++-------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 300445f..34c3bdb 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -9,10 +9,10 @@ @dataclass class Visibilities: - SI: torch.tensor - SQ: torch.tensor - SU: torch.tensor - SV: torch.tensor + V_11: torch.tensor + V_22: torch.tensor + V_12: torch.tensor + V_21: torch.tensor num: torch.tensor base_num: torch.tensor u: torch.tensor @@ -27,7 +27,7 @@ def __getitem__(self, i): def get_values(self): return torch.cat( - [self.SI[None], self.SQ[None], self.SU[None], self.SV[None]], dim=0 + [self.V_11[None], self.V_22[None], self.V_12[None], self.V_21[None]], dim=0 ).permute(1, 2, 0) def add(self, visibilities): @@ -95,7 +95,7 @@ def __init__( if self.polarisation: self.polarisation_field = self.rand_polarisation_field( - [SI.shape[0], SI.shape[1]], + [self.SI.shape[0], self.SI.shape[1]], **field_kwargs, ) @@ -108,8 +108,11 @@ def __init__( ay2 = 1 - ax2 - self.ax2 = self.SI[..., 0] * ax2 - self.ay2 = self.SI[..., 0] * ay2 + self.ax2 = self.SI[..., 0].clone() * ax2 + self.ay2 = self.SI[..., 0].clone() * ay2 + else: + self.ax2 = self.SI[..., 0] + self.ay2 = torch.zeros_like(self.ax2) self.I = torch.zeros( (self.SI.shape[0], self.SI.shape[1], 4), dtype=torch.cdouble @@ -174,13 +177,13 @@ def dop(self) -> None: self.I[..., 2] *= self.polarisation_field self.I[..., 3] *= self.polarisation_field - dop_I = self.I[..., 0].clone() + dop_I = self.I[..., 0].real.clone() dop_I[~mask] = float("nan") - dop_Q = self.I[..., 1].clone() + dop_Q = self.I[..., 1].real.clone() dop_Q[~mask] = float("nan") - dop_U = self.I[..., 2].clone() + dop_U = self.I[..., 2].real.clone() dop_U[~mask] = float("nan") - dop_V = self.I[..., 3].clone() + dop_V = self.I[..., 3].real.clone() dop_V[~mask] = float("nan") self.lin_dop = torch.sqrt(dop_Q**2 + dop_U**2) / dop_I @@ -212,28 +215,30 @@ def stokes_matrix(self) -> tuple: self.linear() self.dop() - B[..., 0, 0] = self.I[..., 0] + self.I[..., 3] - B[..., 0, 1] = self.I[..., 1] + 1j * self.I[..., 2] - B[..., 1, 0] = self.I[..., 1] - 1j * self.I[..., 2] - B[..., 1, 1] = self.I[..., 0] - self.I[..., 3] + B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] # I + Q + B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] # U + iV + B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] # U - iV + B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] # I - Q elif self.polarisation == "circular": self.circular() self.dop() - B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] - B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] - B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] - B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] + B[..., 0, 0] = self.I[..., 0] + self.I[..., 3] # I + V + B[..., 0, 1] = self.I[..., 1] + 1j * self.I[..., 2] # Q + iU + B[..., 1, 0] = self.I[..., 1] - 1j * self.I[..., 2] # Q - iU + B[..., 1, 1] = self.I[..., 0] - self.I[..., 3] # I - V else: # No polarisation applied self.I[..., 0] = self.SI[..., 0] + self.polarisation_field = torch.ones_like(self.I[..., 0]) + self.dop() - B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] - B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] - B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] - B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] + B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] # I + Q + B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] # U + iV + B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] # U - iV + B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] # I - Q # calculations only for px > sensitivity cut mask = (self.SI >= self.sensitivity_cut)[..., 0] @@ -451,10 +456,10 @@ def vis_loop( vis_num = torch.arange(int_values.shape[0]) + 1 + vis_num.max() vis = Visibilities( - int_values[:, :, 0, 0].cpu(), - int_values[:, :, 0, 1].cpu(), - int_values[:, :, 1, 0].cpu(), - int_values[:, :, 1, 1].cpu(), + int_values[..., 0, 0].cpu(), # V_11 + int_values[..., 1, 1].cpu(), # V_22 + int_values[..., 0, 1].cpu(), # V_12 + int_values[..., 1, 0].cpu(), # V_21 vis_num, bas_p[9].cpu(), bas_p[2].cpu(), @@ -464,11 +469,13 @@ def vis_loop( torch.tensor([]), torch.tensor([]), ) - visibilities.linear_dop = lin_dop.cpu() - visibilities.circular_dop = circ_dop.cpu() visibilities.add(vis) del int_values + + visibilities.linear_dop = lin_dop.cpu() + visibilities.circular_dop = circ_dop.cpu() + return visibilities From d12e8737be1ff8be3b297e5e91f24e5f8cd1d05c Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Thu, 12 Sep 2024 17:29:49 +0200 Subject: [PATCH 07/52] Change type hint from str to torch.device for device --- pyvisgen/simulation/visibility.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 34c3bdb..8966f45 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -53,7 +53,7 @@ def __init__( polarisation: str, field_kwargs: dict, random_state: int, - device: str, + device: torch.device, ) -> None: """Creates the 2 x 2 stokes matrix and simulates polarisation if `polarisation` is either 'linear' @@ -81,7 +81,7 @@ def __init__( random_state : int Random state used when drawing `amp_ratio` and during the generation of the random polarisation field. - device : str + device : torch.device Torch device to select for computation. """ self.sensitivity_cut = sensitivity_cut From a619a3a3c9326652c0d7c146cd5768a60644afbb Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 20 Sep 2024 16:53:41 +0200 Subject: [PATCH 08/52] Flip input image before creating stokes matrix --- pyvisgen/simulation/visibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 8966f45..d0551ca 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -372,7 +372,7 @@ def vis_loop( torch._dynamo.config.suppress_errors = True pol = Polarisation( - SI, + torch.flip(SI, dims=[1]), sensitivity_cut=obs.sensitivity_cut, polarisation=obs.polarisation, device=obs.device, From 66b5d229f1aff8568d0e49404b4862da72d6786c Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 1 Oct 2024 10:04:34 +0200 Subject: [PATCH 09/52] Add changelog --- docs/changes/39.bugfix.rst | 3 +++ docs/changes/39.feature.rst | 4 ++++ docs/changes/39.maintenance.rst | 6 ++++++ 3 files changed, 13 insertions(+) create mode 100644 docs/changes/39.bugfix.rst create mode 100644 docs/changes/39.feature.rst create mode 100644 docs/changes/39.maintenance.rst diff --git a/docs/changes/39.bugfix.rst b/docs/changes/39.bugfix.rst new file mode 100644 index 0000000..9a98182 --- /dev/null +++ b/docs/changes/39.bugfix.rst @@ -0,0 +1,3 @@ +- Fix gridding in `pyvisgen.simulation.observation.Observation` methods `create_rd_grid` and `create_lm_grid` + methods resulting in rotated images +- Fix `pyvisgen.simulation.observation.ValidBaselineSubset` dataclass field order diff --git a/docs/changes/39.feature.rst b/docs/changes/39.feature.rst new file mode 100644 index 0000000..d2cda72 --- /dev/null +++ b/docs/changes/39.feature.rst @@ -0,0 +1,4 @@ +- Add class `Polarisation` to `pyvisgen.simulation.visibility` that is called in `vis_loop` + - Add linear, circular, and no polarisation options +- Update `pyvisgen.simulation.visibility.Visibilities` dataclass to also store polarisation degree tensors +- Add keyword arguments for polarisation simulation to `pyvisgen.simulation.observation.Observation` class diff --git a/docs/changes/39.maintenance.rst b/docs/changes/39.maintenance.rst new file mode 100644 index 0000000..83a2365 --- /dev/null +++ b/docs/changes/39.maintenance.rst @@ -0,0 +1,6 @@ +- Change pyvisgen.simulation.visibility.Visibilities dataclass component names from stokes components (I , Q, U, and V) + to visibilities constructed from the stokes components (`V_11`, `V_22`, `V_12`, `V_21`) +- Change indices for stokes components according to AIPS Memo 114 + - Indices will be set automatically depending on simulated polarisation +- Update comment strings in FITS files +- Update docstrings accordingly in `pyvisgen.simulation.visibility.vis_loop` and `pyvisgen.simulation.observation.Observation` From 2a5e07c0c54ee08bd1437145252994debb8ce362 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 1 Oct 2024 10:08:23 +0200 Subject: [PATCH 10/52] Fix test failing because of api change --- tests/test_simulation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 4227924..132df39 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -42,10 +42,10 @@ def test_vis_loop(): SI = torch.tensor(data[0])[None] vis_data = vis_loop(obs, SI, noisy=conf["noisy"], mode=conf["mode"]) - assert (vis_data[0].SI[0]).dtype == torch.complex128 - assert (vis_data[0].SQ[0]).dtype == torch.complex128 - assert (vis_data[0].SU[0]).dtype == torch.complex128 - assert (vis_data[0].SV[0]).dtype == torch.complex128 + assert (vis_data[0].V_11[0]).dtype == torch.complex128 + assert (vis_data[0].V_22[0]).dtype == torch.complex128 + assert (vis_data[0].V_12[0]).dtype == torch.complex128 + assert (vis_data[0].V_21[0]).dtype == torch.complex128 assert (vis_data[0].num).dtype == torch.float32 assert (vis_data[0].base_num).dtype == torch.float64 assert torch.is_tensor(vis_data[0].u) From 3ae5f722a2c60fe8093f18043a9ee20f745b9a1f Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 1 Oct 2024 10:09:12 +0200 Subject: [PATCH 11/52] Update bugfix changelog --- docs/changes/39.bugfix.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changes/39.bugfix.rst b/docs/changes/39.bugfix.rst index 9a98182..67473db 100644 --- a/docs/changes/39.bugfix.rst +++ b/docs/changes/39.bugfix.rst @@ -1,3 +1,4 @@ - Fix gridding in `pyvisgen.simulation.observation.Observation` methods `create_rd_grid` and `create_lm_grid` methods resulting in rotated images - Fix `pyvisgen.simulation.observation.ValidBaselineSubset` dataclass field order +- Fix test failing because of api change From be9bc138015644aed959788a97dd94350b5203c0 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 25 Oct 2024 14:08:16 +0200 Subject: [PATCH 12/52] Add tests for Polarisation class --- tests/test_simulation.py | 181 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 132df39..17ff26a 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -1,6 +1,7 @@ from pathlib import Path import torch +from numpy.testing import assert_array_equal, assert_raises from pyvisgen.utils.config import read_data_set_conf @@ -64,3 +65,183 @@ def test_vis_loop(): out = out_path / Path("vis_0.fits") hdu_list = writer.create_hdu_list(vis_data, obs) hdu_list.writeto(out, overwrite=True) + + +class TestPolarisation: + """Unit test class for ``pyvisgen.simulation.visibility.Polarisation``.""" + + def setup_class(self): + """Set up common objects and variables for the following tests.""" + from pyvisgen.simulation.data_set import create_observation + from pyvisgen.simulation.visibility import Polarisation + + self.obs = create_observation(conf) + + self.SI = torch.zeros((100, 100)) + self.SI[25::25, 25::25] = 1 + self.SI = self.SI[None, ...] + + self.si_shape = self.SI.shape + self.im_shape = self.si_shape[1], self.si_shape[2] + + self.obs.img_size = self.im_shape[0] + + self.pol = Polarisation( + self.SI, + sensitivity_cut=self.obs.sensitivity_cut, + polarisation=self.obs.polarisation, + device=self.obs.device, + field_kwargs=self.obs.field_kwargs, + **self.obs.pol_kwargs, + ) + + def test_polarisation_circular(self): + """Test circular polarisation.""" + + self.pol.__init__( + self.SI, + sensitivity_cut=self.obs.sensitivity_cut, + polarisation="circular", + device=self.obs.device, + field_kwargs=self.obs.field_kwargs, + **self.obs.pol_kwargs, + ) + + assert self.pol.delta == 0 + assert self.pol.ax2.sum() == self.SI.sum() * 0.5 + assert self.pol.ay2.sum() == self.SI.sum() * 0.5 + + B, mask, lin_dop, circ_dop = self.pol.stokes_matrix() + + assert mask.sum() == 9 + assert B.shape == torch.Size([9, 2, 2]) + assert mask.shape == self.im_shape + assert lin_dop.shape == self.im_shape + assert lin_dop.shape == self.im_shape + + def test_polarisation_linear(self): + """Test linear polarisation.""" + + self.pol.__init__( + self.SI, + sensitivity_cut=self.obs.sensitivity_cut, + polarisation="linear", + device=self.obs.device, + field_kwargs=self.obs.field_kwargs, + **self.obs.pol_kwargs, + ) + + assert self.pol.delta == 0 + assert self.pol.ax2.sum() == self.SI.sum() * 0.5 + assert self.pol.ay2.sum() == self.SI.sum() * 0.5 + + B, mask, lin_dop, circ_dop = self.pol.stokes_matrix() + + assert mask.sum() == 9 + assert B.shape == torch.Size([9, 2, 2]) + assert mask.shape == self.im_shape + assert lin_dop.shape == self.im_shape + assert lin_dop.shape == self.im_shape + + def test_polarisation_field(self): + """Test Polarisation.rand_polarisation_field method.""" + pf = self.pol.rand_polarisation_field(shape=self.im_shape) + + assert pf.shape == torch.Size([100, 100]) + + def test_polarisation_field_random_state(self): + """Test polarisation field method for a given random_state""" + random_state = 42 + + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=random_state, + ) + + assert torch.random.initial_seed() == random_state + assert pf.shape == torch.Size([100, 100]) + + def test_polarisation_field_shape_int(self): + """Test polarisation field method for type(shape) = int.""" + pf = self.pol.rand_polarisation_field( + shape=self.im_shape[0], + ) + + assert pf.shape == torch.Size([100, 100]) + + def test_polarisation_field_order(self): + """Test polarisation field method for different orders.""" + + pf_ref = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + ) + + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + order=[1, 1], + ) + + assert pf.shape == torch.Size([100, 100]) + # assert order = 1 and order = [1, 1] yield same images + assert_array_equal(pf, pf_ref, strict=True) + + # assert different order creates different image + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, random_state=42, order=[10, 10] + ) + # expected to raise an AssertionError + assert_raises(AssertionError, assert_array_equal, pf, pf_ref) + + # assert len(order) > 2 raises ValueError + assert_raises( + ValueError, + self.pol.rand_polarisation_field, + shape=self.im_shape, + random_state=42, + order=[10, 10, 10], + ) + + def test_polarisation_field_scale(self): + """Test polarisation field method for different scales.""" + + pf_ref = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + ) + + # scale = None + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + scale=None, + ) + + # expected to raise an AssertionError + assert_raises(AssertionError, assert_array_equal, pf, pf_ref) + + # scale = [0.25, 0.25] + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, random_state=42, scale=[0.25, 0.25] + ) + + # expected to raise an AssertionError + assert_raises(AssertionError, assert_array_equal, pf, pf_ref) + + def test_polarisation_field_threshold(self): + """Test polarisation field method for different threshold.""" + + pf_ref = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + ) + + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + threshold=0.5, + ) + + # expected to raise an AssertionError + assert_raises(AssertionError, assert_array_equal, pf, pf_ref) From 12c9d56628c670cd3bf95d3f48f10bc976e85f08 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 25 Oct 2024 15:08:51 +0200 Subject: [PATCH 13/52] Fix bug in Polarisation class --- pyvisgen/simulation/visibility.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index c4e6c6e..259454d 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -93,7 +93,7 @@ def __init__( if random_state: torch.manual_seed(random_state) - if self.polarisation: + if self.polarisation and self.polarisation in ["circular", "linear"]: self.polarisation_field = self.rand_polarisation_field( [self.SI.shape[0], self.SI.shape[1]], **field_kwargs, @@ -282,6 +282,9 @@ def rand_polarisation_field( if random_state: torch.random.manual_seed(random_state) + if isinstance(shape, int): + shape = [shape] + if not isinstance(shape, list): shape = list(shape) From d4522f97cdee690dfec1daf960b3d6582c0961d3 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 25 Oct 2024 15:46:49 +0200 Subject: [PATCH 14/52] Add tests for remaining edge cases --- tests/test_simulation.py | 41 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 17ff26a..fda7698 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -143,6 +143,22 @@ def test_polarisation_linear(self): assert lin_dop.shape == self.im_shape assert lin_dop.shape == self.im_shape + def test_polarisation_amplitude(self): + """Test random amplitude.""" + pol_kwargs = {"delta": 0, "amp_ratio": None, "random_state": 42} + + self.pol.__init__( + self.SI, + sensitivity_cut=self.obs.sensitivity_cut, + polarisation="linear", + device=self.obs.device, + field_kwargs=self.obs.field_kwargs, + **pol_kwargs, + ) + + assert self.pol.ax2.sum() <= 9 + assert self.pol.ay2.sum() <= 9 + def test_polarisation_field(self): """Test Polarisation.rand_polarisation_field method.""" pf = self.pol.rand_polarisation_field(shape=self.im_shape) @@ -161,13 +177,28 @@ def test_polarisation_field_random_state(self): assert torch.random.initial_seed() == random_state assert pf.shape == torch.Size([100, 100]) - def test_polarisation_field_shape_int(self): + def test_polarisation_field_shape(self): """Test polarisation field method for type(shape) = int.""" + pf_ref = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + ) + pf = self.pol.rand_polarisation_field( shape=self.im_shape[0], + random_state=42, ) assert pf.shape == torch.Size([100, 100]) + assert_array_equal(pf, pf_ref, strict=True) + + # assert len(shape) > 2 raises ValueError + assert_raises( + ValueError, + self.pol.rand_polarisation_field, + shape=[100, 100, 100], + random_state=42, + ) def test_polarisation_field_order(self): """Test polarisation field method for different orders.""" @@ -187,6 +218,14 @@ def test_polarisation_field_order(self): # assert order = 1 and order = [1, 1] yield same images assert_array_equal(pf, pf_ref, strict=True) + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + order=[1], + ) + # assert order = [1] and order = [1, 1] yield same images + assert_array_equal(pf, pf_ref, strict=True) + # assert different order creates different image pf = self.pol.rand_polarisation_field( shape=self.im_shape, random_state=42, order=[10, 10] From 4edfd60846b807fbc15951e1b4bef9d13e138cbd Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 25 Oct 2024 17:55:47 +0200 Subject: [PATCH 15/52] Add test for order as tuple --- tests/test_simulation.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index fda7698..585ba85 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -218,6 +218,14 @@ def test_polarisation_field_order(self): # assert order = 1 and order = [1, 1] yield same images assert_array_equal(pf, pf_ref, strict=True) + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + order=(1, 1), + ) + # assert order = (1, 1) and order = [1, 1] yield same images + assert_array_equal(pf, pf_ref, strict=True) + pf = self.pol.rand_polarisation_field( shape=self.im_shape, random_state=42, From dbe802f35b042543d5dd8c01de38a6e23408420c Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Thu, 9 Jan 2025 17:59:41 +0100 Subject: [PATCH 16/52] Add parallactic angle/feed rotation calculation --- pyvisgen/simulation/observation.py | 117 ++++++++++++++++++++++++++--- 1 file changed, 106 insertions(+), 11 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index ad0e3cc..1eb6b3a 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -7,7 +7,7 @@ import numpy as np import torch from astropy.constants import c -from astropy.coordinates import AltAz, Angle, EarthLocation, SkyCoord +from astropy.coordinates import AltAz, Angle, EarthLocation, SkyCoord, Longitude from astropy.time import Time from pyvisgen.layouts import layouts @@ -23,6 +23,9 @@ class Baselines: w: torch.tensor valid: torch.tensor time: torch.tensor + q_all: torch.tensor + q1: torch.tensor + q2: torch.tensor def __getitem__(self, i): return Baselines(*[getattr(self, f.name)[i] for f in fields(self)]) @@ -61,6 +64,9 @@ def get_valid_subset(self, num_baselines, device): v_valid = (v_start + v_stop) / 2 w_valid = (w_start + w_stop) / 2 + q1_valid = bas_reshaped.q1[mask].to(device) + q2_valid = bas_reshaped.q2[mask].to(device) + t = Time(bas_reshaped.time / (60 * 60 * 24), format="mjd").jd date = (torch.from_numpy(t[:-1][mask] + t[1:][mask]) / 2).to(device) @@ -76,6 +82,8 @@ def get_valid_subset(self, num_baselines, device): w_valid, baseline_nums, date, + q1_valid, + q2_valid, ) @@ -92,6 +100,8 @@ class ValidBaselineSubset: w_valid: torch.tensor baseline_nums: torch.tensor date: torch.tensor + q1_valid: torch.tensor + q2_valid: torch.tensor def __getitem__(self, i): return torch.stack( @@ -107,6 +117,8 @@ def __getitem__(self, i): self.w_valid, self.baseline_nums, self.date, + self.q1_valid, + self.q2_valid, ] ) @@ -236,7 +248,7 @@ def __init__( "delta": 0, "amp_ratio": 0.5, "random_state": 42, - } + }` field_kwargs : dict, optional Additional keyword arguments for the random polarisation field that is applied when simulating polarisation. @@ -265,7 +277,8 @@ def __init__( self.times, self.times_mjd = self.calc_time_steps() self.scans = torch.stack( torch.split( - torch.arange(len(self.times)), (len(self.times) // self.num_scans) + torch.arange(self.times.size), + (self.times.size // self.num_scans), ), dim=0, ) @@ -290,6 +303,9 @@ def __init__( self.layout = array_layout self.array = layouts.get_array_layout(array_layout) + self.array_earth_loc = EarthLocation.from_geocentric( + self.array.x, self.array.y, self.array.z, unit=un.m + ) self.num_baselines = int( len(self.array.st_num) * (len(self.array.st_num) - 1) / 2 ) @@ -303,7 +319,7 @@ def __init__( else: self.calc_baselines() self.baselines.num = int( - len(self.array.st_num) * (len(self.array.st_num) - 1) / 2 + self.array.st_num.size(dim=0) * (self.array.st_num.size(dim=0) - 1) / 2 ) self.baselines.times_unique = torch.unique(self.baselines.time) @@ -365,7 +381,15 @@ def calc_baselines(self): torch.tensor([]), torch.tensor([]), torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), ) + self.q_comb_l = [] + self.q_all_l = [] + + self.n = 1 + for scan in self.scans: bas = self.get_baselines(self.times[scan]) self.baselines.add_baseline(bas) @@ -389,12 +413,25 @@ def calc_ref_elev(self, time=None): time = self.times if time.shape == (): time = time[None] + src_crd = SkyCoord(ra=self.ra, dec=self.dec, unit=(un.deg, un.deg)) # Calculate for all times # calculate GHA, Greenwich as reference - ha_all = Angle( + GHA = Angle( [t.sidereal_time("apparent", "greenwich") - src_crd.ra for t in time] ) + self.ha_all = GHA + + # calculate local sidereal time and HA at each antenna + lst = un.Quantity( + [ + Time(time, location=loc).sidereal_time("mean") + for loc in self.array_earth_loc + ] + ) + ha_local = torch.from_numpy( + (lst - Longitude(self.ra.item(), unit=un.deg)).radian + ).T # calculate elevations el_st_all = src_crd.transform_to( @@ -408,8 +445,38 @@ def calc_ref_elev(self, time=None): ), ) ) - assert len(ha_all.value) == len(el_st_all) - return torch.tensor(ha_all.deg), torch.tensor(el_st_all.alt.degree) + assert len(GHA.value) == len(el_st_all) + return torch.tensor(GHA.deg), ha_local, torch.tensor(el_st_all.alt.degree) + + def calc_feed_rotation(self, ha: Angle) -> Angle: + r"""Calculates feed rotation for every antenna at every time step. + + Notes + ----- + The calculation is based on Equation (13.1) of Meeus' + Astronomical Algorithms: + + .. math:: + + q = \frac{\sin h}{\cos\delta \tan\varphi - \sin\delta \cos h, + + where $h$ is the local hour angle, $\varphi$ the geographical latitude + of the observer, and $\delta$ the declination of the source. + """ + # We need to create a tensor from the EarthLocation object + # and save only the geographical latitude of each antenna + ant_lat = torch.tensor(self.array_earth_loc.lat) + + # Eqn (13.1) of Meeus' Astronomical Algorithms + q = torch.arctan2( + torch.sin(ha), + ( + torch.tan(ant_lat) * torch.cos(self.dec) + - torch.sin(self.dec) * torch.cos(ha) + ), + ) + + return q def test_active_telescopes(self): _, el_st_0 = self.calc_ref_elev(self.times[0]) @@ -501,9 +568,10 @@ def get_baselines(self, times): dataclass object baselines between telescopes with visibility flags """ - # Calculate for all times - # calculate GHA, Greenwich as reference - ha_all, el_st_all = self.calc_ref_elev(time=times) + # calculate GHA, local HA, and station elevation for all times. + GHA, ha_local, el_st_all = self.calc_ref_elev(time=times) + + self.el_st_all = el_st_all ar = Array(self.array) delta_x, delta_y, delta_z = ar.calc_relative_pos @@ -518,8 +586,31 @@ def get_baselines(self, times): torch.tensor([]), torch.tensor([]), torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), ) - for ha, el_st, time in zip(ha_all, el_st_all, times): + q_all = self.calc_feed_rotation(ha_local) + q_comb = torch.vstack([torch.combinations(qi) for qi in q_all]) + q_comb = q_comb.reshape(-1, int(q_comb.shape[0] / times.shape[0]), 2) + + self.q_comb_l.append(q_comb) + self.q_all_l.append(q_all) + + print( + GHA.shape, + el_st_all.shape, + times.shape, + q_all.shape, + q_comb.shape, + ) + + self.GHA = GHA + self.delx = delta_x + self.dely = delta_y + self.delz = delta_z + + for ha, el_st, time, q, qc in zip(GHA, el_st_all, times, q_all, q_comb): u, v, w = self.calc_direction_cosines(ha, el_st, delta_x, delta_y, delta_z) # calc current elevations @@ -535,6 +626,7 @@ def get_baselines(self, times): time_mjd = torch.repeat_interleave( torch.tensor(time.mjd) * (24 * 60 * 60), len(valid) ) + # collect baselines base = Baselines( st_num_pairs[:, 0], @@ -544,6 +636,9 @@ def get_baselines(self, times): w, valid, time_mjd, + q, + qc[..., 0].ravel(), + qc[..., 1].ravel(), ) baselines.add_baseline(base) return baselines From 489ba3bf4248815a640a081b39fe4816a2ff2769 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Thu, 9 Jan 2025 18:00:14 +0100 Subject: [PATCH 17/52] Add parallactic angle/feed rotation matrices to RIME --- pyvisgen/simulation/scan.py | 50 ++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/pyvisgen/simulation/scan.py b/pyvisgen/simulation/scan.py index 2b92e9e..d3077c6 100644 --- a/pyvisgen/simulation/scan.py +++ b/pyvisgen/simulation/scan.py @@ -6,7 +6,19 @@ @torch.compile -def rime(img, bas, lm, rd, ra, dec, ant_diam, spw_low, spw_high, corrupted=False): +def rime( + img, + bas, + lm, + rd, + ra, + dec, + ant_diam, + spw_low, + spw_high, + polarisation, + corrupted=False, +): """Calculates visibilities using RIME Parameters @@ -21,6 +33,8 @@ def rime(img, bas, lm, rd, ra, dec, ant_diam, spw_low, spw_high, corrupted=False lower wavelength spw_high : float higher wavelength + polarisation : str + Type of polarisation. Returns ------- @@ -29,8 +43,11 @@ def rime(img, bas, lm, rd, ra, dec, ant_diam, spw_low, spw_high, corrupted=False """ with torch.no_grad(): X1, X2 = calc_fourier(img, bas, lm, spw_low, spw_high) + print(X1.shape) if corrupted: X1, X2 = calc_beam(X1, X2, rd, ra, dec, ant_diam, spw_low, spw_high) + + X1, X2 = calc_feed_rotation(X1, X2, bas.q1, bas.q2, polarisation) vis = integrate(X1, X2) return vis @@ -77,6 +94,37 @@ def calc_fourier(img, bas, lm, spw_low, spw_high): return img * K1, img * K2 +@torch.compile +def calc_feed_rotation(X1, X2, q1, q2, polarisation): + """ """ + P1 = torch.ones_like(X1) + P2 = torch.ones_like(X2) + + if polarisation == "linear": + P1[..., 0, 0] = torch.cos(q1) + P1[..., 0, 1] = torch.sin(q1) + P1[..., 1, 0] = -torch.sin(q1) + P1[..., 1, 1] = torch.cos(q1) + + P2[..., 0, 0] = torch.cos(q2) + P2[..., 0, 1] = torch.sin(q2) + P2[..., 1, 0] = -torch.sin(q2) + P2[..., 1, 1] = torch.cos(q2) + + if polarisation == "circular": + P1[..., 0, 0] = torch.exp(1j * q1) + P1[..., 0, 1] = 0 + P1[..., 1, 0] = 0 + P1[..., 1, 1] = torch.exp(-1j * q1) + + P2[..., 0, 0] = torch.exp(1j * q2) + P2[..., 0, 1] = 0 + P2[..., 1, 0] = 0 + P2[..., 1, 1] = torch.exp(-1j * q2) + + return img * P1, img * P2 + + @torch.compile def calc_beam(X1, X2, rd, ra, dec, ant_diam, spw_low, spw_high): diameters = ant_diam.to(rd.device) From 011670f898028882cc074412c987bd178fe0c34f Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Thu, 9 Jan 2025 18:00:57 +0100 Subject: [PATCH 18/52] Pass current polarisation from vis_loop to RIME --- pyvisgen/simulation/visibility.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 259454d..4fdd40b 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -442,6 +442,7 @@ def vis_loop( torch.unique(obs.array.diam), wave_low, wave_high, + obs.polarisation, corrupted=obs.corrupted, )[None] for wave_low, wave_high in zip(obs.waves_low, obs.waves_high) From 2f79cd928b9da78a37b1881941a34c1936fd720d Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 10 Jan 2025 18:42:17 +0100 Subject: [PATCH 19/52] Restructure pyvisgen.simulation.observation.Observation --- pyvisgen/simulation/observation.py | 275 +++++++++++++++-------------- 1 file changed, 147 insertions(+), 128 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index 1eb6b3a..a27cce8 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -274,6 +274,7 @@ def __init__( self.num_scans = num_scans self.int_time = integration_time self.scan_separation = scan_separation + self.times, self.times_mjd = self.calc_time_steps() self.scans = torch.stack( torch.split( @@ -286,6 +287,7 @@ def __init__( self.ref_frequency = torch.tensor(ref_frequency) self.bandwidths = torch.tensor(bandwidths) self.frequency_offsets = torch.tensor(frequency_offsets) + self.waves_low = ( self.ref_frequency + self.frequency_offsets ) - self.bandwidths / 2 @@ -331,6 +333,30 @@ def __init__( self.pol_kwargs = pol_kwargs self.field_kwargs = field_kwargs + def calc_time_steps(self): + """Computes the time steps of the observation. + + Returns + ------- + time : array_like + Array of time steps. + time.mjd : array_like + Time steps in mjd format. + """ + time_lst = [ + self.start + + self.scan_separation * i * un.second + + i * self.scan_duration * un.second + + j * self.int_time * un.second + for i in range(self.num_scans) + for j in range(int(self.scan_duration / self.int_time) + 1) + ] + # +1 because t_1 is the stop time of t_0 + # in order to save computing power we take one time more to complete interval + time = Time(time_lst) + + return time, time.mjd * (60 * 60 * 24) + def calc_dense_baselines(self): N = self.img_size fov = self.fov * pi / (3600 * 180) @@ -394,19 +420,96 @@ def calc_baselines(self): bas = self.get_baselines(self.times[scan]) self.baselines.add_baseline(bas) - def calc_time_steps(self): - time_lst = [ - self.start - + self.scan_separation * i * un.second - + i * self.scan_duration * un.second - + j * self.int_time * un.second - for i in range(self.num_scans) - for j in range(int(self.scan_duration / self.int_time) + 1) - ] - # +1 because t_1 is the stop time of t_0 - # in order to save computing power we take one time more to complete interval - time = Time(time_lst) - return time, time.mjd * (60 * 60 * 24) + def get_baselines(self, times): + """Calculates baselines from source coordinates and time of observation for + every antenna station in array_layout. + + Parameters + ---------- + times : time object + time of observation + + Returns + ------- + dataclass object + baselines between telescopes with visibility flags + """ + # calculate GHA, local HA, and station elevation for all times. + GHA, ha_local, el_st_all = self.calc_ref_elev(time=times) + + self.el_st_all = el_st_all + + ar = Array(self.array) + delta_x, delta_y, delta_z = ar.calc_relative_pos + st_num_pairs, els_low_pairs, els_high_pairs = ar.calc_ant_pair_vals + + # Loop over ha and el_st + baselines = Baselines( + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + ) + q_all = self.calc_feed_rotation(ha_local) + q_comb = torch.vstack([torch.combinations(qi) for qi in q_all]) + q_comb = q_comb.reshape(-1, int(q_comb.shape[0] / times.shape[0]), 2) + + self.q_comb_l.append(q_comb) + self.q_all_l.append(q_all) + + print( + GHA.shape, + el_st_all.shape, + times.shape, + q_all.shape, + q_comb.shape, + ) + + self.GHA = GHA + self.delx = delta_x + self.dely = delta_y + self.delz = delta_z + + for ha, el_st, time, q, qc in zip(GHA, el_st_all, times, q_all, q_comb): + u, v, w = self.calc_direction_cosines(ha, el_st, delta_x, delta_y, delta_z) + + # calc current elevations + cur_el_st = torch.combinations(el_st) + + # calc valid baselines + m1 = (cur_el_st < els_low_pairs).any(axis=1) + m2 = (cur_el_st > els_high_pairs).any(axis=1) + + valid = torch.ones(u.shape).bool() + valid_mask = torch.logical_or(m1, m2) + valid[valid_mask] = False + + time_mjd = torch.repeat_interleave( + torch.tensor(time.mjd) * (24 * 60 * 60), len(valid) + ) + + # collect baselines + base = Baselines( + st_num_pairs[:, 0], + st_num_pairs[:, 1], + u, + v, + w, + valid, + time_mjd, + q, + qc[..., 0].ravel(), + qc[..., 1].ravel(), + ) + baselines.add_baseline(base) + + return baselines def calc_ref_elev(self, time=None): if time is None: @@ -446,6 +549,13 @@ def calc_ref_elev(self, time=None): ) ) assert len(GHA.value) == len(el_st_all) + + if not len(GHA.value) == len(el_st_all): + raise ValueError( + "Expected GHA and el_st_all to have the same length" + f"{len(GHA.value)} and {len(el_st_all)}" + ) + return torch.tensor(GHA.deg), ha_local, torch.tensor(el_st_all.alt.degree) def calc_feed_rotation(self, ha: Angle) -> Angle: @@ -478,14 +588,29 @@ def calc_feed_rotation(self, ha: Angle) -> Angle: return q - def test_active_telescopes(self): - _, el_st_0 = self.calc_ref_elev(self.times[0]) - _, el_st_1 = self.calc_ref_elev(self.times[1]) - el_min = 15 - el_max = 85 - active_telescopes_0 = np.sum((el_st_0 >= el_min) & (el_st_0 <= el_max)) - active_telescopes_1 = np.sum((el_st_1 >= el_min) & (el_st_1 <= el_max)) - return min(active_telescopes_0, active_telescopes_1) + def calc_direction_cosines(self, ha, el_st, delta_x, delta_y, delta_z): + src_dec = torch.deg2rad(self.dec) + ha = torch.deg2rad(ha) + + u = (torch.sin(ha) * delta_x + torch.cos(ha) * delta_y).reshape(-1) + v = ( + -torch.sin(src_dec) * torch.cos(ha) * delta_x + + torch.sin(src_dec) * torch.sin(ha) * delta_y + + torch.cos(src_dec) * delta_z + ).reshape(-1) + w = ( + torch.cos(src_dec) * torch.cos(ha) * delta_x + - torch.cos(src_dec) * torch.sin(ha) * delta_y + + torch.sin(src_dec) * delta_z + ).reshape(-1) + + if not (u.shape == v.shape == w.shape): + raise ValueError( + "u, v, w array shapes are not the same: " + f"{u.shape}, {v.shape}, {w.shape}" + ) + + return u, v, w def create_rd_grid(self): """Calculates RA and Dec values for a given fov around a source position @@ -542,7 +667,7 @@ def create_lm_grid(self): Returns ------- lm_grid : 3d array - Returns a 3d array with every pixel containing a l and m value + Returns a 3d array with every pixel containing an l and m value """ dec = torch.deg2rad(self.dec) @@ -553,109 +678,3 @@ def create_lm_grid(self): ) * torch.sin(dec) * torch.cos(self.rd[..., 0]) return lm_grid - - def get_baselines(self, times): - """Calculates baselines from source coordinates and time of observation for - every antenna station in array_layout. - - Parameters - ---------- - times : time object - time of observation - - Returns - ------- - dataclass object - baselines between telescopes with visibility flags - """ - # calculate GHA, local HA, and station elevation for all times. - GHA, ha_local, el_st_all = self.calc_ref_elev(time=times) - - self.el_st_all = el_st_all - - ar = Array(self.array) - delta_x, delta_y, delta_z = ar.calc_relative_pos - st_num_pairs, els_low_pairs, els_high_pairs = ar.calc_ant_pair_vals - - # Loop over ha and el_st - baselines = Baselines( - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - ) - q_all = self.calc_feed_rotation(ha_local) - q_comb = torch.vstack([torch.combinations(qi) for qi in q_all]) - q_comb = q_comb.reshape(-1, int(q_comb.shape[0] / times.shape[0]), 2) - - self.q_comb_l.append(q_comb) - self.q_all_l.append(q_all) - - print( - GHA.shape, - el_st_all.shape, - times.shape, - q_all.shape, - q_comb.shape, - ) - - self.GHA = GHA - self.delx = delta_x - self.dely = delta_y - self.delz = delta_z - - for ha, el_st, time, q, qc in zip(GHA, el_st_all, times, q_all, q_comb): - u, v, w = self.calc_direction_cosines(ha, el_st, delta_x, delta_y, delta_z) - - # calc current elevations - cur_el_st = torch.combinations(el_st) - - # calc valid baselines - valid = torch.ones(u.shape).bool() - m1 = (cur_el_st < els_low_pairs).any(axis=1) - m2 = (cur_el_st > els_high_pairs).any(axis=1) - valid_mask = torch.logical_or(m1, m2) - valid[valid_mask] = False - - time_mjd = torch.repeat_interleave( - torch.tensor(time.mjd) * (24 * 60 * 60), len(valid) - ) - - # collect baselines - base = Baselines( - st_num_pairs[:, 0], - st_num_pairs[:, 1], - u, - v, - w, - valid, - time_mjd, - q, - qc[..., 0].ravel(), - qc[..., 1].ravel(), - ) - baselines.add_baseline(base) - return baselines - - def calc_direction_cosines(self, ha, el_st, delta_x, delta_y, delta_z): - src_dec = torch.deg2rad(self.dec) - ha = torch.deg2rad(ha) - u = (torch.sin(ha) * delta_x + torch.cos(ha) * delta_y).reshape(-1) - v = ( - -torch.sin(src_dec) * torch.cos(ha) * delta_x - + torch.sin(src_dec) * torch.sin(ha) * delta_y - + torch.cos(src_dec) * delta_z - ).reshape(-1) - w = ( - torch.cos(src_dec) * torch.cos(ha) * delta_x - - torch.cos(src_dec) * torch.sin(ha) * delta_y - + torch.sin(src_dec) * delta_z - ).reshape(-1) - assert u.shape == v.shape == w.shape - return u, v, w From c2de45927907cd236e0f80e4446712c60294afb5 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 14 Jan 2025 15:44:04 +0100 Subject: [PATCH 20/52] Fix valid q1, q2 computation in Baselines class --- pyvisgen/simulation/observation.py | 46 ++++++++++++++++-------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index a27cce8..9ad5d00 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -23,7 +23,6 @@ class Baselines: w: torch.tensor valid: torch.tensor time: torch.tensor - q_all: torch.tensor q1: torch.tensor q2: torch.tensor @@ -64,8 +63,14 @@ def get_valid_subset(self, num_baselines, device): v_valid = (v_start + v_stop) / 2 w_valid = (w_start + w_stop) / 2 - q1_valid = bas_reshaped.q1[mask].to(device) - q2_valid = bas_reshaped.q2[mask].to(device) + q1_start = bas_reshaped.q1[:-1][mask].to(device) + q2_start = bas_reshaped.q2[:-1][mask].to(device) + + q1_stop = bas_reshaped.q1[1:][mask].to(device) + q2_stop = bas_reshaped.q2[1:][mask].to(device) + + q1_valid = (q1_start + q1_stop) / 2 + q2_valid = (q2_start + q2_stop) / 2 t = Time(bas_reshaped.time / (60 * 60 * 24), format="mjd").jd date = (torch.from_numpy(t[:-1][mask] + t[1:][mask]) / 2).to(device) @@ -82,7 +87,11 @@ def get_valid_subset(self, num_baselines, device): w_valid, baseline_nums, date, + q1_start, + q1_stop, q1_valid, + q2_start, + q2_stop, q2_valid, ) @@ -100,7 +109,11 @@ class ValidBaselineSubset: w_valid: torch.tensor baseline_nums: torch.tensor date: torch.tensor + q1_start: torch.tensor + q1_stop: torch.tensor q1_valid: torch.tensor + q2_start: torch.tensor + q2_stop: torch.tensor q2_valid: torch.tensor def __getitem__(self, i): @@ -117,7 +130,11 @@ def __getitem__(self, i): self.w_valid, self.baseline_nums, self.date, + self.q1_start, + self.q1_stop, self.q1_valid, + self.q2_start, + self.q2_stop, self.q2_valid, ] ) @@ -129,6 +146,8 @@ def get_timerange(self, t_start, t_stop): def get_unique_grid(self, fov_size, ref_frequency, img_size, device): uv = torch.cat([self.u_valid[None], self.v_valid[None]], dim=0) + q = torch.cat([self.q1_valid[None], self.q2_valid[None]], dim=0) + fov = fov_size * pi / (3600 * 180) delta = 1 / fov * const.c.value.item() / ref_frequency bins = ( @@ -140,8 +159,10 @@ def get_unique_grid(self, fov_size, ref_frequency, img_size, device): ) + delta / 2 ) + if len(bins) - 1 > img_size: bins = bins[:-1] + indices_bucket = torch.bucketize(uv, bins) indices_bucket_sort, indices_bucket_inv = self._lexsort(indices_bucket) indices_unique, indices_unique_inv, counts = torch.unique_consecutive( @@ -409,7 +430,6 @@ def calc_baselines(self): torch.tensor([]), torch.tensor([]), torch.tensor([]), - torch.tensor([]), ) self.q_comb_l = [] self.q_all_l = [] @@ -454,28 +474,11 @@ def get_baselines(self, times): torch.tensor([]), torch.tensor([]), torch.tensor([]), - torch.tensor([]), ) q_all = self.calc_feed_rotation(ha_local) q_comb = torch.vstack([torch.combinations(qi) for qi in q_all]) q_comb = q_comb.reshape(-1, int(q_comb.shape[0] / times.shape[0]), 2) - self.q_comb_l.append(q_comb) - self.q_all_l.append(q_all) - - print( - GHA.shape, - el_st_all.shape, - times.shape, - q_all.shape, - q_comb.shape, - ) - - self.GHA = GHA - self.delx = delta_x - self.dely = delta_y - self.delz = delta_z - for ha, el_st, time, q, qc in zip(GHA, el_st_all, times, q_all, q_comb): u, v, w = self.calc_direction_cosines(ha, el_st, delta_x, delta_y, delta_z) @@ -503,7 +506,6 @@ def get_baselines(self, times): w, valid, time_mjd, - q, qc[..., 0].ravel(), qc[..., 1].ravel(), ) From 401c6a006c7881ada2877dcc5908fc521ed20e3d Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 14 Jan 2025 15:51:55 +0100 Subject: [PATCH 21/52] Remove debug output, set defaults --- pyvisgen/simulation/observation.py | 43 +++++++++++++----------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index 9ad5d00..2a423fe 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -14,6 +14,20 @@ from pyvisgen.simulation.array import Array +DEFAULT_POL_KWARGS = { + "delta": 0, + "amp_ratio": 0.5, + "random_state": 42, +} + +DEFAULT_FIELD_KWARGS = { + "order": [1, 1], + "scale": [0, 1], + "threshold": None, + "random_state": 42, +} + + @dataclass class Baselines: st1: torch.tensor @@ -146,7 +160,6 @@ def get_timerange(self, t_start, t_stop): def get_unique_grid(self, fov_size, ref_frequency, img_size, device): uv = torch.cat([self.u_valid[None], self.v_valid[None]], dim=0) - q = torch.cat([self.q1_valid[None], self.q2_valid[None]], dim=0) fov = fov_size * pi / (3600 * 180) delta = 1 / fov * const.c.value.item() / ref_frequency @@ -207,17 +220,8 @@ def __init__( dense: bool = False, sensitivity_cut: float = 1e-6, polarisation: str = None, - pol_kwargs: dict = { - "delta": 0, - "amp_ratio": 0.5, - "random_state": 42, - }, - field_kwargs: dict = { - "order": [1, 1], - "scale": [0, 1], - "threshold": None, - "random_state": 42, - }, + pol_kwargs: dict = DEFAULT_POL_KWARGS, + field_kwargs: dict = DEFAULT_FIELD_KWARGS, ) -> None: """Sets up the observation class. @@ -431,10 +435,6 @@ def calc_baselines(self): torch.tensor([]), torch.tensor([]), ) - self.q_comb_l = [] - self.q_all_l = [] - - self.n = 1 for scan in self.scans: bas = self.get_baselines(self.times[scan]) @@ -457,8 +457,6 @@ def get_baselines(self, times): # calculate GHA, local HA, and station elevation for all times. GHA, ha_local, el_st_all = self.calc_ref_elev(time=times) - self.el_st_all = el_st_all - ar = Array(self.array) delta_x, delta_y, delta_z = ar.calc_relative_pos st_num_pairs, els_low_pairs, els_high_pairs = ar.calc_ant_pair_vals @@ -499,8 +497,8 @@ def get_baselines(self, times): # collect baselines base = Baselines( - st_num_pairs[:, 0], - st_num_pairs[:, 1], + st_num_pairs[..., 0], + st_num_pairs[..., 1], u, v, w, @@ -525,7 +523,6 @@ def calc_ref_elev(self, time=None): GHA = Angle( [t.sidereal_time("apparent", "greenwich") - src_crd.ra for t in time] ) - self.ha_all = GHA # calculate local sidereal time and HA at each antenna lst = un.Quantity( @@ -550,8 +547,6 @@ def calc_ref_elev(self, time=None): ), ) ) - assert len(GHA.value) == len(el_st_all) - if not len(GHA.value) == len(el_st_all): raise ValueError( "Expected GHA and el_st_all to have the same length" @@ -608,7 +603,7 @@ def calc_direction_cosines(self, ha, el_st, delta_x, delta_y, delta_z): if not (u.shape == v.shape == w.shape): raise ValueError( - "u, v, w array shapes are not the same: " + "Expected u, v, and w to have the same shapes: " f"{u.shape}, {v.shape}, {w.shape}" ) From 7eb878911582d8aed10a492c8477d9d057fe515b Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 17 Jan 2025 10:09:22 +0100 Subject: [PATCH 22/52] Add optional progress bar --- pyvisgen/simulation/observation.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index 2a423fe..e49d45a 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -4,11 +4,11 @@ import astropy.constants as const import astropy.units as un -import numpy as np import torch from astropy.constants import c from astropy.coordinates import AltAz, Angle, EarthLocation, SkyCoord, Longitude from astropy.time import Time +from tqdm import tqdm from pyvisgen.layouts import layouts from pyvisgen.simulation.array import Array @@ -189,6 +189,7 @@ def get_unique_grid(self, fov_size, ref_frequency, img_size, device): cum_sum = counts.cumsum(0) cum_sum = torch.cat((torch.tensor([0], device=device), cum_sum[:-1])) first_indices = ind_sorted[cum_sum] + return self[:][:, indices_bucket_sort[first_indices]] def _lexsort(self, a, dim=-1): @@ -222,6 +223,7 @@ def __init__( polarisation: str = None, pol_kwargs: dict = DEFAULT_POL_KWARGS, field_kwargs: dict = DEFAULT_FIELD_KWARGS, + show_progress: bool = False, ) -> None: """Sets up the observation class. @@ -283,6 +285,9 @@ def __init__( "threshold": None, "random_state": 42 }` + show_progress : bool, optional + If `True`, show a progress bar during the iteration over the + scans. Default: False Notes ----- @@ -337,6 +342,8 @@ def __init__( len(self.array.st_num) * (len(self.array.st_num) - 1) / 2 ) + self.show_progress = show_progress + if dense: self.waves_low = [self.ref_frequency] self.waves_high = [self.ref_frequency] @@ -424,6 +431,10 @@ def calc_dense_baselines(self): ) def calc_baselines(self): + """Initializes Baselines dataclass object and + calls self.get_baselines to compute the contents of + the Baselines dataclass. + """ self.baselines = Baselines( torch.tensor([]), torch.tensor([]), @@ -436,6 +447,9 @@ def calc_baselines(self): torch.tensor([]), ) + if self.show_progress: + self.scans = tqdm(self.scans) + for scan in self.scans: bas = self.get_baselines(self.times[scan]) self.baselines.add_baseline(bas) @@ -553,7 +567,11 @@ def calc_ref_elev(self, time=None): f"{len(GHA.value)} and {len(el_st_all)}" ) - return torch.tensor(GHA.deg), ha_local, torch.tensor(el_st_all.alt.degree) + return ( + torch.tensor(GHA.deg), + torch.tensor(ha_local), + torch.tensor(el_st_all.alt.degree), + ) def calc_feed_rotation(self, ha: Angle) -> Angle: r"""Calculates feed rotation for every antenna at every time step. From 2c17d85c10cff4c386fd4cb15d1d6ec3fbbd31ac Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 17 Jan 2025 10:09:54 +0100 Subject: [PATCH 23/52] Fix feed rotation computation --- pyvisgen/simulation/scan.py | 46 ++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/pyvisgen/simulation/scan.py b/pyvisgen/simulation/scan.py index d3077c6..2c7dff4 100644 --- a/pyvisgen/simulation/scan.py +++ b/pyvisgen/simulation/scan.py @@ -43,11 +43,11 @@ def rime( """ with torch.no_grad(): X1, X2 = calc_fourier(img, bas, lm, spw_low, spw_high) - print(X1.shape) + if corrupted: X1, X2 = calc_beam(X1, X2, rd, ra, dec, ant_diam, spw_low, spw_high) - X1, X2 = calc_feed_rotation(X1, X2, bas.q1, bas.q2, polarisation) + X1, X2 = calc_feed_rotation(X1, X2, bas, polarisation) vis = integrate(X1, X2) return vis @@ -79,7 +79,7 @@ def calc_fourier(img, bas, lm, spw_low, spw_high): v_cmplt = torch.cat((bas[3], bas[4])) w_cmplt = torch.cat((bas[6], bas[7])) - l = lm[..., 0] + l = lm[..., 0] # noqa: E741 m = lm[..., 1] n = torch.sqrt(1 - l**2 - m**2) @@ -95,34 +95,34 @@ def calc_fourier(img, bas, lm, spw_low, spw_high): @torch.compile -def calc_feed_rotation(X1, X2, q1, q2, polarisation): +def calc_feed_rotation(X1, X2, bas, polarisation): """ """ - P1 = torch.ones_like(X1) - P2 = torch.ones_like(X2) + q1 = torch.cat((bas[11], bas[12]))[..., None] + q2 = torch.cat((bas[14], bas[15]))[..., None] if polarisation == "linear": - P1[..., 0, 0] = torch.cos(q1) - P1[..., 0, 1] = torch.sin(q1) - P1[..., 1, 0] = -torch.sin(q1) - P1[..., 1, 1] = torch.cos(q1) + X1[..., 0, 0] *= torch.cos(q1) + X1[..., 0, 1] *= torch.sin(q1) + X1[..., 1, 0] *= -torch.sin(q1) + X1[..., 1, 1] *= torch.cos(q1) - P2[..., 0, 0] = torch.cos(q2) - P2[..., 0, 1] = torch.sin(q2) - P2[..., 1, 0] = -torch.sin(q2) - P2[..., 1, 1] = torch.cos(q2) + X2[..., 0, 0] *= torch.cos(q2) + X2[..., 0, 1] *= torch.sin(q2) + X2[..., 1, 0] *= -torch.sin(q2) + X2[..., 1, 1] *= torch.cos(q2) if polarisation == "circular": - P1[..., 0, 0] = torch.exp(1j * q1) - P1[..., 0, 1] = 0 - P1[..., 1, 0] = 0 - P1[..., 1, 1] = torch.exp(-1j * q1) + X1[..., 0, 0] *= torch.exp(1j * q1) + X1[..., 0, 1] *= 0 + X1[..., 1, 0] *= 0 + X1[..., 1, 1] *= torch.exp(-1j * q1) - P2[..., 0, 0] = torch.exp(1j * q2) - P2[..., 0, 1] = 0 - P2[..., 1, 0] = 0 - P2[..., 1, 1] = torch.exp(-1j * q2) + X2[..., 0, 0] *= torch.exp(1j * q2) + X2[..., 0, 1] *= 0 + X2[..., 1, 0] *= 0 + X2[..., 1, 1] *= torch.exp(-1j * q2) - return img * P1, img * P2 + return X1, X2 @torch.compile From ee83f36972064184feb44b53f5a5cc3fa91e93c9 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 21 Jan 2025 14:04:36 +0100 Subject: [PATCH 24/52] Add first implementation of polarisation simulation in the vis loop --- pyvisgen/simulation/visibility.py | 388 +++++++++++++++++++++++++++--- 1 file changed, 357 insertions(+), 31 deletions(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 5da0452..cfcfb56 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -1,8 +1,9 @@ from dataclasses import dataclass, fields +from tqdm.autonotebook import tqdm import torch import toma -from tqdm.autonotebook import tqdm +import scipy.ndimage import pyvisgen.simulation.scan as scan @@ -19,6 +20,8 @@ class Visibilities: v: torch.tensor w: torch.tensor date: torch.tensor + linear_dop: torch.tensor + circular_dop: torch.tensor def __getitem__(self, i): return Visibilities(*[getattr(self, f.name)[i] for f in fields(self)]) @@ -39,43 +42,359 @@ def add(self, visibilities): ] +class Polarisation: + """Simulation of polarisation.""" + + def __init__( + self, + SI: torch.tensor, + sensitivity_cut: float, + amp_ratio: float, + delta: float, + polarisation: str, + random_state: int, + device: str, + ) -> None: + """Creates the 2 x 2 stokes matrix and simulates + polarisation if `polarisation` is either 'linear' + or 'circular'. Also computes the degree of polarisation. + + Parameters + ---------- + SI : torch.tensor + Stokes I component, i.e. intensity distribution + of the sky. + sensitivity_cut : float + Sensitivity cut, where only pixels above the value + are kept. + amp_ratio : float + Sets the ratio of $A_{x/r}$. The ratio of $A_{y/l}$ is calculated + as `1 - amp_ratio`. If set to `None`, a random value is drawn + from a uniform distribution. See also: `random_state`. + delta : float + Sets the phase difference of the amplitudes $A_x$ and $A_y$ + of the sky distribution. Defines the measure of ellipticity. + polarisation : str + Choose between `'linear'` or `'circular'` or `None` to + simulate different types of polarisations or disable + the simulation of polarisation. + random_state : int + Random state used when drawing `amp_ratio` and during the generation + of the random polarisation field. + device : str + Torch device to select for computation. + """ + self.sensitivity_cut = sensitivity_cut + self.polarisation = polarisation + self.device = device + + self.SI = SI.permute(dims=(1, 2, 0)) + + if random_state: + torch.manual_seed(random_state) + else: + torch.seed() + + if self.polarisation: + self.polarisation_field = self.rand_polarisation_field( + [SI.shape[0], SI.shape[1]], + random_state=random_state, + order=[1, 1], + scale=[0, 1], + threshold=None, + ) + + self.delta = delta + + if amp_ratio and amp_ratio >= 0: + ax2 = amp_ratio + else: + ax2 = torch.rand(1) + + ay2 = 1 - ax2 + + self.ax2 = self.SI[..., 0] * ax2 + self.ay2 = self.SI[..., 0] * ay2 + + self.I = torch.zeros( + (self.SI.shape[0], self.SI.shape[1], 4), dtype=torch.cdouble + ) # noqa: E741 + + def linear(self) -> None: + """Computes the stokes parameters I, Q, U, and V + for linear polarisation. + + .. math:: + I = A_x^2 + A_y^2 + Q = A_r^2 - A_l^2 + U = 2A_x A_y \cos\delta_{xy} + V = -2A_x A_y \sin\delta_{xy} + """ + self.I[..., 0] = self.ax2 + self.ay2 + self.I[..., 1] = self.ax2 - self.ay2 + self.I[..., 2] = ( + 2 + * torch.sqrt(self.ax2) + * torch.sqrt(self.ay2) + * torch.cos(torch.deg2rad(torch.tensor(self.delta))) + ) + self.I[..., 3] = ( + -2 + * torch.sqrt(self.ax2) + * torch.sqrt(self.ay2) + * torch.sin(torch.deg2rad(torch.tensor(self.delta))) + ) + + def circular(self) -> None: + """Computes the stokes parameters I, Q, U, and V + for circular polarisation. + + .. math:: + I = A_r^2 + A_l^2 + Q = 2A_r A_l \cos\delta_{rl} + U = -2A_r A_l \sin\delta_{rl} + V = A_r^2 - A_l^2 + """ + self.I[..., 0] = self.ax2 + self.ay2 + self.I[..., 1] = ( + 2 + * torch.sqrt(self.ax2) + * torch.sqrt(self.ay2) + * torch.cos(torch.deg2rad(torch.tensor(self.delta))) + ) + self.I[..., 2] = ( + -2 + * torch.sqrt(self.ax2) + * torch.sqrt(self.ay2) + * torch.sin(torch.deg2rad(torch.tensor(self.delta))) + ) + self.I[..., 3] = self.ax2 - self.ay2 + + def dop(self) -> None: + """Computes the degree of polarisation for each pixel.""" + mask = (self.ax2 + self.ay2) > 0 + + # apply polarisation_field to Q, U, and V only + self.I[..., 1] *= self.polarisation_field + self.I[..., 2] *= self.polarisation_field + self.I[..., 3] *= self.polarisation_field + + dop_I = self.I[..., 0].clone() + dop_I[~mask] = float("nan") + dop_Q = self.I[..., 1].clone() + dop_Q[~mask] = float("nan") + dop_U = self.I[..., 2].clone() + dop_U[~mask] = float("nan") + dop_V = self.I[..., 3].clone() + dop_V[~mask] = float("nan") + + self.lin_dop = torch.sqrt(dop_Q**2 + dop_U**2) / dop_I + self.circ_dop = torch.abs(dop_V) / dop_I + + del dop_I, dop_Q, dop_U, dop_V + + def stokes_matrix(self) -> tuple: + """Computes and returns the 2 x 2 stokes matrix B. + + Returns + ------- + B : torch.tensor + 2 x 2 stokes brightness matrix. Either for linear, + circular or no polarisation. + mask : torch.tensor + Mask of the sensitivity cut (Keep all px > sensitivity_cut). + lin_dop : torch.tensor + Degree of linear polarisation of every pixel in the sky. + circ_dop : torch.tensor + Degree of circular polarisation of every pixel in the sky. + """ + # define 2 x 2 Stokes matrix + B = torch.zeros( + (self.SI.shape[0], self.SI.shape[1], 2, 2), dtype=torch.cdouble + ).to(torch.device(self.device)) + + if self.polarisation == "linear": + self.linear() + self.dop() + + B[..., 0, 0] = self.I[..., 0] + self.I[..., 3] + B[..., 0, 1] = self.I[..., 1] + 1j * self.I[..., 2] + B[..., 1, 0] = self.I[..., 1] - 1j * self.I[..., 2] + B[..., 1, 1] = self.I[..., 0] - self.I[..., 3] + + elif self.polarisation == "circular": + self.circular() + self.dop() + + B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] + B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] + B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] + B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] + + else: + # No polarisation applied + self.I[..., 0] = self.SI[..., 0] + self.I[..., 1] = self.SI[..., 0] + self.I[..., 2] = self.SI[..., 0] + self.I[..., 3] = self.SI[..., 0] + + B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] + B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] + B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] + B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] + + # calculations only for px > sensitivity cut + mask = (self.SI >= self.sensitivity_cut)[..., 0] + B = B[mask] + + return B, mask, self.lin_dop, self.circ_dop + + def rand_polarisation_field( + self, + shape: list[int, int] | int, + order: list[int, int] | int = 1, + random_state: int = None, + scale: list = [0, 1], + threshold: float = None, + ) -> torch.tensor: + """ + Generates a random noise mask for polarisation. + + Parameters + ---------- + shape : array_like (M, N), or int + The size of the sky image. + order : array_like (M, N) or int, optional + Morphology of the random noise. Higher values create + more and smaller fluctuations. Default: 1. + random_state : int, optional + Random state for the random number generator. If None, + a random entropy is pulled from the OS. Default: None. + scale : array_like, optional + Scaling of the distribution of the image. Default: [0, 1] + threshold : float, optional + If not None, an upper threshold is applied to the image. + Default: None + + Returns + ------- + im : torch.tensor + An array containing random noise values between + scale[0] and scale[1]. + """ + if random_state: + torch.random.manual_seed(random_state) + + if not isinstance(shape, list): + shape = list(shape) + + if len(shape) < 2: + shape *= 2 + elif len(shape) > 2: + raise ValueError("Only 2d shapes are allowed!") + + if not isinstance(order, list): + order = list(order) + + if len(order) < 2: + order *= 2 + elif len(order) > 2: + raise ValueError("Only 2d shapes are allowed!") + + sigma = torch.mean(torch.tensor(shape).double()) / (40 * torch.tensor(order)) + + im = torch.rand(shape) + im = scipy.ndimage.gaussian_filter(im, sigma=sigma.numpy()) + + if scale is None: + scale = [im.min(), im.max()] + + im_flatten = torch.from_numpy(im.flatten()) + im_argsort = torch.argsort(torch.argsort(im_flatten)) + im_linspace = torch.linspace(*scale, im_argsort.size()[0]) + uniform_flatten = im_linspace[im_argsort] + + im = torch.reshape(uniform_flatten, im.shape) + + if threshold: + im = im < threshold + + return im + + def vis_loop( obs, - SI, - num_threads=10, - noisy=True, - mode="full", - batch_size="auto", - show_progress=False, -): + SI: torch.tensor, + num_threads: int = 10, + noisy: bool = True, + mode: str = "full", + batch_size: int = 100, + polarisation: str = "linear", + delta: float = 0, + amp_ratio: float = None, + random_state: int = 42, + show_progress: bool = False, +) -> Visibilities: + r"""Computes the visibilities of an observation. + + Parameters + ---------- + obs : Observation class object + Observation class object generated by the + `~pyvisgen.simulation.Observation` class. + SI : torch.tensor + Tensor containing the sky intensity distribution. + num_threads : int, optional + Number of threads used for intraoperative parallelism + on the CPU. See `~torch.set_num_threads`. Default: 10 + noisy : bool, optional + If `True`, generate and add additional noise to + the simulated measurements. Default: True + mode : str, optional + Select one of `'full'`, `'grid'`, or `'dense'` to get + all valid baselines, a grid of unique baselines, or + dense baselines. Default: 'full' + batch_size : int, optional + Batch size for iteration over baselines. Default: 100 + polarisation : str, optional + Choose between `'linear'` or `'circular'` or `None` to + simulate different types of polarisations or disable + the simulation of polarisation. Default: 'linear' + delta : float, optional + Sets the phase difference of the amplitudes $A_x$ and $A_y$ + of the sky distribution. Defines the measure of ellipticity. + Default: 0 + amp_ratio : float, optional + Sets the ratio of $A_{x/r}$. The ratio of $A_{y/l}$ is calculated + as `1 - amp_ratio`. If set to `None`, a random value is drawn + from a uniform distribution. See also: `random_state`. Default: None + random_state : int, optional + Random state used when drawing `amp_ratio` and during the generation + of the random polarisation field. Default: 42 + show_progress : bool, optional + If `True`, show a progress bar during the iteration over the + batches of baselines. Default: False + + Returns + ------- + visibilities : Visibilities + Dataclass object containing visibilities and baselines. + """ torch.set_num_threads(num_threads) torch._dynamo.config.suppress_errors = True - if not ( - isinstance(batch_size, int) - or (isinstance(batch_size, str) and batch_size == "auto") - ): - raise ValueError("Expected batch_size to be 'auto' or of type int") - - SI = torch.flip(SI, dims=[1]) + pol = Polarisation( + SI, + sensitivity_cut=obs.sensitivity_cut, + amp_ratio=amp_ratio, + delta=delta, + polarisation=polarisation, + random_state=random_state, + device=obs.device, + ) - # define unpolarized sky distribution - SI = SI.permute(dims=(1, 2, 0)) - I = torch.zeros((SI.shape[0], SI.shape[1], 4), dtype=torch.cdouble) - I[..., 0] = SI[..., 0] + B, mask, lin_dop, circ_dop = pol.stokes_matrix() - # define 2 x 2 Stokes matrix ((I + Q, iU + V), (iU -V, I - Q)) - B = torch.zeros((SI.shape[0], SI.shape[1], 2, 2), dtype=torch.cdouble).to( - torch.device(obs.device) - ) - B[:, :, 0, 0] = I[:, :, 0] + I[:, :, 1] - B[:, :, 0, 1] = I[:, :, 2] + 1j * I[:, :, 3] - B[:, :, 1, 0] = I[:, :, 2] - 1j * I[:, :, 3] - B[:, :, 1, 1] = I[:, :, 0] - I[:, :, 1] - - # calculations only for px > sensitivity cut - mask = (SI >= obs.sensitivity_cut)[..., 0] - B = B[mask] lm = obs.lm[mask] rd = obs.rd[mask] @@ -95,6 +414,8 @@ def vis_loop( torch.tensor([]), torch.tensor([]), torch.tensor([]), + torch.tensor([]), + torch.tensor([]), ) vis_num = torch.zeros(1) if mode == "full": @@ -128,6 +449,9 @@ def vis_loop( show_progress, ) + visibilities.linear_dop = lin_dop.cpu() + visibilities.circular_dop = circ_dop.cpu() + return visibilities @@ -226,6 +550,8 @@ def _batch_loop( bas_p[5].cpu(), bas_p[8].cpu(), bas_p[10].cpu(), + torch.tensor([]), + torch.tensor([]), ) visibilities.add(vis) From 8619a4532ec18ac5dc08c83a38a73926b53fced0 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Wed, 14 Aug 2024 11:09:19 +0200 Subject: [PATCH 25/52] Move arguments for polarisation to Observation class --- pyvisgen/simulation/observation.py | 122 ++++++++++++++++++++++++----- pyvisgen/simulation/visibility.py | 32 ++------ 2 files changed, 110 insertions(+), 44 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index 9fb8c9e..ad0e3cc 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, fields +from datetime import datetime from math import pi import astropy.constants as const @@ -155,24 +156,104 @@ def _lexsort(self, a, dim=-1): class Observation: def __init__( self, - src_ra, - src_dec, - start_time, - scan_duration, - num_scans, - scan_separation, - integration_time, - ref_frequency, - frequency_offsets, - bandwidths, - fov, - image_size, - array_layout, - corrupted, - device, - dense=False, - sensitivity_cut=1e-6, - ): + src_ra: float, + src_dec: float, + start_time: datetime, + scan_duration: int, + num_scans: int, + scan_separation: int, + integration_time: int, + ref_frequency: float, + frequency_offsets: list, + bandwidths: list, + fov: float, + image_size: int, + array_layout: str, + corrupted: bool, + device: str, + dense: bool = False, + sensitivity_cut: float = 1e-6, + polarisation: str = None, + pol_kwargs: dict = { + "delta": 0, + "amp_ratio": 0.5, + "random_state": 42, + }, + field_kwargs: dict = { + "order": [1, 1], + "scale": [0, 1], + "threshold": None, + "random_state": 42, + }, + ) -> None: + """Sets up the observation class. + + Parameters + ---------- + src_ra : float + Source right ascension coordinate. + src_dec : float + Source declination coordinate. + start_time : datetime + Observation start time. + scan_duration : int + Scan duration. + num_scans : int + Number of scans. + scan_separation : int + Scan separation. + integration_time : int + Integration time. + ref_frequency : float + Reference frequency. + frequency_offsets : list + Frequency offsets. + bandwidths : list + Frequency bandwidth. + fov : float + Field of view. + image_size : int + Image size of the sky distribution. + array_layout : str + Name of an existing array layout. See `~pyvisgen.layouts`. + corrupted : bool + If `True`, apply corruption during the vis loop. + device : str + Torch device to select for computation. + dense : bool, optional + If `True`, apply dense baseline calculation of a perfect + interferometer. Default: `False` + sensitivity_cut : float, optional + Sensitivity threshold, where only pixels above the value + are kept. Default: 1e-6 + polarisation : str, optional + Choose between `'linear'` or `'circular'` or `None` to + simulate different types of polarisations or disable + the simulation of polarisation. Default: `None` + pol_kwargs : dict, optional + Additional keyword arguments for the simulation + of polarisation. Default: `{ + "delta": 0, + "amp_ratio": 0.5, + "random_state": 42, + } + field_kwargs : dict, optional + Additional keyword arguments for the random polarisation + field that is applied when simulating polarisation. + Default: `{ + "order": [1, 1], + "scale": [0, 1], + "threshold": None, + "random_state": 42 + }` + + Notes + ----- + See `~pyvisgen.simulation.visibility.Polarisation` and + `~pyvisgen.simulation.visibility.Polarisation.rand_polarisation_field` + for more information on the keyword arguments in `pol_kwargs` + and `field_kwargs`, respectively. + """ self.ra = torch.tensor(src_ra).double() self.dec = torch.tensor(src_dec).double() @@ -229,6 +310,11 @@ def __init__( self.rd = self.create_rd_grid() self.lm = self.create_lm_grid() + # polarisation + self.polarisation = polarisation + self.pol_kwargs = pol_kwargs + self.field_kwargs = field_kwargs + def calc_dense_baselines(self): N = self.img_size fov = self.fov * pi / (3600 * 180) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index cfcfb56..de3ffb4 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -52,6 +52,7 @@ def __init__( amp_ratio: float, delta: float, polarisation: str, + field_kwargs: dict, random_state: int, device: str, ) -> None: @@ -92,21 +93,16 @@ def __init__( if random_state: torch.manual_seed(random_state) - else: - torch.seed() if self.polarisation: self.polarisation_field = self.rand_polarisation_field( [SI.shape[0], SI.shape[1]], - random_state=random_state, - order=[1, 1], - scale=[0, 1], - threshold=None, + **field_kwargs, ) self.delta = delta - if amp_ratio and amp_ratio >= 0: + if amp_ratio and (amp_ratio >= 0): ax2 = amp_ratio else: ax2 = torch.rand(1) @@ -234,9 +230,6 @@ def stokes_matrix(self) -> tuple: else: # No polarisation applied self.I[..., 0] = self.SI[..., 0] - self.I[..., 1] = self.SI[..., 0] - self.I[..., 2] = self.SI[..., 0] - self.I[..., 3] = self.SI[..., 0] B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] @@ -329,10 +322,6 @@ def vis_loop( noisy: bool = True, mode: str = "full", batch_size: int = 100, - polarisation: str = "linear", - delta: float = 0, - amp_ratio: float = None, - random_state: int = 42, show_progress: bool = False, ) -> Visibilities: r"""Computes the visibilities of an observation. @@ -360,14 +349,6 @@ def vis_loop( Choose between `'linear'` or `'circular'` or `None` to simulate different types of polarisations or disable the simulation of polarisation. Default: 'linear' - delta : float, optional - Sets the phase difference of the amplitudes $A_x$ and $A_y$ - of the sky distribution. Defines the measure of ellipticity. - Default: 0 - amp_ratio : float, optional - Sets the ratio of $A_{x/r}$. The ratio of $A_{y/l}$ is calculated - as `1 - amp_ratio`. If set to `None`, a random value is drawn - from a uniform distribution. See also: `random_state`. Default: None random_state : int, optional Random state used when drawing `amp_ratio` and during the generation of the random polarisation field. Default: 42 @@ -386,11 +367,10 @@ def vis_loop( pol = Polarisation( SI, sensitivity_cut=obs.sensitivity_cut, - amp_ratio=amp_ratio, - delta=delta, - polarisation=polarisation, - random_state=random_state, + polarisation=obs.polarisation, device=obs.device, + field_kwargs=obs.field_kwargs, + **obs.pol_kwargs, ) B, mask, lin_dop, circ_dop = pol.stokes_matrix() From 718471970d5e7f56c31a1fc55dfe41f67392b5cc Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Wed, 14 Aug 2024 13:39:14 +0200 Subject: [PATCH 26/52] Catch case where order arg is int in Polarisation class --- pyvisgen/simulation/visibility.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index de3ffb4..12a16a1 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -286,6 +286,9 @@ def rand_polarisation_field( elif len(shape) > 2: raise ValueError("Only 2d shapes are allowed!") + if isinstance(order, int): + order = [order] + if not isinstance(order, list): order = list(order) From 45b5f7c090bd22a1630a2a3b8b01434a8117b509 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Mon, 9 Sep 2024 14:35:27 +0200 Subject: [PATCH 27/52] Update FITS header writer - Add numeric codes for stokes parameters - Add header comments for stokes parameters - Add comment for visibilities (real, complex, weight) --- pyvisgen/fits/writer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pyvisgen/fits/writer.py b/pyvisgen/fits/writer.py index ee004d6..152e07c 100644 --- a/pyvisgen/fits/writer.py +++ b/pyvisgen/fits/writer.py @@ -42,9 +42,17 @@ def create_vis_hdu(data, obs, source_name="sim-source-0"): freq_d = obs.bandwidths[0].cpu().numpy().item() ws = wcs.WCS(naxis=7) + + crval_stokes = -1 + stokes_comment = "-1=RR, -2=LL, -3=RL, -4=LR" + if obs.polarisation == "linear": + crval_stokes = -5 + stokes_comment = "-5=XX, -6=YY, -7=XY, -8=YX" + stokes_comment += " or -5=VV, -6=HH, -7=VH, -8=HV" + ws.wcs.crpix = [1, 1, 1, 1, 1, 1, 1] ws.wcs.cdelt = np.array([1, 1, -1, freq_d, 1, 1, 1]) - ws.wcs.crval = [1, 1, -1, freq, 1, ra, dec] + ws.wcs.crval = [1, 1, crval_stokes, freq, 1, ra, dec] ws.wcs.ctype = ["", "COMPLEX", "STOKES", "FREQ", "IF", "RA", "DEC"] h = ws.to_header() @@ -79,6 +87,8 @@ def create_vis_hdu(data, obs, source_name="sim-source-0"): hdu_vis.header["PZERO" + str(i + 1)] = parbzeros[i] # add comments + hdu_vis.header.comments["CTYPE2"] = "1=real, 2=imag, 3=weight" + hdu_vis.header.comments["CTYPE3"] = stokes_comment hdu_vis.header.comments["PTYPE1"] = "u baseline coordinate in light seconds" hdu_vis.header.comments["PTYPE2"] = "v baseline coordinate in light seconds" hdu_vis.header.comments["PTYPE3"] = "w baseline coordinate in light seconds" From a86f72ae9b2aabcdc93f081a70e335fe8e7c006f Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 21 Jan 2025 14:05:32 +0100 Subject: [PATCH 28/52] Fix stokes calc, fix sequence of visibilities - Fix calculation of Stokes parameters in Polarisation class - Rename visibilities I -> V_11, Q -> V_12, U -> V_21, V -> V_22. The previous names are confusing, as the quantities that are returned are not the Stokes parameters but calculated from the Stokes parameters (i.e. V_11 = I + V for circular polarization) - Fix visbilities sequence -> V_11, V_22, V_12, V_21 (previously V_11, V_12, V_21, V22), i.e. main diagonal of the Jones matrix first, then subdiagonals --- pyvisgen/simulation/visibility.py | 61 +++++++++++++++++-------------- 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 12a16a1..3209fe9 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -10,10 +10,10 @@ @dataclass class Visibilities: - SI: torch.tensor - SQ: torch.tensor - SU: torch.tensor - SV: torch.tensor + V_11: torch.tensor + V_22: torch.tensor + V_12: torch.tensor + V_21: torch.tensor num: torch.tensor base_num: torch.tensor u: torch.tensor @@ -28,7 +28,7 @@ def __getitem__(self, i): def get_values(self): return torch.cat( - [self.SI[None], self.SQ[None], self.SU[None], self.SV[None]], dim=0 + [self.V_11[None], self.V_22[None], self.V_12[None], self.V_21[None]], dim=0 ).permute(1, 2, 0) def add(self, visibilities): @@ -96,7 +96,7 @@ def __init__( if self.polarisation: self.polarisation_field = self.rand_polarisation_field( - [SI.shape[0], SI.shape[1]], + [self.SI.shape[0], self.SI.shape[1]], **field_kwargs, ) @@ -109,8 +109,11 @@ def __init__( ay2 = 1 - ax2 - self.ax2 = self.SI[..., 0] * ax2 - self.ay2 = self.SI[..., 0] * ay2 + self.ax2 = self.SI[..., 0].clone() * ax2 + self.ay2 = self.SI[..., 0].clone() * ay2 + else: + self.ax2 = self.SI[..., 0] + self.ay2 = torch.zeros_like(self.ax2) self.I = torch.zeros( (self.SI.shape[0], self.SI.shape[1], 4), dtype=torch.cdouble @@ -175,13 +178,13 @@ def dop(self) -> None: self.I[..., 2] *= self.polarisation_field self.I[..., 3] *= self.polarisation_field - dop_I = self.I[..., 0].clone() + dop_I = self.I[..., 0].real.clone() dop_I[~mask] = float("nan") - dop_Q = self.I[..., 1].clone() + dop_Q = self.I[..., 1].real.clone() dop_Q[~mask] = float("nan") - dop_U = self.I[..., 2].clone() + dop_U = self.I[..., 2].real.clone() dop_U[~mask] = float("nan") - dop_V = self.I[..., 3].clone() + dop_V = self.I[..., 3].real.clone() dop_V[~mask] = float("nan") self.lin_dop = torch.sqrt(dop_Q**2 + dop_U**2) / dop_I @@ -213,28 +216,30 @@ def stokes_matrix(self) -> tuple: self.linear() self.dop() - B[..., 0, 0] = self.I[..., 0] + self.I[..., 3] - B[..., 0, 1] = self.I[..., 1] + 1j * self.I[..., 2] - B[..., 1, 0] = self.I[..., 1] - 1j * self.I[..., 2] - B[..., 1, 1] = self.I[..., 0] - self.I[..., 3] + B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] # I + Q + B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] # U + iV + B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] # U - iV + B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] # I - Q elif self.polarisation == "circular": self.circular() self.dop() - B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] - B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] - B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] - B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] + B[..., 0, 0] = self.I[..., 0] + self.I[..., 3] # I + V + B[..., 0, 1] = self.I[..., 1] + 1j * self.I[..., 2] # Q + iU + B[..., 1, 0] = self.I[..., 1] - 1j * self.I[..., 2] # Q - iU + B[..., 1, 1] = self.I[..., 0] - self.I[..., 3] # I - V else: # No polarisation applied self.I[..., 0] = self.SI[..., 0] + self.polarisation_field = torch.ones_like(self.I[..., 0]) + self.dop() - B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] - B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] - B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] - B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] + B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] # I + Q + B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] # U + iV + B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] # U - iV + B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] # I - Q # calculations only for px > sensitivity cut mask = (self.SI >= self.sensitivity_cut)[..., 0] @@ -523,10 +528,10 @@ def _batch_loop( vis_num = torch.arange(int_values.shape[0]) + 1 + vis_num.max() vis = Visibilities( - int_values[:, :, 0, 0].cpu(), - int_values[:, :, 0, 1].cpu(), - int_values[:, :, 1, 0].cpu(), - int_values[:, :, 1, 1].cpu(), + int_values[..., 0, 0].cpu(), # V_11 + int_values[..., 1, 1].cpu(), # V_22 + int_values[..., 0, 1].cpu(), # V_12 + int_values[..., 1, 0].cpu(), # V_21 vis_num, bas_p[9].cpu(), bas_p[2].cpu(), From 9f442f15eb61353c89236587a070a177334f463d Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Thu, 12 Sep 2024 17:29:49 +0200 Subject: [PATCH 29/52] Change type hint from str to torch.device for device --- pyvisgen/simulation/visibility.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 3209fe9..720f688 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -54,7 +54,7 @@ def __init__( polarisation: str, field_kwargs: dict, random_state: int, - device: str, + device: torch.device, ) -> None: """Creates the 2 x 2 stokes matrix and simulates polarisation if `polarisation` is either 'linear' @@ -82,7 +82,7 @@ def __init__( random_state : int Random state used when drawing `amp_ratio` and during the generation of the random polarisation field. - device : str + device : torch.device Torch device to select for computation. """ self.sensitivity_cut = sensitivity_cut From ed843fb722c1024eccb2a18808471caedd00b00b Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 20 Sep 2024 16:53:41 +0200 Subject: [PATCH 30/52] Flip input image before creating stokes matrix --- pyvisgen/simulation/visibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 720f688..a655173 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -373,7 +373,7 @@ def vis_loop( torch._dynamo.config.suppress_errors = True pol = Polarisation( - SI, + torch.flip(SI, dims=[1]), sensitivity_cut=obs.sensitivity_cut, polarisation=obs.polarisation, device=obs.device, From 43b8b582815db9323d770280a2733b64ff71e75d Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 1 Oct 2024 10:04:34 +0200 Subject: [PATCH 31/52] Add changelog --- docs/changes/39.bugfix.rst | 3 +++ docs/changes/39.feature.rst | 4 ++++ docs/changes/39.maintenance.rst | 6 ++++++ 3 files changed, 13 insertions(+) create mode 100644 docs/changes/39.bugfix.rst create mode 100644 docs/changes/39.feature.rst create mode 100644 docs/changes/39.maintenance.rst diff --git a/docs/changes/39.bugfix.rst b/docs/changes/39.bugfix.rst new file mode 100644 index 0000000..9a98182 --- /dev/null +++ b/docs/changes/39.bugfix.rst @@ -0,0 +1,3 @@ +- Fix gridding in `pyvisgen.simulation.observation.Observation` methods `create_rd_grid` and `create_lm_grid` + methods resulting in rotated images +- Fix `pyvisgen.simulation.observation.ValidBaselineSubset` dataclass field order diff --git a/docs/changes/39.feature.rst b/docs/changes/39.feature.rst new file mode 100644 index 0000000..d2cda72 --- /dev/null +++ b/docs/changes/39.feature.rst @@ -0,0 +1,4 @@ +- Add class `Polarisation` to `pyvisgen.simulation.visibility` that is called in `vis_loop` + - Add linear, circular, and no polarisation options +- Update `pyvisgen.simulation.visibility.Visibilities` dataclass to also store polarisation degree tensors +- Add keyword arguments for polarisation simulation to `pyvisgen.simulation.observation.Observation` class diff --git a/docs/changes/39.maintenance.rst b/docs/changes/39.maintenance.rst new file mode 100644 index 0000000..83a2365 --- /dev/null +++ b/docs/changes/39.maintenance.rst @@ -0,0 +1,6 @@ +- Change pyvisgen.simulation.visibility.Visibilities dataclass component names from stokes components (I , Q, U, and V) + to visibilities constructed from the stokes components (`V_11`, `V_22`, `V_12`, `V_21`) +- Change indices for stokes components according to AIPS Memo 114 + - Indices will be set automatically depending on simulated polarisation +- Update comment strings in FITS files +- Update docstrings accordingly in `pyvisgen.simulation.visibility.vis_loop` and `pyvisgen.simulation.observation.Observation` From 0c8d7de06793288c72ce62e3f5959cbc5e4b7e84 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 1 Oct 2024 10:08:23 +0200 Subject: [PATCH 32/52] Fix test failing because of api change --- tests/test_simulation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 57613ba..ca59365 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -55,10 +55,10 @@ def test_vis_loop(): SI = torch.tensor(data[0])[None] vis_data = vis_loop(obs, SI, noisy=conf["noisy"], mode=conf["mode"]) - assert (vis_data[0].SI[0]).dtype == torch.complex128 - assert (vis_data[0].SQ[0]).dtype == torch.complex128 - assert (vis_data[0].SU[0]).dtype == torch.complex128 - assert (vis_data[0].SV[0]).dtype == torch.complex128 + assert (vis_data[0].V_11[0]).dtype == torch.complex128 + assert (vis_data[0].V_22[0]).dtype == torch.complex128 + assert (vis_data[0].V_12[0]).dtype == torch.complex128 + assert (vis_data[0].V_21[0]).dtype == torch.complex128 assert (vis_data[0].num).dtype == torch.float32 assert (vis_data[0].base_num).dtype == torch.float64 assert torch.is_tensor(vis_data[0].u) From 697c439313a264bd1550e4226ab95a8a838ff28f Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 1 Oct 2024 10:09:12 +0200 Subject: [PATCH 33/52] Update bugfix changelog --- docs/changes/39.bugfix.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changes/39.bugfix.rst b/docs/changes/39.bugfix.rst index 9a98182..67473db 100644 --- a/docs/changes/39.bugfix.rst +++ b/docs/changes/39.bugfix.rst @@ -1,3 +1,4 @@ - Fix gridding in `pyvisgen.simulation.observation.Observation` methods `create_rd_grid` and `create_lm_grid` methods resulting in rotated images - Fix `pyvisgen.simulation.observation.ValidBaselineSubset` dataclass field order +- Fix test failing because of api change From 22a53fe819f3b8a38f04b22b4217e559c78637e3 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 21 Jan 2025 14:07:02 +0100 Subject: [PATCH 34/52] Add tests for Polarisation class --- tests/test_simulation.py | 186 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 183 insertions(+), 3 deletions(-) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index ca59365..794b97a 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -1,7 +1,7 @@ from pathlib import Path import torch -from numpy.testing import assert_raises +from numpy.testing import assert_array_equal, assert_raises from pyvisgen.utils.config import read_data_set_conf @@ -143,9 +143,189 @@ def test_vis_loop_batch_size_invalid(): mode=conf["mode"], batch_size=20.0, ) - - + + def test_simulate_data_set_no_slurm(): from pyvisgen.simulation.data_set import simulate_data_set simulate_data_set(config) + + +class TestPolarisation: + """Unit test class for ``pyvisgen.simulation.visibility.Polarisation``.""" + + def setup_class(self): + """Set up common objects and variables for the following tests.""" + from pyvisgen.simulation.data_set import create_observation + from pyvisgen.simulation.visibility import Polarisation + + self.obs = create_observation(conf) + + self.SI = torch.zeros((100, 100)) + self.SI[25::25, 25::25] = 1 + self.SI = self.SI[None, ...] + + self.si_shape = self.SI.shape + self.im_shape = self.si_shape[1], self.si_shape[2] + + self.obs.img_size = self.im_shape[0] + + self.pol = Polarisation( + self.SI, + sensitivity_cut=self.obs.sensitivity_cut, + polarisation=self.obs.polarisation, + device=self.obs.device, + field_kwargs=self.obs.field_kwargs, + **self.obs.pol_kwargs, + ) + + def test_polarisation_circular(self): + """Test circular polarisation.""" + + self.pol.__init__( + self.SI, + sensitivity_cut=self.obs.sensitivity_cut, + polarisation="circular", + device=self.obs.device, + field_kwargs=self.obs.field_kwargs, + **self.obs.pol_kwargs, + ) + + assert self.pol.delta == 0 + assert self.pol.ax2.sum() == self.SI.sum() * 0.5 + assert self.pol.ay2.sum() == self.SI.sum() * 0.5 + + B, mask, lin_dop, circ_dop = self.pol.stokes_matrix() + + assert mask.sum() == 9 + assert B.shape == torch.Size([9, 2, 2]) + assert mask.shape == self.im_shape + assert lin_dop.shape == self.im_shape + assert lin_dop.shape == self.im_shape + + def test_polarisation_linear(self): + """Test linear polarisation.""" + + self.pol.__init__( + self.SI, + sensitivity_cut=self.obs.sensitivity_cut, + polarisation="linear", + device=self.obs.device, + field_kwargs=self.obs.field_kwargs, + **self.obs.pol_kwargs, + ) + + assert self.pol.delta == 0 + assert self.pol.ax2.sum() == self.SI.sum() * 0.5 + assert self.pol.ay2.sum() == self.SI.sum() * 0.5 + + B, mask, lin_dop, circ_dop = self.pol.stokes_matrix() + + assert mask.sum() == 9 + assert B.shape == torch.Size([9, 2, 2]) + assert mask.shape == self.im_shape + assert lin_dop.shape == self.im_shape + assert lin_dop.shape == self.im_shape + + def test_polarisation_field(self): + """Test Polarisation.rand_polarisation_field method.""" + pf = self.pol.rand_polarisation_field(shape=self.im_shape) + + assert pf.shape == torch.Size([100, 100]) + + def test_polarisation_field_random_state(self): + """Test polarisation field method for a given random_state""" + random_state = 42 + + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=random_state, + ) + + assert torch.random.initial_seed() == random_state + assert pf.shape == torch.Size([100, 100]) + + def test_polarisation_field_shape_int(self): + """Test polarisation field method for type(shape) = int.""" + pf = self.pol.rand_polarisation_field( + shape=self.im_shape[0], + ) + + assert pf.shape == torch.Size([100, 100]) + + def test_polarisation_field_order(self): + """Test polarisation field method for different orders.""" + + pf_ref = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + ) + + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + order=[1, 1], + ) + + assert pf.shape == torch.Size([100, 100]) + # assert order = 1 and order = [1, 1] yield same images + assert_array_equal(pf, pf_ref, strict=True) + + # assert different order creates different image + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, random_state=42, order=[10, 10] + ) + # expected to raise an AssertionError + assert_raises(AssertionError, assert_array_equal, pf, pf_ref) + + # assert len(order) > 2 raises ValueError + assert_raises( + ValueError, + self.pol.rand_polarisation_field, + shape=self.im_shape, + random_state=42, + order=[10, 10, 10], + ) + + def test_polarisation_field_scale(self): + """Test polarisation field method for different scales.""" + + pf_ref = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + ) + + # scale = None + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + scale=None, + ) + + # expected to raise an AssertionError + assert_raises(AssertionError, assert_array_equal, pf, pf_ref) + + # scale = [0.25, 0.25] + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, random_state=42, scale=[0.25, 0.25] + ) + + # expected to raise an AssertionError + assert_raises(AssertionError, assert_array_equal, pf, pf_ref) + + def test_polarisation_field_threshold(self): + """Test polarisation field method for different threshold.""" + + pf_ref = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + ) + + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + threshold=0.5, + ) + + # expected to raise an AssertionError + assert_raises(AssertionError, assert_array_equal, pf, pf_ref) From 50760a255e78b7877fa764ae3ec5981b7503f239 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 25 Oct 2024 15:08:51 +0200 Subject: [PATCH 35/52] Fix bug in Polarisation class --- pyvisgen/simulation/visibility.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index a655173..8b442ef 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -94,7 +94,7 @@ def __init__( if random_state: torch.manual_seed(random_state) - if self.polarisation: + if self.polarisation and self.polarisation in ["circular", "linear"]: self.polarisation_field = self.rand_polarisation_field( [self.SI.shape[0], self.SI.shape[1]], **field_kwargs, @@ -283,6 +283,9 @@ def rand_polarisation_field( if random_state: torch.random.manual_seed(random_state) + if isinstance(shape, int): + shape = [shape] + if not isinstance(shape, list): shape = list(shape) From 409e922390d5e491b1349481f92ba70cb8a52d90 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 25 Oct 2024 15:46:49 +0200 Subject: [PATCH 36/52] Add tests for remaining edge cases --- tests/test_simulation.py | 41 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 794b97a..be46dae 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -227,6 +227,22 @@ def test_polarisation_linear(self): assert lin_dop.shape == self.im_shape assert lin_dop.shape == self.im_shape + def test_polarisation_amplitude(self): + """Test random amplitude.""" + pol_kwargs = {"delta": 0, "amp_ratio": None, "random_state": 42} + + self.pol.__init__( + self.SI, + sensitivity_cut=self.obs.sensitivity_cut, + polarisation="linear", + device=self.obs.device, + field_kwargs=self.obs.field_kwargs, + **pol_kwargs, + ) + + assert self.pol.ax2.sum() <= 9 + assert self.pol.ay2.sum() <= 9 + def test_polarisation_field(self): """Test Polarisation.rand_polarisation_field method.""" pf = self.pol.rand_polarisation_field(shape=self.im_shape) @@ -245,13 +261,28 @@ def test_polarisation_field_random_state(self): assert torch.random.initial_seed() == random_state assert pf.shape == torch.Size([100, 100]) - def test_polarisation_field_shape_int(self): + def test_polarisation_field_shape(self): """Test polarisation field method for type(shape) = int.""" + pf_ref = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + ) + pf = self.pol.rand_polarisation_field( shape=self.im_shape[0], + random_state=42, ) assert pf.shape == torch.Size([100, 100]) + assert_array_equal(pf, pf_ref, strict=True) + + # assert len(shape) > 2 raises ValueError + assert_raises( + ValueError, + self.pol.rand_polarisation_field, + shape=[100, 100, 100], + random_state=42, + ) def test_polarisation_field_order(self): """Test polarisation field method for different orders.""" @@ -271,6 +302,14 @@ def test_polarisation_field_order(self): # assert order = 1 and order = [1, 1] yield same images assert_array_equal(pf, pf_ref, strict=True) + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + order=[1], + ) + # assert order = [1] and order = [1, 1] yield same images + assert_array_equal(pf, pf_ref, strict=True) + # assert different order creates different image pf = self.pol.rand_polarisation_field( shape=self.im_shape, random_state=42, order=[10, 10] From a2505cd1d35c5e3db9dc656d7c08697f30ab667f Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 25 Oct 2024 17:55:47 +0200 Subject: [PATCH 37/52] Add test for order as tuple --- tests/test_simulation.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index be46dae..d2d277f 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -302,6 +302,14 @@ def test_polarisation_field_order(self): # assert order = 1 and order = [1, 1] yield same images assert_array_equal(pf, pf_ref, strict=True) + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + order=(1, 1), + ) + # assert order = (1, 1) and order = [1, 1] yield same images + assert_array_equal(pf, pf_ref, strict=True) + pf = self.pol.rand_polarisation_field( shape=self.im_shape, random_state=42, From 86f74ef1d941a5f77557e5cd3f1e267db79aba27 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Thu, 9 Jan 2025 17:59:41 +0100 Subject: [PATCH 38/52] Add parallactic angle/feed rotation calculation --- pyvisgen/simulation/observation.py | 117 ++++++++++++++++++++++++++--- 1 file changed, 106 insertions(+), 11 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index ad0e3cc..1eb6b3a 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -7,7 +7,7 @@ import numpy as np import torch from astropy.constants import c -from astropy.coordinates import AltAz, Angle, EarthLocation, SkyCoord +from astropy.coordinates import AltAz, Angle, EarthLocation, SkyCoord, Longitude from astropy.time import Time from pyvisgen.layouts import layouts @@ -23,6 +23,9 @@ class Baselines: w: torch.tensor valid: torch.tensor time: torch.tensor + q_all: torch.tensor + q1: torch.tensor + q2: torch.tensor def __getitem__(self, i): return Baselines(*[getattr(self, f.name)[i] for f in fields(self)]) @@ -61,6 +64,9 @@ def get_valid_subset(self, num_baselines, device): v_valid = (v_start + v_stop) / 2 w_valid = (w_start + w_stop) / 2 + q1_valid = bas_reshaped.q1[mask].to(device) + q2_valid = bas_reshaped.q2[mask].to(device) + t = Time(bas_reshaped.time / (60 * 60 * 24), format="mjd").jd date = (torch.from_numpy(t[:-1][mask] + t[1:][mask]) / 2).to(device) @@ -76,6 +82,8 @@ def get_valid_subset(self, num_baselines, device): w_valid, baseline_nums, date, + q1_valid, + q2_valid, ) @@ -92,6 +100,8 @@ class ValidBaselineSubset: w_valid: torch.tensor baseline_nums: torch.tensor date: torch.tensor + q1_valid: torch.tensor + q2_valid: torch.tensor def __getitem__(self, i): return torch.stack( @@ -107,6 +117,8 @@ def __getitem__(self, i): self.w_valid, self.baseline_nums, self.date, + self.q1_valid, + self.q2_valid, ] ) @@ -236,7 +248,7 @@ def __init__( "delta": 0, "amp_ratio": 0.5, "random_state": 42, - } + }` field_kwargs : dict, optional Additional keyword arguments for the random polarisation field that is applied when simulating polarisation. @@ -265,7 +277,8 @@ def __init__( self.times, self.times_mjd = self.calc_time_steps() self.scans = torch.stack( torch.split( - torch.arange(len(self.times)), (len(self.times) // self.num_scans) + torch.arange(self.times.size), + (self.times.size // self.num_scans), ), dim=0, ) @@ -290,6 +303,9 @@ def __init__( self.layout = array_layout self.array = layouts.get_array_layout(array_layout) + self.array_earth_loc = EarthLocation.from_geocentric( + self.array.x, self.array.y, self.array.z, unit=un.m + ) self.num_baselines = int( len(self.array.st_num) * (len(self.array.st_num) - 1) / 2 ) @@ -303,7 +319,7 @@ def __init__( else: self.calc_baselines() self.baselines.num = int( - len(self.array.st_num) * (len(self.array.st_num) - 1) / 2 + self.array.st_num.size(dim=0) * (self.array.st_num.size(dim=0) - 1) / 2 ) self.baselines.times_unique = torch.unique(self.baselines.time) @@ -365,7 +381,15 @@ def calc_baselines(self): torch.tensor([]), torch.tensor([]), torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), ) + self.q_comb_l = [] + self.q_all_l = [] + + self.n = 1 + for scan in self.scans: bas = self.get_baselines(self.times[scan]) self.baselines.add_baseline(bas) @@ -389,12 +413,25 @@ def calc_ref_elev(self, time=None): time = self.times if time.shape == (): time = time[None] + src_crd = SkyCoord(ra=self.ra, dec=self.dec, unit=(un.deg, un.deg)) # Calculate for all times # calculate GHA, Greenwich as reference - ha_all = Angle( + GHA = Angle( [t.sidereal_time("apparent", "greenwich") - src_crd.ra for t in time] ) + self.ha_all = GHA + + # calculate local sidereal time and HA at each antenna + lst = un.Quantity( + [ + Time(time, location=loc).sidereal_time("mean") + for loc in self.array_earth_loc + ] + ) + ha_local = torch.from_numpy( + (lst - Longitude(self.ra.item(), unit=un.deg)).radian + ).T # calculate elevations el_st_all = src_crd.transform_to( @@ -408,8 +445,38 @@ def calc_ref_elev(self, time=None): ), ) ) - assert len(ha_all.value) == len(el_st_all) - return torch.tensor(ha_all.deg), torch.tensor(el_st_all.alt.degree) + assert len(GHA.value) == len(el_st_all) + return torch.tensor(GHA.deg), ha_local, torch.tensor(el_st_all.alt.degree) + + def calc_feed_rotation(self, ha: Angle) -> Angle: + r"""Calculates feed rotation for every antenna at every time step. + + Notes + ----- + The calculation is based on Equation (13.1) of Meeus' + Astronomical Algorithms: + + .. math:: + + q = \frac{\sin h}{\cos\delta \tan\varphi - \sin\delta \cos h, + + where $h$ is the local hour angle, $\varphi$ the geographical latitude + of the observer, and $\delta$ the declination of the source. + """ + # We need to create a tensor from the EarthLocation object + # and save only the geographical latitude of each antenna + ant_lat = torch.tensor(self.array_earth_loc.lat) + + # Eqn (13.1) of Meeus' Astronomical Algorithms + q = torch.arctan2( + torch.sin(ha), + ( + torch.tan(ant_lat) * torch.cos(self.dec) + - torch.sin(self.dec) * torch.cos(ha) + ), + ) + + return q def test_active_telescopes(self): _, el_st_0 = self.calc_ref_elev(self.times[0]) @@ -501,9 +568,10 @@ def get_baselines(self, times): dataclass object baselines between telescopes with visibility flags """ - # Calculate for all times - # calculate GHA, Greenwich as reference - ha_all, el_st_all = self.calc_ref_elev(time=times) + # calculate GHA, local HA, and station elevation for all times. + GHA, ha_local, el_st_all = self.calc_ref_elev(time=times) + + self.el_st_all = el_st_all ar = Array(self.array) delta_x, delta_y, delta_z = ar.calc_relative_pos @@ -518,8 +586,31 @@ def get_baselines(self, times): torch.tensor([]), torch.tensor([]), torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), ) - for ha, el_st, time in zip(ha_all, el_st_all, times): + q_all = self.calc_feed_rotation(ha_local) + q_comb = torch.vstack([torch.combinations(qi) for qi in q_all]) + q_comb = q_comb.reshape(-1, int(q_comb.shape[0] / times.shape[0]), 2) + + self.q_comb_l.append(q_comb) + self.q_all_l.append(q_all) + + print( + GHA.shape, + el_st_all.shape, + times.shape, + q_all.shape, + q_comb.shape, + ) + + self.GHA = GHA + self.delx = delta_x + self.dely = delta_y + self.delz = delta_z + + for ha, el_st, time, q, qc in zip(GHA, el_st_all, times, q_all, q_comb): u, v, w = self.calc_direction_cosines(ha, el_st, delta_x, delta_y, delta_z) # calc current elevations @@ -535,6 +626,7 @@ def get_baselines(self, times): time_mjd = torch.repeat_interleave( torch.tensor(time.mjd) * (24 * 60 * 60), len(valid) ) + # collect baselines base = Baselines( st_num_pairs[:, 0], @@ -544,6 +636,9 @@ def get_baselines(self, times): w, valid, time_mjd, + q, + qc[..., 0].ravel(), + qc[..., 1].ravel(), ) baselines.add_baseline(base) return baselines From bd3f75b417e73e31fcfeae65f82d48c0dae0b5a6 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Thu, 9 Jan 2025 18:00:14 +0100 Subject: [PATCH 39/52] Add parallactic angle/feed rotation matrices to RIME --- pyvisgen/simulation/scan.py | 50 ++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/pyvisgen/simulation/scan.py b/pyvisgen/simulation/scan.py index 2b92e9e..d3077c6 100644 --- a/pyvisgen/simulation/scan.py +++ b/pyvisgen/simulation/scan.py @@ -6,7 +6,19 @@ @torch.compile -def rime(img, bas, lm, rd, ra, dec, ant_diam, spw_low, spw_high, corrupted=False): +def rime( + img, + bas, + lm, + rd, + ra, + dec, + ant_diam, + spw_low, + spw_high, + polarisation, + corrupted=False, +): """Calculates visibilities using RIME Parameters @@ -21,6 +33,8 @@ def rime(img, bas, lm, rd, ra, dec, ant_diam, spw_low, spw_high, corrupted=False lower wavelength spw_high : float higher wavelength + polarisation : str + Type of polarisation. Returns ------- @@ -29,8 +43,11 @@ def rime(img, bas, lm, rd, ra, dec, ant_diam, spw_low, spw_high, corrupted=False """ with torch.no_grad(): X1, X2 = calc_fourier(img, bas, lm, spw_low, spw_high) + print(X1.shape) if corrupted: X1, X2 = calc_beam(X1, X2, rd, ra, dec, ant_diam, spw_low, spw_high) + + X1, X2 = calc_feed_rotation(X1, X2, bas.q1, bas.q2, polarisation) vis = integrate(X1, X2) return vis @@ -77,6 +94,37 @@ def calc_fourier(img, bas, lm, spw_low, spw_high): return img * K1, img * K2 +@torch.compile +def calc_feed_rotation(X1, X2, q1, q2, polarisation): + """ """ + P1 = torch.ones_like(X1) + P2 = torch.ones_like(X2) + + if polarisation == "linear": + P1[..., 0, 0] = torch.cos(q1) + P1[..., 0, 1] = torch.sin(q1) + P1[..., 1, 0] = -torch.sin(q1) + P1[..., 1, 1] = torch.cos(q1) + + P2[..., 0, 0] = torch.cos(q2) + P2[..., 0, 1] = torch.sin(q2) + P2[..., 1, 0] = -torch.sin(q2) + P2[..., 1, 1] = torch.cos(q2) + + if polarisation == "circular": + P1[..., 0, 0] = torch.exp(1j * q1) + P1[..., 0, 1] = 0 + P1[..., 1, 0] = 0 + P1[..., 1, 1] = torch.exp(-1j * q1) + + P2[..., 0, 0] = torch.exp(1j * q2) + P2[..., 0, 1] = 0 + P2[..., 1, 0] = 0 + P2[..., 1, 1] = torch.exp(-1j * q2) + + return img * P1, img * P2 + + @torch.compile def calc_beam(X1, X2, rd, ra, dec, ant_diam, spw_low, spw_high): diameters = ant_diam.to(rd.device) From cf80576bfd86e4b5a4ec31a562333e2d3cf83cc6 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Thu, 9 Jan 2025 18:00:57 +0100 Subject: [PATCH 40/52] Pass current polarisation from vis_loop to RIME --- pyvisgen/simulation/visibility.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 8b442ef..2b6018a 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -514,6 +514,7 @@ def _batch_loop( torch.unique(obs.array.diam), wave_low, wave_high, + obs.polarisation, corrupted=obs.corrupted, )[None] for wave_low, wave_high in zip(obs.waves_low, obs.waves_high) From c1ec878a45970b63384d32dcd468d2654993a5ec Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 10 Jan 2025 18:42:17 +0100 Subject: [PATCH 41/52] Restructure pyvisgen.simulation.observation.Observation --- pyvisgen/simulation/observation.py | 275 +++++++++++++++-------------- 1 file changed, 147 insertions(+), 128 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index 1eb6b3a..a27cce8 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -274,6 +274,7 @@ def __init__( self.num_scans = num_scans self.int_time = integration_time self.scan_separation = scan_separation + self.times, self.times_mjd = self.calc_time_steps() self.scans = torch.stack( torch.split( @@ -286,6 +287,7 @@ def __init__( self.ref_frequency = torch.tensor(ref_frequency) self.bandwidths = torch.tensor(bandwidths) self.frequency_offsets = torch.tensor(frequency_offsets) + self.waves_low = ( self.ref_frequency + self.frequency_offsets ) - self.bandwidths / 2 @@ -331,6 +333,30 @@ def __init__( self.pol_kwargs = pol_kwargs self.field_kwargs = field_kwargs + def calc_time_steps(self): + """Computes the time steps of the observation. + + Returns + ------- + time : array_like + Array of time steps. + time.mjd : array_like + Time steps in mjd format. + """ + time_lst = [ + self.start + + self.scan_separation * i * un.second + + i * self.scan_duration * un.second + + j * self.int_time * un.second + for i in range(self.num_scans) + for j in range(int(self.scan_duration / self.int_time) + 1) + ] + # +1 because t_1 is the stop time of t_0 + # in order to save computing power we take one time more to complete interval + time = Time(time_lst) + + return time, time.mjd * (60 * 60 * 24) + def calc_dense_baselines(self): N = self.img_size fov = self.fov * pi / (3600 * 180) @@ -394,19 +420,96 @@ def calc_baselines(self): bas = self.get_baselines(self.times[scan]) self.baselines.add_baseline(bas) - def calc_time_steps(self): - time_lst = [ - self.start - + self.scan_separation * i * un.second - + i * self.scan_duration * un.second - + j * self.int_time * un.second - for i in range(self.num_scans) - for j in range(int(self.scan_duration / self.int_time) + 1) - ] - # +1 because t_1 is the stop time of t_0 - # in order to save computing power we take one time more to complete interval - time = Time(time_lst) - return time, time.mjd * (60 * 60 * 24) + def get_baselines(self, times): + """Calculates baselines from source coordinates and time of observation for + every antenna station in array_layout. + + Parameters + ---------- + times : time object + time of observation + + Returns + ------- + dataclass object + baselines between telescopes with visibility flags + """ + # calculate GHA, local HA, and station elevation for all times. + GHA, ha_local, el_st_all = self.calc_ref_elev(time=times) + + self.el_st_all = el_st_all + + ar = Array(self.array) + delta_x, delta_y, delta_z = ar.calc_relative_pos + st_num_pairs, els_low_pairs, els_high_pairs = ar.calc_ant_pair_vals + + # Loop over ha and el_st + baselines = Baselines( + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + torch.tensor([]), + ) + q_all = self.calc_feed_rotation(ha_local) + q_comb = torch.vstack([torch.combinations(qi) for qi in q_all]) + q_comb = q_comb.reshape(-1, int(q_comb.shape[0] / times.shape[0]), 2) + + self.q_comb_l.append(q_comb) + self.q_all_l.append(q_all) + + print( + GHA.shape, + el_st_all.shape, + times.shape, + q_all.shape, + q_comb.shape, + ) + + self.GHA = GHA + self.delx = delta_x + self.dely = delta_y + self.delz = delta_z + + for ha, el_st, time, q, qc in zip(GHA, el_st_all, times, q_all, q_comb): + u, v, w = self.calc_direction_cosines(ha, el_st, delta_x, delta_y, delta_z) + + # calc current elevations + cur_el_st = torch.combinations(el_st) + + # calc valid baselines + m1 = (cur_el_st < els_low_pairs).any(axis=1) + m2 = (cur_el_st > els_high_pairs).any(axis=1) + + valid = torch.ones(u.shape).bool() + valid_mask = torch.logical_or(m1, m2) + valid[valid_mask] = False + + time_mjd = torch.repeat_interleave( + torch.tensor(time.mjd) * (24 * 60 * 60), len(valid) + ) + + # collect baselines + base = Baselines( + st_num_pairs[:, 0], + st_num_pairs[:, 1], + u, + v, + w, + valid, + time_mjd, + q, + qc[..., 0].ravel(), + qc[..., 1].ravel(), + ) + baselines.add_baseline(base) + + return baselines def calc_ref_elev(self, time=None): if time is None: @@ -446,6 +549,13 @@ def calc_ref_elev(self, time=None): ) ) assert len(GHA.value) == len(el_st_all) + + if not len(GHA.value) == len(el_st_all): + raise ValueError( + "Expected GHA and el_st_all to have the same length" + f"{len(GHA.value)} and {len(el_st_all)}" + ) + return torch.tensor(GHA.deg), ha_local, torch.tensor(el_st_all.alt.degree) def calc_feed_rotation(self, ha: Angle) -> Angle: @@ -478,14 +588,29 @@ def calc_feed_rotation(self, ha: Angle) -> Angle: return q - def test_active_telescopes(self): - _, el_st_0 = self.calc_ref_elev(self.times[0]) - _, el_st_1 = self.calc_ref_elev(self.times[1]) - el_min = 15 - el_max = 85 - active_telescopes_0 = np.sum((el_st_0 >= el_min) & (el_st_0 <= el_max)) - active_telescopes_1 = np.sum((el_st_1 >= el_min) & (el_st_1 <= el_max)) - return min(active_telescopes_0, active_telescopes_1) + def calc_direction_cosines(self, ha, el_st, delta_x, delta_y, delta_z): + src_dec = torch.deg2rad(self.dec) + ha = torch.deg2rad(ha) + + u = (torch.sin(ha) * delta_x + torch.cos(ha) * delta_y).reshape(-1) + v = ( + -torch.sin(src_dec) * torch.cos(ha) * delta_x + + torch.sin(src_dec) * torch.sin(ha) * delta_y + + torch.cos(src_dec) * delta_z + ).reshape(-1) + w = ( + torch.cos(src_dec) * torch.cos(ha) * delta_x + - torch.cos(src_dec) * torch.sin(ha) * delta_y + + torch.sin(src_dec) * delta_z + ).reshape(-1) + + if not (u.shape == v.shape == w.shape): + raise ValueError( + "u, v, w array shapes are not the same: " + f"{u.shape}, {v.shape}, {w.shape}" + ) + + return u, v, w def create_rd_grid(self): """Calculates RA and Dec values for a given fov around a source position @@ -542,7 +667,7 @@ def create_lm_grid(self): Returns ------- lm_grid : 3d array - Returns a 3d array with every pixel containing a l and m value + Returns a 3d array with every pixel containing an l and m value """ dec = torch.deg2rad(self.dec) @@ -553,109 +678,3 @@ def create_lm_grid(self): ) * torch.sin(dec) * torch.cos(self.rd[..., 0]) return lm_grid - - def get_baselines(self, times): - """Calculates baselines from source coordinates and time of observation for - every antenna station in array_layout. - - Parameters - ---------- - times : time object - time of observation - - Returns - ------- - dataclass object - baselines between telescopes with visibility flags - """ - # calculate GHA, local HA, and station elevation for all times. - GHA, ha_local, el_st_all = self.calc_ref_elev(time=times) - - self.el_st_all = el_st_all - - ar = Array(self.array) - delta_x, delta_y, delta_z = ar.calc_relative_pos - st_num_pairs, els_low_pairs, els_high_pairs = ar.calc_ant_pair_vals - - # Loop over ha and el_st - baselines = Baselines( - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - ) - q_all = self.calc_feed_rotation(ha_local) - q_comb = torch.vstack([torch.combinations(qi) for qi in q_all]) - q_comb = q_comb.reshape(-1, int(q_comb.shape[0] / times.shape[0]), 2) - - self.q_comb_l.append(q_comb) - self.q_all_l.append(q_all) - - print( - GHA.shape, - el_st_all.shape, - times.shape, - q_all.shape, - q_comb.shape, - ) - - self.GHA = GHA - self.delx = delta_x - self.dely = delta_y - self.delz = delta_z - - for ha, el_st, time, q, qc in zip(GHA, el_st_all, times, q_all, q_comb): - u, v, w = self.calc_direction_cosines(ha, el_st, delta_x, delta_y, delta_z) - - # calc current elevations - cur_el_st = torch.combinations(el_st) - - # calc valid baselines - valid = torch.ones(u.shape).bool() - m1 = (cur_el_st < els_low_pairs).any(axis=1) - m2 = (cur_el_st > els_high_pairs).any(axis=1) - valid_mask = torch.logical_or(m1, m2) - valid[valid_mask] = False - - time_mjd = torch.repeat_interleave( - torch.tensor(time.mjd) * (24 * 60 * 60), len(valid) - ) - - # collect baselines - base = Baselines( - st_num_pairs[:, 0], - st_num_pairs[:, 1], - u, - v, - w, - valid, - time_mjd, - q, - qc[..., 0].ravel(), - qc[..., 1].ravel(), - ) - baselines.add_baseline(base) - return baselines - - def calc_direction_cosines(self, ha, el_st, delta_x, delta_y, delta_z): - src_dec = torch.deg2rad(self.dec) - ha = torch.deg2rad(ha) - u = (torch.sin(ha) * delta_x + torch.cos(ha) * delta_y).reshape(-1) - v = ( - -torch.sin(src_dec) * torch.cos(ha) * delta_x - + torch.sin(src_dec) * torch.sin(ha) * delta_y - + torch.cos(src_dec) * delta_z - ).reshape(-1) - w = ( - torch.cos(src_dec) * torch.cos(ha) * delta_x - - torch.cos(src_dec) * torch.sin(ha) * delta_y - + torch.sin(src_dec) * delta_z - ).reshape(-1) - assert u.shape == v.shape == w.shape - return u, v, w From 9dc6667875fa767ca9f93755d2b4152e5f91fa57 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 14 Jan 2025 15:44:04 +0100 Subject: [PATCH 42/52] Fix valid q1, q2 computation in Baselines class --- pyvisgen/simulation/observation.py | 46 ++++++++++++++++-------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index a27cce8..9ad5d00 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -23,7 +23,6 @@ class Baselines: w: torch.tensor valid: torch.tensor time: torch.tensor - q_all: torch.tensor q1: torch.tensor q2: torch.tensor @@ -64,8 +63,14 @@ def get_valid_subset(self, num_baselines, device): v_valid = (v_start + v_stop) / 2 w_valid = (w_start + w_stop) / 2 - q1_valid = bas_reshaped.q1[mask].to(device) - q2_valid = bas_reshaped.q2[mask].to(device) + q1_start = bas_reshaped.q1[:-1][mask].to(device) + q2_start = bas_reshaped.q2[:-1][mask].to(device) + + q1_stop = bas_reshaped.q1[1:][mask].to(device) + q2_stop = bas_reshaped.q2[1:][mask].to(device) + + q1_valid = (q1_start + q1_stop) / 2 + q2_valid = (q2_start + q2_stop) / 2 t = Time(bas_reshaped.time / (60 * 60 * 24), format="mjd").jd date = (torch.from_numpy(t[:-1][mask] + t[1:][mask]) / 2).to(device) @@ -82,7 +87,11 @@ def get_valid_subset(self, num_baselines, device): w_valid, baseline_nums, date, + q1_start, + q1_stop, q1_valid, + q2_start, + q2_stop, q2_valid, ) @@ -100,7 +109,11 @@ class ValidBaselineSubset: w_valid: torch.tensor baseline_nums: torch.tensor date: torch.tensor + q1_start: torch.tensor + q1_stop: torch.tensor q1_valid: torch.tensor + q2_start: torch.tensor + q2_stop: torch.tensor q2_valid: torch.tensor def __getitem__(self, i): @@ -117,7 +130,11 @@ def __getitem__(self, i): self.w_valid, self.baseline_nums, self.date, + self.q1_start, + self.q1_stop, self.q1_valid, + self.q2_start, + self.q2_stop, self.q2_valid, ] ) @@ -129,6 +146,8 @@ def get_timerange(self, t_start, t_stop): def get_unique_grid(self, fov_size, ref_frequency, img_size, device): uv = torch.cat([self.u_valid[None], self.v_valid[None]], dim=0) + q = torch.cat([self.q1_valid[None], self.q2_valid[None]], dim=0) + fov = fov_size * pi / (3600 * 180) delta = 1 / fov * const.c.value.item() / ref_frequency bins = ( @@ -140,8 +159,10 @@ def get_unique_grid(self, fov_size, ref_frequency, img_size, device): ) + delta / 2 ) + if len(bins) - 1 > img_size: bins = bins[:-1] + indices_bucket = torch.bucketize(uv, bins) indices_bucket_sort, indices_bucket_inv = self._lexsort(indices_bucket) indices_unique, indices_unique_inv, counts = torch.unique_consecutive( @@ -409,7 +430,6 @@ def calc_baselines(self): torch.tensor([]), torch.tensor([]), torch.tensor([]), - torch.tensor([]), ) self.q_comb_l = [] self.q_all_l = [] @@ -454,28 +474,11 @@ def get_baselines(self, times): torch.tensor([]), torch.tensor([]), torch.tensor([]), - torch.tensor([]), ) q_all = self.calc_feed_rotation(ha_local) q_comb = torch.vstack([torch.combinations(qi) for qi in q_all]) q_comb = q_comb.reshape(-1, int(q_comb.shape[0] / times.shape[0]), 2) - self.q_comb_l.append(q_comb) - self.q_all_l.append(q_all) - - print( - GHA.shape, - el_st_all.shape, - times.shape, - q_all.shape, - q_comb.shape, - ) - - self.GHA = GHA - self.delx = delta_x - self.dely = delta_y - self.delz = delta_z - for ha, el_st, time, q, qc in zip(GHA, el_st_all, times, q_all, q_comb): u, v, w = self.calc_direction_cosines(ha, el_st, delta_x, delta_y, delta_z) @@ -503,7 +506,6 @@ def get_baselines(self, times): w, valid, time_mjd, - q, qc[..., 0].ravel(), qc[..., 1].ravel(), ) From 0b88cb76b3e5a6765cdca48db2d4f4438da27cf3 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 14 Jan 2025 15:51:55 +0100 Subject: [PATCH 43/52] Remove debug output, set defaults --- pyvisgen/simulation/observation.py | 43 +++++++++++++----------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index 9ad5d00..2a423fe 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -14,6 +14,20 @@ from pyvisgen.simulation.array import Array +DEFAULT_POL_KWARGS = { + "delta": 0, + "amp_ratio": 0.5, + "random_state": 42, +} + +DEFAULT_FIELD_KWARGS = { + "order": [1, 1], + "scale": [0, 1], + "threshold": None, + "random_state": 42, +} + + @dataclass class Baselines: st1: torch.tensor @@ -146,7 +160,6 @@ def get_timerange(self, t_start, t_stop): def get_unique_grid(self, fov_size, ref_frequency, img_size, device): uv = torch.cat([self.u_valid[None], self.v_valid[None]], dim=0) - q = torch.cat([self.q1_valid[None], self.q2_valid[None]], dim=0) fov = fov_size * pi / (3600 * 180) delta = 1 / fov * const.c.value.item() / ref_frequency @@ -207,17 +220,8 @@ def __init__( dense: bool = False, sensitivity_cut: float = 1e-6, polarisation: str = None, - pol_kwargs: dict = { - "delta": 0, - "amp_ratio": 0.5, - "random_state": 42, - }, - field_kwargs: dict = { - "order": [1, 1], - "scale": [0, 1], - "threshold": None, - "random_state": 42, - }, + pol_kwargs: dict = DEFAULT_POL_KWARGS, + field_kwargs: dict = DEFAULT_FIELD_KWARGS, ) -> None: """Sets up the observation class. @@ -431,10 +435,6 @@ def calc_baselines(self): torch.tensor([]), torch.tensor([]), ) - self.q_comb_l = [] - self.q_all_l = [] - - self.n = 1 for scan in self.scans: bas = self.get_baselines(self.times[scan]) @@ -457,8 +457,6 @@ def get_baselines(self, times): # calculate GHA, local HA, and station elevation for all times. GHA, ha_local, el_st_all = self.calc_ref_elev(time=times) - self.el_st_all = el_st_all - ar = Array(self.array) delta_x, delta_y, delta_z = ar.calc_relative_pos st_num_pairs, els_low_pairs, els_high_pairs = ar.calc_ant_pair_vals @@ -499,8 +497,8 @@ def get_baselines(self, times): # collect baselines base = Baselines( - st_num_pairs[:, 0], - st_num_pairs[:, 1], + st_num_pairs[..., 0], + st_num_pairs[..., 1], u, v, w, @@ -525,7 +523,6 @@ def calc_ref_elev(self, time=None): GHA = Angle( [t.sidereal_time("apparent", "greenwich") - src_crd.ra for t in time] ) - self.ha_all = GHA # calculate local sidereal time and HA at each antenna lst = un.Quantity( @@ -550,8 +547,6 @@ def calc_ref_elev(self, time=None): ), ) ) - assert len(GHA.value) == len(el_st_all) - if not len(GHA.value) == len(el_st_all): raise ValueError( "Expected GHA and el_st_all to have the same length" @@ -608,7 +603,7 @@ def calc_direction_cosines(self, ha, el_st, delta_x, delta_y, delta_z): if not (u.shape == v.shape == w.shape): raise ValueError( - "u, v, w array shapes are not the same: " + "Expected u, v, and w to have the same shapes: " f"{u.shape}, {v.shape}, {w.shape}" ) From 522337566a18a4d4ce03bde474253de7a23ead2d Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 17 Jan 2025 10:09:22 +0100 Subject: [PATCH 44/52] Add optional progress bar --- pyvisgen/simulation/observation.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index 2a423fe..e49d45a 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -4,11 +4,11 @@ import astropy.constants as const import astropy.units as un -import numpy as np import torch from astropy.constants import c from astropy.coordinates import AltAz, Angle, EarthLocation, SkyCoord, Longitude from astropy.time import Time +from tqdm import tqdm from pyvisgen.layouts import layouts from pyvisgen.simulation.array import Array @@ -189,6 +189,7 @@ def get_unique_grid(self, fov_size, ref_frequency, img_size, device): cum_sum = counts.cumsum(0) cum_sum = torch.cat((torch.tensor([0], device=device), cum_sum[:-1])) first_indices = ind_sorted[cum_sum] + return self[:][:, indices_bucket_sort[first_indices]] def _lexsort(self, a, dim=-1): @@ -222,6 +223,7 @@ def __init__( polarisation: str = None, pol_kwargs: dict = DEFAULT_POL_KWARGS, field_kwargs: dict = DEFAULT_FIELD_KWARGS, + show_progress: bool = False, ) -> None: """Sets up the observation class. @@ -283,6 +285,9 @@ def __init__( "threshold": None, "random_state": 42 }` + show_progress : bool, optional + If `True`, show a progress bar during the iteration over the + scans. Default: False Notes ----- @@ -337,6 +342,8 @@ def __init__( len(self.array.st_num) * (len(self.array.st_num) - 1) / 2 ) + self.show_progress = show_progress + if dense: self.waves_low = [self.ref_frequency] self.waves_high = [self.ref_frequency] @@ -424,6 +431,10 @@ def calc_dense_baselines(self): ) def calc_baselines(self): + """Initializes Baselines dataclass object and + calls self.get_baselines to compute the contents of + the Baselines dataclass. + """ self.baselines = Baselines( torch.tensor([]), torch.tensor([]), @@ -436,6 +447,9 @@ def calc_baselines(self): torch.tensor([]), ) + if self.show_progress: + self.scans = tqdm(self.scans) + for scan in self.scans: bas = self.get_baselines(self.times[scan]) self.baselines.add_baseline(bas) @@ -553,7 +567,11 @@ def calc_ref_elev(self, time=None): f"{len(GHA.value)} and {len(el_st_all)}" ) - return torch.tensor(GHA.deg), ha_local, torch.tensor(el_st_all.alt.degree) + return ( + torch.tensor(GHA.deg), + torch.tensor(ha_local), + torch.tensor(el_st_all.alt.degree), + ) def calc_feed_rotation(self, ha: Angle) -> Angle: r"""Calculates feed rotation for every antenna at every time step. From 84fbdeca37ba38e84780c88ba09289f5b68744e1 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 17 Jan 2025 10:09:54 +0100 Subject: [PATCH 45/52] Fix feed rotation computation --- pyvisgen/simulation/scan.py | 46 ++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/pyvisgen/simulation/scan.py b/pyvisgen/simulation/scan.py index d3077c6..2c7dff4 100644 --- a/pyvisgen/simulation/scan.py +++ b/pyvisgen/simulation/scan.py @@ -43,11 +43,11 @@ def rime( """ with torch.no_grad(): X1, X2 = calc_fourier(img, bas, lm, spw_low, spw_high) - print(X1.shape) + if corrupted: X1, X2 = calc_beam(X1, X2, rd, ra, dec, ant_diam, spw_low, spw_high) - X1, X2 = calc_feed_rotation(X1, X2, bas.q1, bas.q2, polarisation) + X1, X2 = calc_feed_rotation(X1, X2, bas, polarisation) vis = integrate(X1, X2) return vis @@ -79,7 +79,7 @@ def calc_fourier(img, bas, lm, spw_low, spw_high): v_cmplt = torch.cat((bas[3], bas[4])) w_cmplt = torch.cat((bas[6], bas[7])) - l = lm[..., 0] + l = lm[..., 0] # noqa: E741 m = lm[..., 1] n = torch.sqrt(1 - l**2 - m**2) @@ -95,34 +95,34 @@ def calc_fourier(img, bas, lm, spw_low, spw_high): @torch.compile -def calc_feed_rotation(X1, X2, q1, q2, polarisation): +def calc_feed_rotation(X1, X2, bas, polarisation): """ """ - P1 = torch.ones_like(X1) - P2 = torch.ones_like(X2) + q1 = torch.cat((bas[11], bas[12]))[..., None] + q2 = torch.cat((bas[14], bas[15]))[..., None] if polarisation == "linear": - P1[..., 0, 0] = torch.cos(q1) - P1[..., 0, 1] = torch.sin(q1) - P1[..., 1, 0] = -torch.sin(q1) - P1[..., 1, 1] = torch.cos(q1) + X1[..., 0, 0] *= torch.cos(q1) + X1[..., 0, 1] *= torch.sin(q1) + X1[..., 1, 0] *= -torch.sin(q1) + X1[..., 1, 1] *= torch.cos(q1) - P2[..., 0, 0] = torch.cos(q2) - P2[..., 0, 1] = torch.sin(q2) - P2[..., 1, 0] = -torch.sin(q2) - P2[..., 1, 1] = torch.cos(q2) + X2[..., 0, 0] *= torch.cos(q2) + X2[..., 0, 1] *= torch.sin(q2) + X2[..., 1, 0] *= -torch.sin(q2) + X2[..., 1, 1] *= torch.cos(q2) if polarisation == "circular": - P1[..., 0, 0] = torch.exp(1j * q1) - P1[..., 0, 1] = 0 - P1[..., 1, 0] = 0 - P1[..., 1, 1] = torch.exp(-1j * q1) + X1[..., 0, 0] *= torch.exp(1j * q1) + X1[..., 0, 1] *= 0 + X1[..., 1, 0] *= 0 + X1[..., 1, 1] *= torch.exp(-1j * q1) - P2[..., 0, 0] = torch.exp(1j * q2) - P2[..., 0, 1] = 0 - P2[..., 1, 0] = 0 - P2[..., 1, 1] = torch.exp(-1j * q2) + X2[..., 0, 0] *= torch.exp(1j * q2) + X2[..., 0, 1] *= 0 + X2[..., 1, 0] *= 0 + X2[..., 1, 1] *= torch.exp(-1j * q2) - return img * P1, img * P2 + return X1, X2 @torch.compile From cb1c87de0b1cae241dcce70b41570c4a896dde50 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 21 Jan 2025 14:46:46 +0100 Subject: [PATCH 46/52] Fix batch size, tests --- pyvisgen/simulation/visibility.py | 11 +++++++++-- tests/test_simulation.py | 8 ++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index dd42ec4..04cd6c7 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -4,7 +4,6 @@ import scipy.ndimage import torch import toma -import scipy.ndimage import pyvisgen.simulation.scan as scan @@ -333,7 +332,7 @@ def vis_loop( num_threads: int = 10, noisy: bool = True, mode: str = "full", - batch_size: int = 100, + batch_size: int = "auto", show_progress: bool = False, ) -> Visibilities: r"""Computes the visibilities of an observation. @@ -376,6 +375,12 @@ def vis_loop( torch.set_num_threads(num_threads) torch._dynamo.config.suppress_errors = True + if not ( + isinstance(batch_size, int) + or (isinstance(batch_size, str) and batch_size == "auto") + ): + raise ValueError("Expected batch_size to be 'auto' or type int") + pol = Polarisation( torch.flip(SI, dims=[1]), sensitivity_cut=obs.sensitivity_cut, @@ -409,7 +414,9 @@ def vis_loop( torch.tensor([]), torch.tensor([]), ) + vis_num = torch.zeros(1) + if mode == "full": bas = obs.baselines.get_valid_subset(obs.num_baselines, obs.device) elif mode == "grid": diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 52d54fc..c384951 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -100,10 +100,10 @@ def test_vis_loop_batch_size_auto(): batch_size="auto", ) - assert (vis_data[0].SI[0]).dtype == torch.complex128 - assert (vis_data[0].SQ[0]).dtype == torch.complex128 - assert (vis_data[0].SU[0]).dtype == torch.complex128 - assert (vis_data[0].SV[0]).dtype == torch.complex128 + assert (vis_data[0].V_11[0]).dtype == torch.complex128 + assert (vis_data[0].V_22[0]).dtype == torch.complex128 + assert (vis_data[0].V_12[0]).dtype == torch.complex128 + assert (vis_data[0].V_21[0]).dtype == torch.complex128 assert (vis_data[0].num).dtype == torch.float32 assert (vis_data[0].base_num).dtype == torch.float64 assert torch.is_tensor(vis_data[0].u) From 4ac6f0b35de26f6628d45608c19218a9a6a0bfaa Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Wed, 22 Jan 2025 10:32:33 +0100 Subject: [PATCH 47/52] Add comments --- pyvisgen/simulation/observation.py | 40 +++++++++++++++--------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index e49d45a..7f75cea 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -436,15 +436,15 @@ def calc_baselines(self): the Baselines dataclass. """ self.baselines = Baselines( - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), + torch.tensor([]), # st1 + torch.tensor([]), # st2 + torch.tensor([]), # u + torch.tensor([]), # v + torch.tensor([]), # w + torch.tensor([]), # valid + torch.tensor([]), # time + torch.tensor([]), # q1 + torch.tensor([]), # q2 ) if self.show_progress: @@ -475,22 +475,22 @@ def get_baselines(self, times): delta_x, delta_y, delta_z = ar.calc_relative_pos st_num_pairs, els_low_pairs, els_high_pairs = ar.calc_ant_pair_vals - # Loop over ha and el_st baselines = Baselines( - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), - torch.tensor([]), + torch.tensor([]), # st1 + torch.tensor([]), # st2 + torch.tensor([]), # u + torch.tensor([]), # v + torch.tensor([]), # w + torch.tensor([]), # valid + torch.tensor([]), # time + torch.tensor([]), # q1 + torch.tensor([]), # q2 ) q_all = self.calc_feed_rotation(ha_local) q_comb = torch.vstack([torch.combinations(qi) for qi in q_all]) q_comb = q_comb.reshape(-1, int(q_comb.shape[0] / times.shape[0]), 2) + # Loop over ha, el_st, times, parallactic angles for ha, el_st, time, q, qc in zip(GHA, el_st_all, times, q_all, q_comb): u, v, w = self.calc_direction_cosines(ha, el_st, delta_x, delta_y, delta_z) @@ -583,7 +583,7 @@ def calc_feed_rotation(self, ha: Angle) -> Angle: .. math:: - q = \frac{\sin h}{\cos\delta \tan\varphi - \sin\delta \cos h, + q = \atan\left(\frac{\sin h}{\cos\delta \tan\varphi - \sin\delta \cos h\right), where $h$ is the local hour angle, $\varphi$ the geographical latitude of the observer, and $\delta$ the declination of the source. From 8c8dc842fa9bc61040c2b803d556a7681afd8ede Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 24 Jan 2025 14:57:56 +0100 Subject: [PATCH 48/52] Use tqdm.autonotebook, fix tensor --- pyvisgen/simulation/observation.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index 7f75cea..2f3f673 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -8,7 +8,7 @@ from astropy.constants import c from astropy.coordinates import AltAz, Angle, EarthLocation, SkyCoord, Longitude from astropy.time import Time -from tqdm import tqdm +from tqdm.autonotebook import tqdm from pyvisgen.layouts import layouts from pyvisgen.simulation.array import Array @@ -447,8 +447,11 @@ def calc_baselines(self): torch.tensor([]), # q2 ) - if self.show_progress: - self.scans = tqdm(self.scans) + self.scans = tqdm( + self.scans, + disable=not self.show_progress, + desc="Computing scans", + ) for scan in self.scans: bas = self.get_baselines(self.times[scan]) @@ -569,7 +572,7 @@ def calc_ref_elev(self, time=None): return ( torch.tensor(GHA.deg), - torch.tensor(ha_local), + ha_local, torch.tensor(el_st_all.alt.degree), ) From a97fa5801e54f3de43c54e53512e2699dcafd1b6 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 24 Jan 2025 15:00:14 +0100 Subject: [PATCH 49/52] Handle mode='dense' for parallactic angle computation --- pyvisgen/simulation/scan.py | 5 ++++- pyvisgen/simulation/visibility.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pyvisgen/simulation/scan.py b/pyvisgen/simulation/scan.py index 2c7dff4..a4a7129 100644 --- a/pyvisgen/simulation/scan.py +++ b/pyvisgen/simulation/scan.py @@ -17,6 +17,7 @@ def rime( spw_low, spw_high, polarisation, + mode, corrupted=False, ): """Calculates visibilities using RIME @@ -44,10 +45,12 @@ def rime( with torch.no_grad(): X1, X2 = calc_fourier(img, bas, lm, spw_low, spw_high) + if mode != "dense": + X1, X2 = calc_feed_rotation(X1, X2, bas, polarisation) + if corrupted: X1, X2 = calc_beam(X1, X2, rd, ra, dec, ant_diam, spw_low, spw_high) - X1, X2 = calc_feed_rotation(X1, X2, bas, polarisation) vis = integrate(X1, X2) return vis diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 04cd6c7..9d7dc52 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -446,6 +446,7 @@ def vis_loop( rd, noisy, show_progress, + mode, ) visibilities.linear_dop = lin_dop.cpu() @@ -465,6 +466,7 @@ def _batch_loop( rd: torch.tensor, noisy: bool | float, show_progress: bool, + mode: str, ): """Main simulation loop of pyvisgen. Computes visibilities batchwise. @@ -523,6 +525,7 @@ def _batch_loop( wave_low, wave_high, obs.polarisation, + mode=mode, corrupted=obs.corrupted, )[None] for wave_low, wave_high in zip(obs.waves_low, obs.waves_high) From 08d08cb7518d04c46cee79a8945f54b95fdea567 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 24 Jan 2025 15:14:48 +0100 Subject: [PATCH 50/52] Update changelogs --- docs/changes/39.bugfix.rst | 6 +++--- docs/changes/39.feature.rst | 9 +++++---- docs/changes/39.maintenance.rst | 8 ++++---- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/docs/changes/39.bugfix.rst b/docs/changes/39.bugfix.rst index 67473db..4c3f16d 100644 --- a/docs/changes/39.bugfix.rst +++ b/docs/changes/39.bugfix.rst @@ -1,4 +1,4 @@ -- Fix gridding in `pyvisgen.simulation.observation.Observation` methods `create_rd_grid` and `create_lm_grid` +- Fixed gridding in `pyvisgen.simulation.observation.Observation` methods `create_rd_grid` and `create_lm_grid` methods resulting in rotated images -- Fix `pyvisgen.simulation.observation.ValidBaselineSubset` dataclass field order -- Fix test failing because of api change +- Fixed `pyvisgen.simulation.observation.ValidBaselineSubset` dataclass field order +- Fixed tests failing because of api change diff --git a/docs/changes/39.feature.rst b/docs/changes/39.feature.rst index d2cda72..dcce261 100644 --- a/docs/changes/39.feature.rst +++ b/docs/changes/39.feature.rst @@ -1,4 +1,5 @@ -- Add class `Polarisation` to `pyvisgen.simulation.visibility` that is called in `vis_loop` - - Add linear, circular, and no polarisation options -- Update `pyvisgen.simulation.visibility.Visibilities` dataclass to also store polarisation degree tensors -- Add keyword arguments for polarisation simulation to `pyvisgen.simulation.observation.Observation` class +- Added class `Polarisation` to `pyvisgen.simulation.visibility` that is called in `vis_loop` + - Added linear, circular, and no polarisation options +- Updated `pyvisgen.simulation.visibility.Visibilities` dataclass to also store polarisation degree tensors +- Added keyword arguments for polarisation simulation to `pyvisgen.simulation.observation.Observation` class +- Added parallactic angle computation diff --git a/docs/changes/39.maintenance.rst b/docs/changes/39.maintenance.rst index 83a2365..16beef1 100644 --- a/docs/changes/39.maintenance.rst +++ b/docs/changes/39.maintenance.rst @@ -1,6 +1,6 @@ -- Change pyvisgen.simulation.visibility.Visibilities dataclass component names from stokes components (I , Q, U, and V) +- Changed pyvisgen.simulation.visibility.Visibilities dataclass component names from stokes components (I , Q, U, and V) to visibilities constructed from the stokes components (`V_11`, `V_22`, `V_12`, `V_21`) -- Change indices for stokes components according to AIPS Memo 114 +- Changed indices for stokes components according to AIPS Memo 114 - Indices will be set automatically depending on simulated polarisation -- Update comment strings in FITS files -- Update docstrings accordingly in `pyvisgen.simulation.visibility.vis_loop` and `pyvisgen.simulation.observation.Observation` +- Updated comment strings in FITS files +- Updated docstrings accordingly in `pyvisgen.simulation.visibility.vis_loop` and `pyvisgen.simulation.observation.Observation` From 7a6dace3052783a95e2cf4f65890d19c1a91e5a4 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 24 Jan 2025 21:45:22 +0100 Subject: [PATCH 51/52] Fix tests --- pyvisgen/simulation/visibility.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 2aa9fd7..fb15eb0 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -1,9 +1,9 @@ from dataclasses import dataclass, fields -from tqdm.autonotebook import tqdm import scipy.ndimage -import torch import toma +import torch +from tqdm.autonotebook import tqdm import pyvisgen.simulation.scan as scan @@ -334,6 +334,7 @@ def vis_loop( mode: str = "full", batch_size: int = "auto", show_progress: bool = False, + normalize: bool = True, ) -> Visibilities: r"""Computes the visibilities of an observation. @@ -366,6 +367,9 @@ def vis_loop( show_progress : bool, optional If `True`, show a progress bar during the iteration over the batches of baselines. Default: False + normalize : bool, optional + If ``True``, normalize stokes matrix ``B`` by a factor 0.5. + Default: ``True`` Returns ------- From 2921c8d8e0113ac929756d4421fdc489ee9e387c Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Mon, 27 Jan 2025 09:48:19 +0100 Subject: [PATCH 52/52] Change changelog to present tense --- docs/changes/39.bugfix.rst | 6 +++--- docs/changes/39.feature.rst | 8 ++++---- docs/changes/39.maintenance.rst | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/changes/39.bugfix.rst b/docs/changes/39.bugfix.rst index 4c3f16d..cc5a43d 100644 --- a/docs/changes/39.bugfix.rst +++ b/docs/changes/39.bugfix.rst @@ -1,4 +1,4 @@ -- Fixed gridding in `pyvisgen.simulation.observation.Observation` methods `create_rd_grid` and `create_lm_grid` +- Fix gridding in `pyvisgen.simulation.observation.Observation` methods `create_rd_grid` and `create_lm_grid` methods resulting in rotated images -- Fixed `pyvisgen.simulation.observation.ValidBaselineSubset` dataclass field order -- Fixed tests failing because of api change +- Fix `pyvisgen.simulation.observation.ValidBaselineSubset` dataclass field order +- Fix tests failing because of api change diff --git a/docs/changes/39.feature.rst b/docs/changes/39.feature.rst index dcce261..d9d459a 100644 --- a/docs/changes/39.feature.rst +++ b/docs/changes/39.feature.rst @@ -1,5 +1,5 @@ -- Added class `Polarisation` to `pyvisgen.simulation.visibility` that is called in `vis_loop` +- Add class `Polarisation` to `pyvisgen.simulation.visibility` that is called in `vis_loop` - Added linear, circular, and no polarisation options -- Updated `pyvisgen.simulation.visibility.Visibilities` dataclass to also store polarisation degree tensors -- Added keyword arguments for polarisation simulation to `pyvisgen.simulation.observation.Observation` class -- Added parallactic angle computation +- Update `pyvisgen.simulation.visibility.Visibilities` dataclass to also store polarisation degree tensors +- Add keyword arguments for polarisation simulation to `pyvisgen.simulation.observation.Observation` class +- Add parallactic angle computation diff --git a/docs/changes/39.maintenance.rst b/docs/changes/39.maintenance.rst index 16beef1..83a2365 100644 --- a/docs/changes/39.maintenance.rst +++ b/docs/changes/39.maintenance.rst @@ -1,6 +1,6 @@ -- Changed pyvisgen.simulation.visibility.Visibilities dataclass component names from stokes components (I , Q, U, and V) +- Change pyvisgen.simulation.visibility.Visibilities dataclass component names from stokes components (I , Q, U, and V) to visibilities constructed from the stokes components (`V_11`, `V_22`, `V_12`, `V_21`) -- Changed indices for stokes components according to AIPS Memo 114 +- Change indices for stokes components according to AIPS Memo 114 - Indices will be set automatically depending on simulated polarisation -- Updated comment strings in FITS files -- Updated docstrings accordingly in `pyvisgen.simulation.visibility.vis_loop` and `pyvisgen.simulation.observation.Observation` +- Update comment strings in FITS files +- Update docstrings accordingly in `pyvisgen.simulation.visibility.vis_loop` and `pyvisgen.simulation.observation.Observation`