Skip to content

Commit

Permalink
add compression error calculator
Browse files Browse the repository at this point in the history
  • Loading branch information
drhlxiao committed Feb 13, 2024
1 parent 86e7d35 commit d71b071
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 11 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
setup(
name='stixdcpy',
description='STIX data center APIs and data analysis tools',
version='2.2',
version='3.0',
author='Hualin Xiao',
author_email='[email protected]',
long_description=open('README.md').read(),
install_requires=['numpy', 'requests', 'python-dateutil',
install_requires=['numpy', 'requests', 'python-dateutil', 'ipython',
'astropy', 'matplotlib','tqdm','pandas','joblib', 'roentgen', 'simplejson', 'sunpy','wget'],
long_description_content_type='text/markdown',
#packages=find_packages(where='stixdcpy'),
Expand Down
88 changes: 88 additions & 0 deletions stixdcpy/compression_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
This script provides functions for decompressing integer values and creating error lookup tables.
It includes the following functions:
- decompress(x, K, M): Decompresses a compressed integer value.
- make_lut(k, m): Creates a lookup table for error calculation based on k and m values.
Additionally, it defines error lookup tables for specific combinations of k and m values.
Author: Hualin Xiao([email protected])
"""

import numpy as np

MAX_STORED_INTEGER = 1e8

def decompress(x, K, M):
"""
Decompresses a compressed integer value.
Parameters:
x (int): The compressed integer value to decompress.
K (int): The number of bits reserved for the exponent.
M (int): The number of bits reserved for the mantissa.
Returns:
tuple: A tuple containing the error and the decompressed value.
The error represents the uncertainty in the decompressed value.
The decompressed value is the original integer value.
"""
S = 0
if S + K + M > 8 or S not in (0, 1) or K > 7 or M > 7:
return None, None
if K == 0 or M == 0:
return None, None

sign = 1
if S == 1: # signed
MSB = x & (1 << 7)
if MSB != 0:
sign = -1
x = x & ((1 << 7) - 1)

x0 = 1 << (M + 1)
if x < x0:
return None, x
mask1 = (1 << M) - 1
mask2 = (1 << M)
mantissa1 = x & mask1
exponent = (x >> M) - 1
# number of shifted bits
mantissa2 = mask2 | mantissa1 # add 1 before mantissa
low = mantissa2 << exponent # minimal possible value
high = low | ((1 << exponent) - 1) # maximal possible value
mean = (low + high) >> 1 # mean value
error = np.sqrt((high - low) ** 2 / 12)

if mean > MAX_STORED_INTEGER:
return error, float(mean)

return error, sign * mean

def make_lut(k, m):
"""
Creates a lookup table for error calculation based on k and m values.
Parameters:
k (int): The number of bits reserved for the exponent.
m (int): The number of bits reserved for the mantissa.
Returns:
dict: A dictionary mapping decompressed values to their respective errors.
"""
res = {}
for i in range(256):
err, val = decompress(i, k, m)
if err is not None:
res[val] = np.sqrt(err ** 2 + val)
return res

error_luts = {'035': make_lut(3, 5), '044': make_lut(4, 4), '053': make_lut(5, 3)}

if __name__ == '__main__':
from pprint import pprint
pprint(error_luts)

159 changes: 150 additions & 9 deletions stixdcpy/science.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,15 @@ def make_spectra(self, pixel_counts=None):

@property
def pixel_counts_error(self):
return error_computation(self.hdul['data'].data['counts_err'],
try:
counts_err=self.hdul['data'].data['counts_err']
except KeyError:
counts_err=self.hdul['data'].data['counts_comp_err']
return error_computation(counts_err,
self.pixel_counts)



def correct_dead_time(self):
""" dead time correction
Returns:
Expand All @@ -295,22 +301,30 @@ def correct_dead_time(self):
live_ratio: np.array
"""

def correct(triggers, counts_arr, time_bins):
def correct(triggers, counts_arr, counts_err_arr, time_bins):
""" Live time correction
Args
triggers: ndarray
triggers in the spectrogram
counts_arr:ndarray
counts in the spectrogram
counts_err_arr:ndarray
counts error in the spectrogram
time_bins: ndarray
time_bins in the spectrogram
Returns
live_time_ratio: ndarray
live time ratio of detectors
count_rate:
corrected_rates:
corrected count rate
count_rate:
count rate before dead time correction
time:
timestamp
photons_in:
rate of photons illuminating the detector group
corrected_counts:
dead-time corrected counts
"""

time_bins = time_bins[:, None]
Expand All @@ -321,31 +335,154 @@ def correct(triggers, counts_arr, time_bins):
time_bins = time_bins[:, :, None, None]

count_rate = counts_arr / time_bins
count_rate_err = counts_err_arr / time_bins
# print(counts_arr.shape)
for det in range(32):
trig_idx = inst.detector_id_to_trigger_index(det)
nin = photons_in[:, trig_idx]
live_ratio[:, det] = np.exp(
-BETA * nin * ASIC_TAU) / (1 + nin * TRIG_TAU)



#live_ratio=np.zeros_like(live_ratio)+1
#disable live time correction

corrected_rate = count_rate / live_ratio[:, :, None, None]

corrected_rate_err = count_rate_err / live_ratio[:, :, None, None]

corrected_counts = corrected_rate * time_bins
corrected_counts_err = corrected_rate_err * time_bins

return {
'corrected_rates': corrected_rate,
'corrected_rate_err': corrected_rate_err,
'count_rate': count_rate,
'photons_in': photons_in,
'corrected_counts':corrected_counts,
'corrected_counts_err':corrected_counts_err,
'time': self.datetime,
'time_bins': time_bins.flatten(),
'live_ratio': live_ratio
}

self.corrected = correct(self.triggers, self.pixel_counts,
self.corrected = correct(self.triggers, self.pixel_counts, self.pixel_counts_error,
self.timedel)

# approximate the live ratio error like Ewan does
above = correct(self.triggers + self.trigger_error, self.pixel_counts,
above = correct(self.triggers + self.trigger_error, self.pixel_counts,self.pixel_counts_error,
self.timedel)
below = correct(self.triggers - self.trigger_error, self.pixel_counts,
below = correct(self.triggers - self.trigger_error, self.pixel_counts,self.pixel_counts_error,
self.timedel)
self.corrected['live_error'] = np.abs(above['live_ratio'] -
below['live_ratio']) / 2
return self.corrected
def get_sum_counts(self, start_utc=None, end_utc=None) :
"""
Calculate the total counts in different regions of a pixel data file within a specified time range.
Parameters:
- start_utc (str or None, optional): The start time in UTC format. If not provided, the beginning of the observation is used.
- end_utc (str or None, optional): The end time in UTC format. If not provided, the end of the observation is used.
Returns:
- dict: A dictionary containing total counts in different regions:
- 'top': Total counts in the top row(channels 1-4).
- 'bottom': Total counts in the bottom row(channels 5-8).
- 'small': Total counts in the small pixels.
- 'total': Total counts, sum of 'top' and 'bottom'.
Note:
- The function relies on external modules 'sci' and 'st' for pixel data manipulation and time conversions.
- The time range for counts calculation is determined by start_utc and end_utc.
If these are not provided, the entire observation duration is considered.
"""
cl1=self.correct_dead_time()
data_start=self.T0_unix
data_end=self.T0_unix+self.duration
pixel_counts=cl1['corrected_counts']
pixel_counts_err=cl1['corrected_counts_err']

if start_utc is None and end_utc is None:
start_i_tbin, end_i_tbin=0, pixel_counts.shape[0]-1
duration = self.duration

else:
start_unix=sdt.utc2unix(start_utc)
end_unix=sdt.utc2unix(end_utc)
if start_unix > data_end or end_unix < data_start:
return None

start_unix=max(data_start, start_unix)
end_unix=min(data_end, end_unix)



#duration=end_unix-start_unix

start_time = start_unix - self.T0_unix #relative start time
end_time = end_unix - self.T0_unix

start_i_tbin=np.argmax(
self.time - 0.5 * self.timedel >= start_time) if (
0 <= start_time <= self.duration) else 0
end_i_tbin=np.argmin(
self.time + 0.5 * self.timedel <= end_time) if (
start_time <= end_time <= self.duration)+1 else len(
self.time)

duration=np.sum(self.timedel[start_i_tbin:end_i_tbin])



sum_counts = {'top': np.sum(pixel_counts[start_i_tbin:end_i_tbin, :,0:4,:], axis=(0,1,2) ),
'bottom': np.sum(pixel_counts[start_i_tbin:end_i_tbin, :,4:8,:],axis=(0,1,2) ),
'small': np.sum(pixel_counts[start_i_tbin:end_i_tbin, :,8:,:],axis=(0,1,2) ),
'duration':duration,

'top_err': np.sqrt(np.sum(pixel_counts_err[start_i_tbin:end_i_tbin, :,0:4,:]**2, axis=(0,1,2) )),
'bottom_err': np.sqrt(np.sum(pixel_counts_err[start_i_tbin:end_i_tbin, :,4:8,:]**2,axis=(0,1,2) )),
'small_err': np.sqrt(np.sum(pixel_counts_err[start_i_tbin:end_i_tbin, :,8:,:]**2,axis=(0,1,2) ))
}

sum_counts['big'] = sum_counts['top']+sum_counts['bottom']
sum_counts['big_err'] = np.sqrt(sum_counts['top_err']**2
+ sum_counts['bottom_err']**2 )

sum_counts['total'] = sum_counts['top']+sum_counts['bottom']+sum_counts['small']
sum_counts['total_err'] = np.sqrt(sum_counts['top_err']**2
+ sum_counts['bottom_err']**2 + sum_counts['small_err']**2)

return sum_counts


def get_mean_rate(self,start_utc =None, end_utc=None):
"""
Calculate the mean count rate in different regions of a pixel data file within a specified time range.
Parameters:
- start_utc (str or None, optional): The start time in UTC format. If not provided, the beginning of the observation is used.
- end_utc (str or None, optional): The end time in UTC format. If not provided, the end of the observation is used.
Returns:
- dict: A dictionary containing total counts in different regions:
- 'top': mean count rate in the top row(channels 1-4).
- 'bottom': mean count rate in the bottom row(channels 5-8).
- 'small': mean count rate in the small pixels.
- 'total': mean rate of all pixels 'top' and 'bottom'.
Note:
- The function relies on external modules 'sci' and 'st' for pixel data manipulation and time conversions.
- The time range for counts calculation is determined by start_utc and end_utc. If these are not provided, the entire observation duration is considered.
"""

cnts=get_sum_counts(start_utc , end_utc)
result={}
for key,val in cnts.items():
norm = cnts['duration'] if key !='duration' else 1
result[key]=val/norm
return result



def peek(self,
plots=['spg', 'lc', 'spec', 'tbin', 'qllc'],
Expand Down Expand Up @@ -450,7 +587,7 @@ def __init__(self, l1sig: PixelData, l1bkg: PixelData):
"""
self.l1sig = l1sig
self.l1bkg = l1bkg
print(self.l1sig.energy_bin_mask)
#print(self.l1sig.energy_bin_mask)

dmask = self.l1bkg.energy_bin_mask - self.l1sig.energy_bin_mask
if np.any(dmask < 0):
Expand All @@ -465,7 +602,7 @@ def __init__(self, l1sig: PixelData, l1bkg: PixelData):
])
# set counts beyond the signal energy range to 0
self.subtracted_counts = (self.l1sig.counts - self.pixel_bkg_counts)
print(self.l1sig.inversed_energy_bin_mask)
#print(self.l1sig.inversed_energy_bin_mask)
self.subtracted_counts *= self.l1sig.energy_bin_mask

# Dead time correction needs to be included in the future
Expand Down Expand Up @@ -660,7 +797,7 @@ def correct(triggers, counts_arr, time_bins, num_detectors):
except KeyError:
num_detectors = self.hdul[1].data['detector_mask'].sum()

self.corrected = correct(self.triggers, self.counts, self.timedel,
self.corrected = correct(self.triggers, self.counts, self.timedel,
num_detectors)

# approximate the live ratio error like Ewan does
Expand Down Expand Up @@ -1005,3 +1142,7 @@ def spec_fits_concatenate(fitsfile1,
hdul.writeto(outfilename, overwrite=True)

return outfilename




0 comments on commit d71b071

Please sign in to comment.