Skip to content

Commit

Permalink
add integer compression error estimation script
Browse files Browse the repository at this point in the history
  • Loading branch information
drhlxiao committed Feb 20, 2024
1 parent 7b6715a commit 27b000f
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 40 deletions.
45 changes: 32 additions & 13 deletions stixdcpy/integer_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,35 @@
"""

import numpy as np
from pprint import pprint

MAX_STORED_INTEGER = 1e8



class Compression(object):
def __init__(self, k,m, with_stat_error=True):
self.lut=Compression.get_error_lut(k,m, with_stat_error)
def __init__(self,s, k,m, include_stat_error=True):
self.skm=(s,k,m)
self.lut=Compression.get_error_lut(s,k,m, include_stat_error)

def get_errors(self,counts:np.array):
"""
calculate errors for counts
"""
fv=np.vectorize(self.get_error, otypes=[np.float])
#otypes must be specified
return fv(counts)



def get_error(self,counts):
return self.lut[int(counts)]
if counts==0:
return 0
try:
return self.lut[counts]
except KeyError:
s,k,m=self.skm
raise Exception(f'Failed to error of {counts}! Could the compression scheme ({s=},{k=},{m=}) wrong?')

@staticmethod
def decompress(x,S, K, M):
Expand Down Expand Up @@ -54,7 +72,7 @@ def decompress(x,S, K, M):

x0 = 1 << (M + 1)
if x < x0:
return None, x
return x,0
mask1 = (1 << M) - 1
mask2 = (1 << M)
mantissa1 = x & mask1
Expand All @@ -68,12 +86,12 @@ def decompress(x,S, K, M):
#error of a flat distribution

if mean > MAX_STORED_INTEGER:
return error, float(mean)
return float(mean),error

return error, sign * mean
return sign * mean, error

@staticmethod
def get_error_lut(k, m, with_stat_error=True, s=0):
def get_error_lut(s, k, m, include_stat_error=True):
"""
Creates a lookup table for error calculation based on k and m values.
Expand All @@ -86,14 +104,15 @@ def get_error_lut(k, m, with_stat_error=True, s=0):
"""
res = {}
for i in range(256):
err, val = Compression.decompress(i,s, k, m)
val,err = Compression.decompress(i,s, k, m)
if err is not None:
res[val] = np.sqrt(err ** 2 + val) if with_stat_error else err
res[val] = np.sqrt(err ** 2 + np.abs(val)) if include_stat_error else err
return res

#error_luts = {'035': make_lut(3, 5), '044': make_lut(4, 4), '053': make_lut(5, 3)}

#if __name__ == '__main__':
# from pprint import pprint
# pprint(error_luts)
if __name__ == '__main__':
from pprint import pprint
c=Compression(0,4,4)
error_luts=c.get_error_lut(0,4,4)
pprint(error_luts)

82 changes: 55 additions & 27 deletions stixdcpy/science.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
from stixdcpy.logger import logger
from stixdcpy import io as sio
from stixdcpy.net import FitsQuery as freq
from stixdcpy import integer_compression as sic

from stixdcpy import net as net
from stixdcpy import instrument as inst

from pathlib import PurePath
from datetime import datetime as dt
from datetime import timedelta as td
Expand Down Expand Up @@ -44,7 +47,6 @@ def __init__(self, fname=None, request_id=None):
self.energies = []
self.corrected = None
# self.read_data()
self.skm = {}

@property
def url(self):
Expand All @@ -69,13 +71,7 @@ def filename(self):
return self.fname


#def init_skm(self):
# try:
# self.skm={'counts':self.hdul['CONTROL'].data['compression_scheme_counts_skm'],
# 'triggers':self.hdul['CONTROL'].data['compression_scheme_triggers_skm']
# }
# except (KeyError,TypeError, ValueError):
# pass




Expand Down Expand Up @@ -108,10 +104,9 @@ def read_fits(self, light_time_correction=True):
self.timedel = self.data['timedel']
self.time = self.data['time']

if self.is_time_bin_shifted(self.T0_unix) and len(self.timedel) > 1:
if self.is_time_bin_shifted() and len(self.timedel) > 1:
self.timedel = self.timedel[:-1]
self.time = self.time[1:]
logger.info('Shifted time bins have been corrected automatically!')
if self.data_type == 'PixelData':
self.counts = self.counts[1:, :, :, :]
self.triggers = self.triggers[1:, :]
Expand All @@ -120,6 +115,7 @@ def read_fits(self, light_time_correction=True):
self.counts = self.counts[1:, :]
self.triggers = self.triggers[1:]
# self.rcr = self.rcr[1:]
logger.info('Shifted time bins have been corrected automatically!')

self.request_id = self.hdul['CONTROL'].data['request_id']

Expand Down Expand Up @@ -167,7 +163,7 @@ def read_fits(self, light_time_correction=True):
self.count_rates = self.counts / self.timedel[:, None]
self.trigger_rates = self.triggers / self.timedel

def is_time_bin_shifted(self, unix_time):
def is_time_bin_shifted(self):
"""
Time bins are shifted in the data collected before 2021-12-09 due a bug in the flight software
Expand All @@ -178,8 +174,7 @@ def is_time_bin_shifted(self, unix_time):
is_shifted: bool
True if time bin is shifted else False
"""

return (unix_time < sdt.utc2unix('2021-12-09T14:00:00'))
return (self.T0_unix< sdt.utc2unix('2021-12-09T14:00:00'))

@classmethod
def from_sdc(cls, request_id, level='L1'):
Expand Down Expand Up @@ -274,6 +269,7 @@ def __init__(self, fname, request_id, ltc=False):
self.pixel_count_rates = None
self.correct_pixel_count_rates = None
self.read_fits(light_time_correction=ltc)
self.pixel_counts_comp_stat_err= None
self.make_spectra()


Expand All @@ -294,15 +290,46 @@ def make_spectra(self, pixel_counts=None):

self.pixel_total_counts = np.sum(self.pixel_counts, axis=(0, 3))

def compute_errors_from_counts(self, skm=None):
"""
compute errors using counts
Args
include_stat_error: include statistical errors in the final results
"""

if skm is None:
try:
skm=self.hdul['CONTROL'].data['compression_scheme_counts_skm']
s,k,m=skm.flatten()
except (KeyError,TypeError, ValueError):
raise Exception("Couldn't find SKM found in the FITS file!")
else:
s,k,m=skm
comp = sic.Compression(s,k,m, include_stat_error=True)
self.pixel_counts_comp_stat_err=comp.get_errors(self.counts)


@property
def pixel_counts_error(self):
try:
counts_err=self.hdul['data'].data['counts_err']
except KeyError:
counts_err=self.hdul['data'].data['counts_comp_err']
return error_computation(counts_err,
"""
compute or read counts errors from FITS file
"""

if self.pixel_counts_comp_stat_err is not None:
#if already computed
return self.pixel_counts_comp_stat_err
else:
try:
counts_err=self.hdul['data'].data['counts_err']
except KeyError:
counts_err=self.hdul['data'].data['counts_comp_err']
#in some FITS files compression errors are zeros
return error_computation(counts_err,
self.pixel_counts)





def correct_dead_time(self):
Expand Down Expand Up @@ -351,7 +378,6 @@ def correct(triggers, counts_arr, counts_err_arr, time_bins):

count_rate = counts_arr / time_bins
count_rate_err = counts_err_arr / time_bins
# print(counts_arr.shape)
for det in range(32):
trig_idx = inst.detector_id_to_trigger_index(det)
nin = photons_in[:, trig_idx]
Expand All @@ -362,13 +388,16 @@ def correct(triggers, counts_arr, counts_err_arr, time_bins):

#live_ratio=np.zeros_like(live_ratio)+1
#disable live time correction
live_ratio=live_ratio[:, :, None, None]

corrected_rate = count_rate / live_ratio[:, :, None, None]
corrected_rate = count_rate / live_ratio

corrected_rate_err = count_rate_err / live_ratio[:, :, None, None]
corrected_rate_err =count_rate_err/ live_ratio

corrected_counts = corrected_rate * time_bins
corrected_counts_err = corrected_rate_err * time_bins
#errors of live ratio not taken into account yet


return {
'corrected_rates': corrected_rate,
Expand Down Expand Up @@ -465,8 +494,7 @@ def get_sum_counts(self, start_utc=None, end_utc=None) :
+ sum_counts['bottom_err']**2 )

sum_counts['total'] = sum_counts['top']+sum_counts['bottom']+sum_counts['small']
sum_counts['total_err'] = np.sqrt(sum_counts['top_err']**2
+ sum_counts['bottom_err']**2 + sum_counts['small_err']**2)
sum_counts['total_err'] = np.sqrt(sum_counts['big_err']**2 + sum_counts['small_err']**2)

return sum_counts

Expand Down Expand Up @@ -617,7 +645,6 @@ def __init__(self, l1sig: PixelData, l1bkg: PixelData):
])
# set counts beyond the signal energy range to 0
self.subtracted_counts = (self.l1sig.counts - self.pixel_bkg_counts)
#print(self.l1sig.inversed_energy_bin_mask)
self.subtracted_counts *= self.l1sig.energy_bin_mask

# Dead time correction needs to be included in the future
Expand Down Expand Up @@ -789,7 +816,6 @@ def correct_dead_time(self) -> dict:
def correct(triggers, counts_arr, time_bins, num_detectors):
''' correct dead time using triggers '''
tau_conv_const = 1e-6

photons_in = triggers / (time_bins * num_detectors -
TRIG_TAU * triggers)
# photon rate approximated using triggers
Expand Down Expand Up @@ -880,7 +906,8 @@ def fits_time_corrections(primary_header, tstart, tend):
primary_header.set('DATE_EAR', date_ear)
primary_header.set('DATE_SUN', date_sun)
primary_header.set('MJDREF', Time(tstart).mjd)
# Note that MJDREF in these files is NOT 1979-01-01 but rather the observation start time! Affects time calculation later, but further processing of the file corrects this
# Note that MJDREF in these files is NOT 1979-01-01 but rather the
# observation start time! Affects time calculation later, but further processing of the file corrects this

return primary_header

Expand Down Expand Up @@ -982,7 +1009,8 @@ def time_select_indices(tstart, tend, primary_header, data_table, factor=1.):


def spec_fits_crop(fitsfile, tstart, tend, outfilename=None):
""" Crop a STIX science data product (L1A, L1, or L4) to contain only the data within a given time interval. Create a new FITS file containing this data. The new file will be of the same processing level as the input file.
""" Crop a STIX science data product (L1A, L1, or L4) to contain only the data within a given time interval.
Create a new FITS file containing this data. The new file will be of the same processing level as the input file.
Inputs:
fitsfile : str
Expand Down

0 comments on commit 27b000f

Please sign in to comment.