Skip to content

Commit

Permalink
Merge pull request #9 from radionets-project/fits
Browse files Browse the repository at this point in the history
Minor speed up changes
  • Loading branch information
StFroese authored May 24, 2021
2 parents 16c5c2b + 57e9d6c commit 792b43b
Showing 1 changed file with 100 additions and 8 deletions.
108 changes: 100 additions & 8 deletions vipy/simulation/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from vipy.layouts import layouts
import torch
import itertools

import time as t
import numexpr as ne # fast exponential
from einsumt import einsumt as einsum

@dataclass
class Baselines:
Expand Down Expand Up @@ -312,6 +314,7 @@ def corrupted(lm, baselines, wave, time, src_crd, array_layout, I, rd):
4d array
Returns visibility for every lm and baseline
"""

stat_num = array_layout.st_num.shape[0]
base_num = int(stat_num * (stat_num - 1) / 2)

Expand All @@ -323,8 +326,11 @@ def corrupted(lm, baselines, wave, time, src_crd, array_layout, I, rd):
if st1_num.shape[0] == 0:
return torch.zeros(1)


K = getK(baselines, lm, wave, base_num)



B = np.zeros((lm.shape[0], lm.shape[1], 2, 2), dtype=complex)

B[:,:,0,0] = I[:,:,0]+I[:,:,1]
Expand All @@ -334,6 +340,10 @@ def corrupted(lm, baselines, wave, time, src_crd, array_layout, I, rd):

# coherency
X = torch.einsum('lmij,lmb->lmbij', torch.tensor(B), K)
# X = np.einsum('lmij,lmb->lmbij', B, K, optimize=True)
# X = torch.tensor(B)[:,:,None,:,:] * K[:,:,:,None,None]


del K

# telescope response
Expand All @@ -343,14 +353,17 @@ def corrupted(lm, baselines, wave, time, src_crd, array_layout, I, rd):
E1 = torch.tensor(E_st[:, :, st1_num], dtype=torch.cdouble)
E2 = torch.tensor(E_st[:, :, st2_num], dtype=torch.cdouble)

# EX = torch.einsum('lmbij,lmbjk->lmbik',E1,X)

EX = torch.einsum('lmb,lmbij->lmbij',E1,X)

del E1, X
# EXE = torch.einsum('lmbij,lmbjk->lmbik',EX,torch.transpose(torch.conj(E2),3,4))
EXE = torch.einsum('lmbij,lmb->lmbij',EX,torch.conj(E2))
EXE = torch.einsum('lmbij,lmb->lmbij',EX,E2)
del EX, E2

# P matrix
# parallactic angle

beta = np.array(
[
Observer(
Expand All @@ -365,12 +378,88 @@ def corrupted(lm, baselines, wave, time, src_crd, array_layout, I, rd):
P1 = torch.tensor(getP(b1),dtype=torch.cdouble)
P2 = torch.tensor(getP(b2),dtype=torch.cdouble)



PEXE = torch.einsum('bij,lmbjk->lmbik',P1,EXE)
del EXE
PEXEP = torch.einsum('lmbij,bjk->lmbik',PEXE,torch.transpose(torch.conj(P2),1,2))
del PEXE

return PEXEP

def direction_independent(lm, baselines, wave, time, src_crd, array_layout, I, rd):
"""Calculates direction independet visibility
Parameters
----------
lm : 3d array
every pixel containing a l and m value
baselines : dataclass
baseline information
wave : float
wavelength of observation
time : astropy Time
Time steps of observation
src_crd : astropy SkyCoord
source position
array_layout : dataclass
station information
I : 2d array
source brightness distribution / input img
rd : 3d array
RA and dec values for every pixel
Returns
-------
4d array
Returns visibility for every lm and baseline
"""

stat_num = array_layout.st_num.shape[0]
base_num = int(stat_num * (stat_num - 1) / 2)


vectorized_num = np.vectorize(lambda st: st.st_num, otypes=[int])
st1, st2 = get_valid_baselines(baselines, base_num)
st1_num = vectorized_num(st1)
st2_num = vectorized_num(st2)
if st1_num.shape[0] == 0:
return torch.zeros(1)


K = getK(baselines, lm, wave, base_num)



B = np.zeros((lm.shape[0], lm.shape[1], 2, 2), dtype=complex)

B[:,:,0,0] = I[:,:,0]+I[:,:,1]
B[:,:,0,1] = I[:,:,2]+1j*I[:,:,3]
B[:,:,1,0] = I[:,:,2]-1j*I[:,:,3]
B[:,:,1,1] = I[:,:,0]-I[:,:,1]

# coherency
X = torch.einsum('lmij,lmb->lmbij', torch.tensor(B), K)



del K

# telescope response
E_st = getE(rd, array_layout, wave, src_crd)

E1 = torch.tensor(E_st[:, :, st1_num], dtype=torch.cdouble)
E2 = torch.tensor(E_st[:, :, st2_num], dtype=torch.cdouble)


EX = torch.einsum('lmb,lmbij->lmbij',E1,X)

del E1, X

EXE = torch.einsum('lmbij,lmb->lmbij',EX,E2)
del EX, E2

return EXE

def integrate(X1, X2):
"""Summation over l and m and avering over time and freq
Expand Down Expand Up @@ -506,6 +595,7 @@ def getK(baselines, lm, wave, base_num):
Shape is given by lm axes and baseline axis
"""
# new valid baseline calculus. for details see function get_valid_baselines()

valid = baselines.valid.reshape(-1, base_num)
mask = np.array(valid[:-1]).astype(bool) & np.array(valid[1:]).astype(bool)

Expand All @@ -527,14 +617,16 @@ def getK(baselines, lm, wave, base_num):
l = torch.tensor(lm[:, :, 0])
m = torch.tensor(lm[:, :, 1])
n = torch.sqrt(1-l**2-m**2)

ul = torch.einsum("b,ij->ijb", torch.tensor(u_cmplt), l)
vm = torch.einsum("b,ij->ijb", torch.tensor(v_cmplt), m)
wn = torch.einsum('b,ij->ijb', torch.tensor(w_cmplt), (n-1))

K = torch.exp(-2 * np.pi * 1j * (ul + vm + wn))

return K


pi = np.pi
test = ul + vm + wn
K = ne.evaluate('exp(-2 * pi * 1j * (ul + vm + wn))') #-0.4 secs for vlba
return torch.tensor(K)


def jinc(x):
Expand Down

0 comments on commit 792b43b

Please sign in to comment.