From be6d2edec55b5b0cc989b504f0a046d959da8254 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 13 Aug 2024 15:47:40 +0200 Subject: [PATCH 01/15] Add first implementation of polarisation simulation in the vis loop --- pyvisgen/simulation/visibility.py | 383 ++++++++++++++++++++++++++++-- 1 file changed, 358 insertions(+), 25 deletions(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 09f6896..e817857 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -1,7 +1,8 @@ from dataclasses import dataclass, fields +from tqdm import tqdm import torch -from tqdm import tqdm +import scipy.ndimage import pyvisgen.simulation.scan as scan @@ -18,6 +19,8 @@ class Visibilities: v: torch.tensor w: torch.tensor date: torch.tensor + linear_dop: torch.tensor + circular_dop: torch.tensor def __getitem__(self, i): return Visibilities(*[getattr(self, f.name)[i] for f in fields(self)]) @@ -38,35 +41,359 @@ def add(self, visibilities): ] +class Polarisation: + """Simulation of polarisation.""" + + def __init__( + self, + SI: torch.tensor, + sensitivity_cut: float, + amp_ratio: float, + delta: float, + polarisation: str, + random_state: int, + device: str, + ) -> None: + """Creates the 2 x 2 stokes matrix and simulates + polarisation if `polarisation` is either 'linear' + or 'circular'. Also computes the degree of polarisation. + + Parameters + ---------- + SI : torch.tensor + Stokes I component, i.e. intensity distribution + of the sky. + sensitivity_cut : float + Sensitivity cut, where only pixels above the value + are kept. + amp_ratio : float + Sets the ratio of $A_{x/r}$. The ratio of $A_{y/l}$ is calculated + as `1 - amp_ratio`. If set to `None`, a random value is drawn + from a uniform distribution. See also: `random_state`. + delta : float + Sets the phase difference of the amplitudes $A_x$ and $A_y$ + of the sky distribution. Defines the measure of ellipticity. + polarisation : str + Choose between `'linear'` or `'circular'` or `None` to + simulate different types of polarisations or disable + the simulation of polarisation. + random_state : int + Random state used when drawing `amp_ratio` and during the generation + of the random polarisation field. + device : str + Torch device to select for computation. + """ + self.sensitivity_cut = sensitivity_cut + self.polarisation = polarisation + self.device = device + + self.SI = SI.permute(dims=(1, 2, 0)) + + if random_state: + torch.manual_seed(random_state) + else: + torch.seed() + + if self.polarisation: + self.polarisation_field = self.rand_polarisation_field( + [SI.shape[0], SI.shape[1]], + random_state=random_state, + order=[1, 1], + scale=[0, 1], + threshold=None, + ) + + self.delta = delta + + if amp_ratio and amp_ratio >= 0: + ax2 = amp_ratio + else: + ax2 = torch.rand(1) + + ay2 = 1 - ax2 + + self.ax2 = self.SI[..., 0] * ax2 + self.ay2 = self.SI[..., 0] * ay2 + + self.I = torch.zeros( + (self.SI.shape[0], self.SI.shape[1], 4), dtype=torch.cdouble + ) # noqa: E741 + + def linear(self) -> None: + """Computes the stokes parameters I, Q, U, and V + for linear polarisation. + + .. math:: + I = A_x^2 + A_y^2 + Q = A_r^2 - A_l^2 + U = 2A_x A_y \cos\delta_{xy} + V = -2A_x A_y \sin\delta_{xy} + """ + self.I[..., 0] = self.ax2 + self.ay2 + self.I[..., 1] = self.ax2 - self.ay2 + self.I[..., 2] = ( + 2 + * torch.sqrt(self.ax2) + * torch.sqrt(self.ay2) + * torch.cos(torch.deg2rad(torch.tensor(self.delta))) + ) + self.I[..., 3] = ( + -2 + * torch.sqrt(self.ax2) + * torch.sqrt(self.ay2) + * torch.sin(torch.deg2rad(torch.tensor(self.delta))) + ) + + def circular(self) -> None: + """Computes the stokes parameters I, Q, U, and V + for circular polarisation. + + .. math:: + I = A_r^2 + A_l^2 + Q = 2A_r A_l \cos\delta_{rl} + U = -2A_r A_l \sin\delta_{rl} + V = A_r^2 - A_l^2 + """ + self.I[..., 0] = self.ax2 + self.ay2 + self.I[..., 1] = ( + 2 + * torch.sqrt(self.ax2) + * torch.sqrt(self.ay2) + * torch.cos(torch.deg2rad(torch.tensor(self.delta))) + ) + self.I[..., 2] = ( + -2 + * torch.sqrt(self.ax2) + * torch.sqrt(self.ay2) + * torch.sin(torch.deg2rad(torch.tensor(self.delta))) + ) + self.I[..., 3] = self.ax2 - self.ay2 + + def dop(self) -> None: + """Computes the degree of polarisation for each pixel.""" + mask = (self.ax2 + self.ay2) > 0 + + # apply polarisation_field to Q, U, and V only + self.I[..., 1] *= self.polarisation_field + self.I[..., 2] *= self.polarisation_field + self.I[..., 3] *= self.polarisation_field + + dop_I = self.I[..., 0].clone() + dop_I[~mask] = float("nan") + dop_Q = self.I[..., 1].clone() + dop_Q[~mask] = float("nan") + dop_U = self.I[..., 2].clone() + dop_U[~mask] = float("nan") + dop_V = self.I[..., 3].clone() + dop_V[~mask] = float("nan") + + self.lin_dop = torch.sqrt(dop_Q**2 + dop_U**2) / dop_I + self.circ_dop = torch.abs(dop_V) / dop_I + + del dop_I, dop_Q, dop_U, dop_V + + def stokes_matrix(self) -> tuple: + """Computes and returns the 2 x 2 stokes matrix B. + + Returns + ------- + B : torch.tensor + 2 x 2 stokes brightness matrix. Either for linear, + circular or no polarisation. + mask : torch.tensor + Mask of the sensitivity cut (Keep all px > sensitivity_cut). + lin_dop : torch.tensor + Degree of linear polarisation of every pixel in the sky. + circ_dop : torch.tensor + Degree of circular polarisation of every pixel in the sky. + """ + # define 2 x 2 Stokes matrix + B = torch.zeros( + (self.SI.shape[0], self.SI.shape[1], 2, 2), dtype=torch.cdouble + ).to(torch.device(self.device)) + + if self.polarisation == "linear": + self.linear() + self.dop() + + B[..., 0, 0] = self.I[..., 0] + self.I[..., 3] + B[..., 0, 1] = self.I[..., 1] + 1j * self.I[..., 2] + B[..., 1, 0] = self.I[..., 1] - 1j * self.I[..., 2] + B[..., 1, 1] = self.I[..., 0] - self.I[..., 3] + + elif self.polarisation == "circular": + self.circular() + self.dop() + + B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] + B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] + B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] + B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] + + else: + # No polarisation applied + self.I[..., 0] = self.SI[..., 0] + self.I[..., 1] = self.SI[..., 0] + self.I[..., 2] = self.SI[..., 0] + self.I[..., 3] = self.SI[..., 0] + + B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] + B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] + B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] + B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] + + # calculations only for px > sensitivity cut + mask = (self.SI >= self.sensitivity_cut)[..., 0] + B = B[mask] + + return B, mask, self.lin_dop, self.circ_dop + + def rand_polarisation_field( + self, + shape: list[int, int] | int, + order: list[int, int] | int = 1, + random_state: int = None, + scale: list = [0, 1], + threshold: float = None, + ) -> torch.tensor: + """ + Generates a random noise mask for polarisation. + + Parameters + ---------- + shape : array_like (M, N), or int + The size of the sky image. + order : array_like (M, N) or int, optional + Morphology of the random noise. Higher values create + more and smaller fluctuations. Default: 1. + random_state : int, optional + Random state for the random number generator. If None, + a random entropy is pulled from the OS. Default: None. + scale : array_like, optional + Scaling of the distribution of the image. Default: [0, 1] + threshold : float, optional + If not None, an upper threshold is applied to the image. + Default: None + + Returns + ------- + im : torch.tensor + An array containing random noise values between + scale[0] and scale[1]. + """ + if random_state: + torch.random.manual_seed(random_state) + + if not isinstance(shape, list): + shape = list(shape) + + if len(shape) < 2: + shape *= 2 + elif len(shape) > 2: + raise ValueError("Only 2d shapes are allowed!") + + if not isinstance(order, list): + order = list(order) + + if len(order) < 2: + order *= 2 + elif len(order) > 2: + raise ValueError("Only 2d shapes are allowed!") + + sigma = torch.mean(torch.tensor(shape).double()) / (40 * torch.tensor(order)) + + im = torch.rand(shape) + im = scipy.ndimage.gaussian_filter(im, sigma=sigma.numpy()) + + if scale is None: + scale = [im.min(), im.max()] + + im_flatten = torch.from_numpy(im.flatten()) + im_argsort = torch.argsort(torch.argsort(im_flatten)) + im_linspace = torch.linspace(*scale, im_argsort.size()[0]) + uniform_flatten = im_linspace[im_argsort] + + im = torch.reshape(uniform_flatten, im.shape) + + if threshold: + im = im < threshold + + return im + + def vis_loop( - obs, - SI, - num_threads=10, - noisy=True, - mode="full", - batch_size=100, - show_progress=False, -): + obs: "Observation", + SI: torch.tensor, + num_threads: int = 10, + noisy: bool = True, + mode: str = "full", + batch_size: int = 100, + polarisation: str = "linear", + delta: float = 0, + amp_ratio: float = None, + random_state: int = 42, + show_progress: bool = False, +) -> Visibilities: + r"""Computes the visibilities of an observation. + + Parameters + ---------- + obs : Observation class object + Observation class object generated by the + `~pyvisgen.simulation.Observation` class. + SI : torch.tensor + Tensor containing the sky intensity distribution. + num_threads : int, optional + Number of threads used for intraoperative parallelism + on the CPU. See `~torch.set_num_threads`. Default: 10 + noisy : bool, optional + If `True`, generate and add additional noise to + the simulated measurements. Default: True + mode : str, optional + Select one of `'full'`, `'grid'`, or `'dense'` to get + all valid baselines, a grid of unique baselines, or + dense baselines. Default: 'full' + batch_size : int, optional + Batch size for iteration over baselines. Default: 100 + polarisation : str, optional + Choose between `'linear'` or `'circular'` or `None` to + simulate different types of polarisations or disable + the simulation of polarisation. Default: 'linear' + delta : float, optional + Sets the phase difference of the amplitudes $A_x$ and $A_y$ + of the sky distribution. Defines the measure of ellipticity. + Default: 0 + amp_ratio : float, optional + Sets the ratio of $A_{x/r}$. The ratio of $A_{y/l}$ is calculated + as `1 - amp_ratio`. If set to `None`, a random value is drawn + from a uniform distribution. See also: `random_state`. Default: None + random_state : int, optional + Random state used when drawing `amp_ratio` and during the generation + of the random polarisation field. Default: 42 + show_progress : bool, optional + If `True`, show a progress bar during the iteration over the + batches of baselines. Default: False + + Returns + ------- + visibilities : Visibilities + Dataclass object containing visibilities and baselines. + """ torch.set_num_threads(num_threads) torch._dynamo.config.suppress_errors = True - # define unpolarized sky distribution - SI = SI.permute(dims=(1, 2, 0)) - I = torch.zeros((SI.shape[0], SI.shape[1], 4), dtype=torch.cdouble) - I[..., 0] = SI[..., 0] - - # define 2 x 2 Stokes matrix ((I + Q, iU + V), (iU -V, I - Q)) - B = torch.zeros((SI.shape[0], SI.shape[1], 2, 2), dtype=torch.cdouble).to( - torch.device(obs.device) + pol = Polarisation( + SI, + sensitivity_cut=obs.sensitivity_cut, + amp_ratio=amp_ratio, + delta=delta, + polarisation=polarisation, + random_state=random_state, + device=obs.device, ) - B[:, :, 0, 0] = I[:, :, 0] + I[:, :, 1] - B[:, :, 0, 1] = I[:, :, 2] + 1j * I[:, :, 3] - B[:, :, 1, 0] = I[:, :, 2] - 1j * I[:, :, 3] - B[:, :, 1, 1] = I[:, :, 0] - I[:, :, 1] - - # calculations only for px > sensitivity cut - mask = (SI >= obs.sensitivity_cut)[..., 0] - B = B[mask] + + B, mask, lin_dop, circ_dop = pol.stokes_matrix() + lm = obs.lm[mask] rd = obs.rd[mask] @@ -86,6 +413,8 @@ def vis_loop( torch.tensor([]), torch.tensor([]), torch.tensor([]), + torch.tensor([]), + torch.tensor([]), ) vis_num = torch.zeros(1) if mode == "full": @@ -149,7 +478,11 @@ def vis_loop( bas_p[5].cpu(), bas_p[8].cpu(), bas_p[10].cpu(), + torch.tensor([]), + torch.tensor([]), ) + visibilities.linear_dop = lin_dop.cpu() + visibilities.circular_dop = circ_dop.cpu() visibilities.add(vis) del int_values From 1f456e700ca79d72382ac6157324e07f45f40541 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Wed, 14 Aug 2024 11:09:19 +0200 Subject: [PATCH 02/15] Move arguments for polarisation to Observation class --- pyvisgen/simulation/observation.py | 122 ++++++++++++++++++++++++----- pyvisgen/simulation/visibility.py | 32 ++------ 2 files changed, 110 insertions(+), 44 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index 814aad4..f7e4e27 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, fields +from datetime import datetime from math import pi import astropy.constants as const @@ -155,24 +156,104 @@ def _lexsort(self, a, dim=-1): class Observation: def __init__( self, - src_ra, - src_dec, - start_time, - scan_duration, - num_scans, - scan_separation, - integration_time, - ref_frequency, - frequency_offsets, - bandwidths, - fov, - image_size, - array_layout, - corrupted, - device, - dense=False, - sensitivity_cut=1e-6, - ): + src_ra: float, + src_dec: float, + start_time: datetime, + scan_duration: int, + num_scans: int, + scan_separation: int, + integration_time: int, + ref_frequency: float, + frequency_offsets: list, + bandwidths: list, + fov: float, + image_size: int, + array_layout: str, + corrupted: bool, + device: str, + dense: bool = False, + sensitivity_cut: float = 1e-6, + polarisation: str = None, + pol_kwargs: dict = { + "delta": 0, + "amp_ratio": 0.5, + "random_state": 42, + }, + field_kwargs: dict = { + "order": [1, 1], + "scale": [0, 1], + "threshold": None, + "random_state": 42, + }, + ) -> None: + """Sets up the observation class. + + Parameters + ---------- + src_ra : float + Source right ascension coordinate. + src_dec : float + Source declination coordinate. + start_time : datetime + Observation start time. + scan_duration : int + Scan duration. + num_scans : int + Number of scans. + scan_separation : int + Scan separation. + integration_time : int + Integration time. + ref_frequency : float + Reference frequency. + frequency_offsets : list + Frequency offsets. + bandwidths : list + Frequency bandwidth. + fov : float + Field of view. + image_size : int + Image size of the sky distribution. + array_layout : str + Name of an existing array layout. See `~pyvisgen.layouts`. + corrupted : bool + If `True`, apply corruption during the vis loop. + device : str + Torch device to select for computation. + dense : bool, optional + If `True`, apply dense baseline calculation of a perfect + interferometer. Default: `False` + sensitivity_cut : float, optional + Sensitivity threshold, where only pixels above the value + are kept. Default: 1e-6 + polarisation : str, optional + Choose between `'linear'` or `'circular'` or `None` to + simulate different types of polarisations or disable + the simulation of polarisation. Default: `None` + pol_kwargs : dict, optional + Additional keyword arguments for the simulation + of polarisation. Default: `{ + "delta": 0, + "amp_ratio": 0.5, + "random_state": 42, + } + field_kwargs : dict, optional + Additional keyword arguments for the random polarisation + field that is applied when simulating polarisation. + Default: `{ + "order": [1, 1], + "scale": [0, 1], + "threshold": None, + "random_state": 42 + }` + + Notes + ----- + See `~pyvisgen.simulation.visibility.Polarisation` and + `~pyvisgen.simulation.visibility.Polarisation.rand_polarisation_field` + for more information on the keyword arguments in `pol_kwargs` + and `field_kwargs`, respectively. + """ self.ra = torch.tensor(src_ra).double() self.dec = torch.tensor(src_dec).double() @@ -229,6 +310,11 @@ def __init__( self.rd = self.create_rd_grid() self.lm = self.create_lm_grid() + # polarisation + self.polarisation = polarisation + self.pol_kwargs = pol_kwargs + self.field_kwargs = field_kwargs + def calc_dense_baselines(self): N = self.img_size fov = self.fov * pi / (3600 * 180) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index e817857..c4a635b 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -51,6 +51,7 @@ def __init__( amp_ratio: float, delta: float, polarisation: str, + field_kwargs: dict, random_state: int, device: str, ) -> None: @@ -91,21 +92,16 @@ def __init__( if random_state: torch.manual_seed(random_state) - else: - torch.seed() if self.polarisation: self.polarisation_field = self.rand_polarisation_field( [SI.shape[0], SI.shape[1]], - random_state=random_state, - order=[1, 1], - scale=[0, 1], - threshold=None, + **field_kwargs, ) self.delta = delta - if amp_ratio and amp_ratio >= 0: + if amp_ratio and (amp_ratio >= 0): ax2 = amp_ratio else: ax2 = torch.rand(1) @@ -233,9 +229,6 @@ def stokes_matrix(self) -> tuple: else: # No polarisation applied self.I[..., 0] = self.SI[..., 0] - self.I[..., 1] = self.SI[..., 0] - self.I[..., 2] = self.SI[..., 0] - self.I[..., 3] = self.SI[..., 0] B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] @@ -328,10 +321,6 @@ def vis_loop( noisy: bool = True, mode: str = "full", batch_size: int = 100, - polarisation: str = "linear", - delta: float = 0, - amp_ratio: float = None, - random_state: int = 42, show_progress: bool = False, ) -> Visibilities: r"""Computes the visibilities of an observation. @@ -359,14 +348,6 @@ def vis_loop( Choose between `'linear'` or `'circular'` or `None` to simulate different types of polarisations or disable the simulation of polarisation. Default: 'linear' - delta : float, optional - Sets the phase difference of the amplitudes $A_x$ and $A_y$ - of the sky distribution. Defines the measure of ellipticity. - Default: 0 - amp_ratio : float, optional - Sets the ratio of $A_{x/r}$. The ratio of $A_{y/l}$ is calculated - as `1 - amp_ratio`. If set to `None`, a random value is drawn - from a uniform distribution. See also: `random_state`. Default: None random_state : int, optional Random state used when drawing `amp_ratio` and during the generation of the random polarisation field. Default: 42 @@ -385,11 +366,10 @@ def vis_loop( pol = Polarisation( SI, sensitivity_cut=obs.sensitivity_cut, - amp_ratio=amp_ratio, - delta=delta, - polarisation=polarisation, - random_state=random_state, + polarisation=obs.polarisation, device=obs.device, + field_kwargs=obs.field_kwargs, + **obs.pol_kwargs, ) B, mask, lin_dop, circ_dop = pol.stokes_matrix() From 15c3434066e230a3afb5f68cf38d237facd53a52 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Wed, 14 Aug 2024 13:39:14 +0200 Subject: [PATCH 03/15] Catch case where order arg is int in Polarisation class --- pyvisgen/simulation/visibility.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index c4a635b..300445f 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -285,6 +285,9 @@ def rand_polarisation_field( elif len(shape) > 2: raise ValueError("Only 2d shapes are allowed!") + if isinstance(order, int): + order = [order] + if not isinstance(order, list): order = list(order) From 12cd0578e3c1cf688e4b0a717783d94235bcc60a Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Mon, 9 Sep 2024 14:35:27 +0200 Subject: [PATCH 04/15] Update FITS header writer - Add numeric codes for stokes parameters - Add header comments for stokes parameters - Add comment for visibilities (real, complex, weight) --- pyvisgen/fits/writer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pyvisgen/fits/writer.py b/pyvisgen/fits/writer.py index f17f97d..11cdcd5 100644 --- a/pyvisgen/fits/writer.py +++ b/pyvisgen/fits/writer.py @@ -42,9 +42,17 @@ def create_vis_hdu(data, obs, layout="vlba", source_name="sim-source-0"): freq_d = obs.bandwidths[0].cpu().numpy().item() ws = wcs.WCS(naxis=7) + + crval_stokes = -1 + stokes_comment = "-1=RR, -2=LL, -3=RL, -4=LR" + if obs.polarisation == "linear": + crval_stokes = -5 + stokes_comment = "-5=XX, -6=YY, -7=XY, -8=YX" + stokes_comment += " or -5=VV, -6=HH, -7=VH, -8=HV" + ws.wcs.crpix = [1, 1, 1, 1, 1, 1, 1] ws.wcs.cdelt = np.array([1, 1, -1, freq_d, 1, 1, 1]) - ws.wcs.crval = [1, 1, -1, freq, 1, ra, dec] + ws.wcs.crval = [1, 1, crval_stokes, freq, 1, ra, dec] ws.wcs.ctype = ["", "COMPLEX", "STOKES", "FREQ", "IF", "RA", "DEC"] h = ws.to_header() @@ -79,6 +87,8 @@ def create_vis_hdu(data, obs, layout="vlba", source_name="sim-source-0"): hdu_vis.header["PZERO" + str(i + 1)] = parbzeros[i] # add comments + hdu_vis.header.comments["CTYPE2"] = "1=real, 2=imag, 3=weight" + hdu_vis.header.comments["CTYPE3"] = stokes_comment hdu_vis.header.comments["PTYPE1"] = "u baseline coordinate in light seconds" hdu_vis.header.comments["PTYPE2"] = "v baseline coordinate in light seconds" hdu_vis.header.comments["PTYPE3"] = "w baseline coordinate in light seconds" From 56dda4b7d53eae7c2040b84018776fab2c537bf7 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Mon, 9 Sep 2024 14:38:01 +0200 Subject: [PATCH 05/15] Fix observation gridding and Baseline dataclass - Fixes image rotation due to rd grid indexing - Fixes Baselines dataclass attribute sequence, where `num_baseline` would previously be saved as first element --- pyvisgen/simulation/observation.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/pyvisgen/simulation/observation.py b/pyvisgen/simulation/observation.py index f7e4e27..6f74c95 100644 --- a/pyvisgen/simulation/observation.py +++ b/pyvisgen/simulation/observation.py @@ -65,7 +65,6 @@ def get_valid_subset(self, num_baselines, device): date = (torch.from_numpy(t[:-1][mask] + t[1:][mask]) / 2).to(device) return ValidBaselineSubset( - baseline_nums, u_start, u_stop, u_valid, @@ -75,13 +74,13 @@ def get_valid_subset(self, num_baselines, device): w_start, w_stop, w_valid, + baseline_nums, date, ) @dataclass() class ValidBaselineSubset: - baseline_nums: torch.tensor u_start: torch.tensor u_stop: torch.tensor u_valid: torch.tensor @@ -91,6 +90,7 @@ class ValidBaselineSubset: w_start: torch.tensor w_stop: torch.tensor w_valid: torch.tensor + baseline_nums: torch.tensor date: torch.tensor def __getitem__(self, i): @@ -456,9 +456,10 @@ def create_rd_grid(self): - self.img_size / 2 ) * res + dec - _, R = torch.meshgrid((r, r), indexing="ij") - D, _ = torch.meshgrid((d, d), indexing="ij") + R, _ = torch.meshgrid((r, r), indexing="ij") + _, D = torch.meshgrid((d, d), indexing="ij") rd_grid = torch.cat([R[..., None], D[..., None]], dim=2) + return rd_grid def create_lm_grid(self): @@ -479,11 +480,11 @@ def create_lm_grid(self): dec = torch.deg2rad(self.dec) lm_grid = torch.zeros(self.rd.shape, device=self.device, dtype=torch.float64) - lm_grid[:, :, 0] = (torch.cos(self.rd[..., 1]) * torch.sin(self.rd[..., 0])).T - lm_grid[:, :, 1] = ( - torch.sin(self.rd[..., 1]) * torch.cos(dec) - - torch.cos(self.rd[..., 1]) * torch.sin(dec) * torch.cos(self.rd[..., 0]) - ).T + lm_grid[..., 0] = torch.cos(self.rd[..., 1]) * torch.sin(self.rd[..., 0]) + lm_grid[..., 1] = torch.sin(self.rd[..., 1]) * torch.cos(dec) - torch.cos( + self.rd[..., 1] + ) * torch.sin(dec) * torch.cos(self.rd[..., 0]) + return lm_grid def get_baselines(self, times): From b8917b615ee955de1e9011e5be866f7cbc2cebda Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Mon, 9 Sep 2024 14:42:09 +0200 Subject: [PATCH 06/15] Fix stokes calc, fix sequence of visibilities - Fix calculation of Stokes parameters in Polarisation class - Rename visibilities I -> V_11, Q -> V_12, U -> V_21, V -> V_22. The previous names are confusing, as the quantities that are returned are not the Stokes parameters but calculated from the Stokes parameters (i.e. V_11 = I + V for circular polarization) - Fix visbilities sequence -> V_11, V_22, V_12, V_21 (previously V_11, V_12, V_21, V22), i.e. main diagonal of the Jones matrix first, then subdiagonals --- pyvisgen/simulation/visibility.py | 67 +++++++++++++++++-------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 300445f..34c3bdb 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -9,10 +9,10 @@ @dataclass class Visibilities: - SI: torch.tensor - SQ: torch.tensor - SU: torch.tensor - SV: torch.tensor + V_11: torch.tensor + V_22: torch.tensor + V_12: torch.tensor + V_21: torch.tensor num: torch.tensor base_num: torch.tensor u: torch.tensor @@ -27,7 +27,7 @@ def __getitem__(self, i): def get_values(self): return torch.cat( - [self.SI[None], self.SQ[None], self.SU[None], self.SV[None]], dim=0 + [self.V_11[None], self.V_22[None], self.V_12[None], self.V_21[None]], dim=0 ).permute(1, 2, 0) def add(self, visibilities): @@ -95,7 +95,7 @@ def __init__( if self.polarisation: self.polarisation_field = self.rand_polarisation_field( - [SI.shape[0], SI.shape[1]], + [self.SI.shape[0], self.SI.shape[1]], **field_kwargs, ) @@ -108,8 +108,11 @@ def __init__( ay2 = 1 - ax2 - self.ax2 = self.SI[..., 0] * ax2 - self.ay2 = self.SI[..., 0] * ay2 + self.ax2 = self.SI[..., 0].clone() * ax2 + self.ay2 = self.SI[..., 0].clone() * ay2 + else: + self.ax2 = self.SI[..., 0] + self.ay2 = torch.zeros_like(self.ax2) self.I = torch.zeros( (self.SI.shape[0], self.SI.shape[1], 4), dtype=torch.cdouble @@ -174,13 +177,13 @@ def dop(self) -> None: self.I[..., 2] *= self.polarisation_field self.I[..., 3] *= self.polarisation_field - dop_I = self.I[..., 0].clone() + dop_I = self.I[..., 0].real.clone() dop_I[~mask] = float("nan") - dop_Q = self.I[..., 1].clone() + dop_Q = self.I[..., 1].real.clone() dop_Q[~mask] = float("nan") - dop_U = self.I[..., 2].clone() + dop_U = self.I[..., 2].real.clone() dop_U[~mask] = float("nan") - dop_V = self.I[..., 3].clone() + dop_V = self.I[..., 3].real.clone() dop_V[~mask] = float("nan") self.lin_dop = torch.sqrt(dop_Q**2 + dop_U**2) / dop_I @@ -212,28 +215,30 @@ def stokes_matrix(self) -> tuple: self.linear() self.dop() - B[..., 0, 0] = self.I[..., 0] + self.I[..., 3] - B[..., 0, 1] = self.I[..., 1] + 1j * self.I[..., 2] - B[..., 1, 0] = self.I[..., 1] - 1j * self.I[..., 2] - B[..., 1, 1] = self.I[..., 0] - self.I[..., 3] + B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] # I + Q + B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] # U + iV + B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] # U - iV + B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] # I - Q elif self.polarisation == "circular": self.circular() self.dop() - B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] - B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] - B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] - B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] + B[..., 0, 0] = self.I[..., 0] + self.I[..., 3] # I + V + B[..., 0, 1] = self.I[..., 1] + 1j * self.I[..., 2] # Q + iU + B[..., 1, 0] = self.I[..., 1] - 1j * self.I[..., 2] # Q - iU + B[..., 1, 1] = self.I[..., 0] - self.I[..., 3] # I - V else: # No polarisation applied self.I[..., 0] = self.SI[..., 0] + self.polarisation_field = torch.ones_like(self.I[..., 0]) + self.dop() - B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] - B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] - B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] - B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] + B[..., 0, 0] = self.I[..., 0] + self.I[..., 1] # I + Q + B[..., 0, 1] = self.I[..., 2] + 1j * self.I[..., 3] # U + iV + B[..., 1, 0] = self.I[..., 2] - 1j * self.I[..., 3] # U - iV + B[..., 1, 1] = self.I[..., 0] - self.I[..., 1] # I - Q # calculations only for px > sensitivity cut mask = (self.SI >= self.sensitivity_cut)[..., 0] @@ -451,10 +456,10 @@ def vis_loop( vis_num = torch.arange(int_values.shape[0]) + 1 + vis_num.max() vis = Visibilities( - int_values[:, :, 0, 0].cpu(), - int_values[:, :, 0, 1].cpu(), - int_values[:, :, 1, 0].cpu(), - int_values[:, :, 1, 1].cpu(), + int_values[..., 0, 0].cpu(), # V_11 + int_values[..., 1, 1].cpu(), # V_22 + int_values[..., 0, 1].cpu(), # V_12 + int_values[..., 1, 0].cpu(), # V_21 vis_num, bas_p[9].cpu(), bas_p[2].cpu(), @@ -464,11 +469,13 @@ def vis_loop( torch.tensor([]), torch.tensor([]), ) - visibilities.linear_dop = lin_dop.cpu() - visibilities.circular_dop = circ_dop.cpu() visibilities.add(vis) del int_values + + visibilities.linear_dop = lin_dop.cpu() + visibilities.circular_dop = circ_dop.cpu() + return visibilities From d12e8737be1ff8be3b297e5e91f24e5f8cd1d05c Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Thu, 12 Sep 2024 17:29:49 +0200 Subject: [PATCH 07/15] Change type hint from str to torch.device for device --- pyvisgen/simulation/visibility.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 34c3bdb..8966f45 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -53,7 +53,7 @@ def __init__( polarisation: str, field_kwargs: dict, random_state: int, - device: str, + device: torch.device, ) -> None: """Creates the 2 x 2 stokes matrix and simulates polarisation if `polarisation` is either 'linear' @@ -81,7 +81,7 @@ def __init__( random_state : int Random state used when drawing `amp_ratio` and during the generation of the random polarisation field. - device : str + device : torch.device Torch device to select for computation. """ self.sensitivity_cut = sensitivity_cut From a619a3a3c9326652c0d7c146cd5768a60644afbb Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 20 Sep 2024 16:53:41 +0200 Subject: [PATCH 08/15] Flip input image before creating stokes matrix --- pyvisgen/simulation/visibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index 8966f45..d0551ca 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -372,7 +372,7 @@ def vis_loop( torch._dynamo.config.suppress_errors = True pol = Polarisation( - SI, + torch.flip(SI, dims=[1]), sensitivity_cut=obs.sensitivity_cut, polarisation=obs.polarisation, device=obs.device, From 66b5d229f1aff8568d0e49404b4862da72d6786c Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 1 Oct 2024 10:04:34 +0200 Subject: [PATCH 09/15] Add changelog --- docs/changes/39.bugfix.rst | 3 +++ docs/changes/39.feature.rst | 4 ++++ docs/changes/39.maintenance.rst | 6 ++++++ 3 files changed, 13 insertions(+) create mode 100644 docs/changes/39.bugfix.rst create mode 100644 docs/changes/39.feature.rst create mode 100644 docs/changes/39.maintenance.rst diff --git a/docs/changes/39.bugfix.rst b/docs/changes/39.bugfix.rst new file mode 100644 index 0000000..9a98182 --- /dev/null +++ b/docs/changes/39.bugfix.rst @@ -0,0 +1,3 @@ +- Fix gridding in `pyvisgen.simulation.observation.Observation` methods `create_rd_grid` and `create_lm_grid` + methods resulting in rotated images +- Fix `pyvisgen.simulation.observation.ValidBaselineSubset` dataclass field order diff --git a/docs/changes/39.feature.rst b/docs/changes/39.feature.rst new file mode 100644 index 0000000..d2cda72 --- /dev/null +++ b/docs/changes/39.feature.rst @@ -0,0 +1,4 @@ +- Add class `Polarisation` to `pyvisgen.simulation.visibility` that is called in `vis_loop` + - Add linear, circular, and no polarisation options +- Update `pyvisgen.simulation.visibility.Visibilities` dataclass to also store polarisation degree tensors +- Add keyword arguments for polarisation simulation to `pyvisgen.simulation.observation.Observation` class diff --git a/docs/changes/39.maintenance.rst b/docs/changes/39.maintenance.rst new file mode 100644 index 0000000..83a2365 --- /dev/null +++ b/docs/changes/39.maintenance.rst @@ -0,0 +1,6 @@ +- Change pyvisgen.simulation.visibility.Visibilities dataclass component names from stokes components (I , Q, U, and V) + to visibilities constructed from the stokes components (`V_11`, `V_22`, `V_12`, `V_21`) +- Change indices for stokes components according to AIPS Memo 114 + - Indices will be set automatically depending on simulated polarisation +- Update comment strings in FITS files +- Update docstrings accordingly in `pyvisgen.simulation.visibility.vis_loop` and `pyvisgen.simulation.observation.Observation` From 2a5e07c0c54ee08bd1437145252994debb8ce362 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 1 Oct 2024 10:08:23 +0200 Subject: [PATCH 10/15] Fix test failing because of api change --- tests/test_simulation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 4227924..132df39 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -42,10 +42,10 @@ def test_vis_loop(): SI = torch.tensor(data[0])[None] vis_data = vis_loop(obs, SI, noisy=conf["noisy"], mode=conf["mode"]) - assert (vis_data[0].SI[0]).dtype == torch.complex128 - assert (vis_data[0].SQ[0]).dtype == torch.complex128 - assert (vis_data[0].SU[0]).dtype == torch.complex128 - assert (vis_data[0].SV[0]).dtype == torch.complex128 + assert (vis_data[0].V_11[0]).dtype == torch.complex128 + assert (vis_data[0].V_22[0]).dtype == torch.complex128 + assert (vis_data[0].V_12[0]).dtype == torch.complex128 + assert (vis_data[0].V_21[0]).dtype == torch.complex128 assert (vis_data[0].num).dtype == torch.float32 assert (vis_data[0].base_num).dtype == torch.float64 assert torch.is_tensor(vis_data[0].u) From 3ae5f722a2c60fe8093f18043a9ee20f745b9a1f Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Tue, 1 Oct 2024 10:09:12 +0200 Subject: [PATCH 11/15] Update bugfix changelog --- docs/changes/39.bugfix.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/changes/39.bugfix.rst b/docs/changes/39.bugfix.rst index 9a98182..67473db 100644 --- a/docs/changes/39.bugfix.rst +++ b/docs/changes/39.bugfix.rst @@ -1,3 +1,4 @@ - Fix gridding in `pyvisgen.simulation.observation.Observation` methods `create_rd_grid` and `create_lm_grid` methods resulting in rotated images - Fix `pyvisgen.simulation.observation.ValidBaselineSubset` dataclass field order +- Fix test failing because of api change From be9bc138015644aed959788a97dd94350b5203c0 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 25 Oct 2024 14:08:16 +0200 Subject: [PATCH 12/15] Add tests for Polarisation class --- tests/test_simulation.py | 181 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 181 insertions(+) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 132df39..17ff26a 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -1,6 +1,7 @@ from pathlib import Path import torch +from numpy.testing import assert_array_equal, assert_raises from pyvisgen.utils.config import read_data_set_conf @@ -64,3 +65,183 @@ def test_vis_loop(): out = out_path / Path("vis_0.fits") hdu_list = writer.create_hdu_list(vis_data, obs) hdu_list.writeto(out, overwrite=True) + + +class TestPolarisation: + """Unit test class for ``pyvisgen.simulation.visibility.Polarisation``.""" + + def setup_class(self): + """Set up common objects and variables for the following tests.""" + from pyvisgen.simulation.data_set import create_observation + from pyvisgen.simulation.visibility import Polarisation + + self.obs = create_observation(conf) + + self.SI = torch.zeros((100, 100)) + self.SI[25::25, 25::25] = 1 + self.SI = self.SI[None, ...] + + self.si_shape = self.SI.shape + self.im_shape = self.si_shape[1], self.si_shape[2] + + self.obs.img_size = self.im_shape[0] + + self.pol = Polarisation( + self.SI, + sensitivity_cut=self.obs.sensitivity_cut, + polarisation=self.obs.polarisation, + device=self.obs.device, + field_kwargs=self.obs.field_kwargs, + **self.obs.pol_kwargs, + ) + + def test_polarisation_circular(self): + """Test circular polarisation.""" + + self.pol.__init__( + self.SI, + sensitivity_cut=self.obs.sensitivity_cut, + polarisation="circular", + device=self.obs.device, + field_kwargs=self.obs.field_kwargs, + **self.obs.pol_kwargs, + ) + + assert self.pol.delta == 0 + assert self.pol.ax2.sum() == self.SI.sum() * 0.5 + assert self.pol.ay2.sum() == self.SI.sum() * 0.5 + + B, mask, lin_dop, circ_dop = self.pol.stokes_matrix() + + assert mask.sum() == 9 + assert B.shape == torch.Size([9, 2, 2]) + assert mask.shape == self.im_shape + assert lin_dop.shape == self.im_shape + assert lin_dop.shape == self.im_shape + + def test_polarisation_linear(self): + """Test linear polarisation.""" + + self.pol.__init__( + self.SI, + sensitivity_cut=self.obs.sensitivity_cut, + polarisation="linear", + device=self.obs.device, + field_kwargs=self.obs.field_kwargs, + **self.obs.pol_kwargs, + ) + + assert self.pol.delta == 0 + assert self.pol.ax2.sum() == self.SI.sum() * 0.5 + assert self.pol.ay2.sum() == self.SI.sum() * 0.5 + + B, mask, lin_dop, circ_dop = self.pol.stokes_matrix() + + assert mask.sum() == 9 + assert B.shape == torch.Size([9, 2, 2]) + assert mask.shape == self.im_shape + assert lin_dop.shape == self.im_shape + assert lin_dop.shape == self.im_shape + + def test_polarisation_field(self): + """Test Polarisation.rand_polarisation_field method.""" + pf = self.pol.rand_polarisation_field(shape=self.im_shape) + + assert pf.shape == torch.Size([100, 100]) + + def test_polarisation_field_random_state(self): + """Test polarisation field method for a given random_state""" + random_state = 42 + + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=random_state, + ) + + assert torch.random.initial_seed() == random_state + assert pf.shape == torch.Size([100, 100]) + + def test_polarisation_field_shape_int(self): + """Test polarisation field method for type(shape) = int.""" + pf = self.pol.rand_polarisation_field( + shape=self.im_shape[0], + ) + + assert pf.shape == torch.Size([100, 100]) + + def test_polarisation_field_order(self): + """Test polarisation field method for different orders.""" + + pf_ref = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + ) + + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + order=[1, 1], + ) + + assert pf.shape == torch.Size([100, 100]) + # assert order = 1 and order = [1, 1] yield same images + assert_array_equal(pf, pf_ref, strict=True) + + # assert different order creates different image + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, random_state=42, order=[10, 10] + ) + # expected to raise an AssertionError + assert_raises(AssertionError, assert_array_equal, pf, pf_ref) + + # assert len(order) > 2 raises ValueError + assert_raises( + ValueError, + self.pol.rand_polarisation_field, + shape=self.im_shape, + random_state=42, + order=[10, 10, 10], + ) + + def test_polarisation_field_scale(self): + """Test polarisation field method for different scales.""" + + pf_ref = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + ) + + # scale = None + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + scale=None, + ) + + # expected to raise an AssertionError + assert_raises(AssertionError, assert_array_equal, pf, pf_ref) + + # scale = [0.25, 0.25] + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, random_state=42, scale=[0.25, 0.25] + ) + + # expected to raise an AssertionError + assert_raises(AssertionError, assert_array_equal, pf, pf_ref) + + def test_polarisation_field_threshold(self): + """Test polarisation field method for different threshold.""" + + pf_ref = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + ) + + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + threshold=0.5, + ) + + # expected to raise an AssertionError + assert_raises(AssertionError, assert_array_equal, pf, pf_ref) From 12c9d56628c670cd3bf95d3f48f10bc976e85f08 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 25 Oct 2024 15:08:51 +0200 Subject: [PATCH 13/15] Fix bug in Polarisation class --- pyvisgen/simulation/visibility.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyvisgen/simulation/visibility.py b/pyvisgen/simulation/visibility.py index c4e6c6e..259454d 100644 --- a/pyvisgen/simulation/visibility.py +++ b/pyvisgen/simulation/visibility.py @@ -93,7 +93,7 @@ def __init__( if random_state: torch.manual_seed(random_state) - if self.polarisation: + if self.polarisation and self.polarisation in ["circular", "linear"]: self.polarisation_field = self.rand_polarisation_field( [self.SI.shape[0], self.SI.shape[1]], **field_kwargs, @@ -282,6 +282,9 @@ def rand_polarisation_field( if random_state: torch.random.manual_seed(random_state) + if isinstance(shape, int): + shape = [shape] + if not isinstance(shape, list): shape = list(shape) From d4522f97cdee690dfec1daf960b3d6582c0961d3 Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 25 Oct 2024 15:46:49 +0200 Subject: [PATCH 14/15] Add tests for remaining edge cases --- tests/test_simulation.py | 41 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 17ff26a..fda7698 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -143,6 +143,22 @@ def test_polarisation_linear(self): assert lin_dop.shape == self.im_shape assert lin_dop.shape == self.im_shape + def test_polarisation_amplitude(self): + """Test random amplitude.""" + pol_kwargs = {"delta": 0, "amp_ratio": None, "random_state": 42} + + self.pol.__init__( + self.SI, + sensitivity_cut=self.obs.sensitivity_cut, + polarisation="linear", + device=self.obs.device, + field_kwargs=self.obs.field_kwargs, + **pol_kwargs, + ) + + assert self.pol.ax2.sum() <= 9 + assert self.pol.ay2.sum() <= 9 + def test_polarisation_field(self): """Test Polarisation.rand_polarisation_field method.""" pf = self.pol.rand_polarisation_field(shape=self.im_shape) @@ -161,13 +177,28 @@ def test_polarisation_field_random_state(self): assert torch.random.initial_seed() == random_state assert pf.shape == torch.Size([100, 100]) - def test_polarisation_field_shape_int(self): + def test_polarisation_field_shape(self): """Test polarisation field method for type(shape) = int.""" + pf_ref = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + ) + pf = self.pol.rand_polarisation_field( shape=self.im_shape[0], + random_state=42, ) assert pf.shape == torch.Size([100, 100]) + assert_array_equal(pf, pf_ref, strict=True) + + # assert len(shape) > 2 raises ValueError + assert_raises( + ValueError, + self.pol.rand_polarisation_field, + shape=[100, 100, 100], + random_state=42, + ) def test_polarisation_field_order(self): """Test polarisation field method for different orders.""" @@ -187,6 +218,14 @@ def test_polarisation_field_order(self): # assert order = 1 and order = [1, 1] yield same images assert_array_equal(pf, pf_ref, strict=True) + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + order=[1], + ) + # assert order = [1] and order = [1, 1] yield same images + assert_array_equal(pf, pf_ref, strict=True) + # assert different order creates different image pf = self.pol.rand_polarisation_field( shape=self.im_shape, random_state=42, order=[10, 10] From 4edfd60846b807fbc15951e1b4bef9d13e138cbd Mon Sep 17 00:00:00 2001 From: Anno Knierim Date: Fri, 25 Oct 2024 17:55:47 +0200 Subject: [PATCH 15/15] Add test for order as tuple --- tests/test_simulation.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_simulation.py b/tests/test_simulation.py index fda7698..585ba85 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -218,6 +218,14 @@ def test_polarisation_field_order(self): # assert order = 1 and order = [1, 1] yield same images assert_array_equal(pf, pf_ref, strict=True) + pf = self.pol.rand_polarisation_field( + shape=self.im_shape, + random_state=42, + order=(1, 1), + ) + # assert order = (1, 1) and order = [1, 1] yield same images + assert_array_equal(pf, pf_ref, strict=True) + pf = self.pol.rand_polarisation_field( shape=self.im_shape, random_state=42,