Skip to content

Commit

Permalink
Merge pull request #131 from fact-project/optional_altitude
Browse files Browse the repository at this point in the history
Optional altitude
  • Loading branch information
maxnoe authored May 28, 2020
2 parents 5a5b1af + 4a616ef commit d1aefbb
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 21 deletions.
29 changes: 27 additions & 2 deletions aict_tools/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,14 @@ class DispConfig:
'source_az_unit',
'source_zd_column',
'source_zd_unit',
'source_alt_column',
'source_alt_unit',
'pointing_az_column',
'pointing_az_unit',
'pointing_zd_column',
'pointing_zd_unit',
'pointing_alt_column',
'pointing_alt_unit',
'focal_length_column',
'focal_length_unit',
'cog_x_column',
Expand Down Expand Up @@ -205,14 +209,32 @@ def __init__(self, config):
self.features.sort()

self.source_az_column = model_config.get('source_az_column', 'source_position_az')
self.source_zd_column = model_config.get('source_zd_column', 'source_position_zd')
self.source_zd_column = model_config.get('source_zd_column', None)
self.source_alt_column = model_config.get('source_alt_column', None)
if (self.source_zd_column is None) is (self.source_alt_column is None):
raise ValueError(
'Need to specify exactly one of'
'source_zd_column or source_alt_column.'
'source_zd_column: {}, source_alt_column: {}'.format(
self.source_zd_column, self.source_alt_column)
)
self.source_az_unit = u.Unit(model_config.get('source_az_unit', 'deg'))
self.source_zd_unit = u.Unit(model_config.get('source_zd_unit', 'deg'))
self.source_alt_unit = u.Unit(model_config.get('source_alt_unit', 'deg'))

self.pointing_az_column = model_config.get('pointing_az_column', 'pointing_position_az')
self.pointing_zd_column = model_config.get('pointing_zd_column', 'pointing_position_zd')
self.pointing_zd_column = model_config.get('pointing_zd_column', None)
self.pointing_alt_column = model_config.get('pointing_alt_column', None)
if (self.pointing_zd_column is None) is (self.pointing_alt_column is None):
raise ValueError(
'Need to specify exactly one of'
'pointing_zd_column or pointing_alt_column.'
'pointing_zd_column: {}, pointing_alt_column: {}'.format(
self.pointing_zd_column, self.pointing_alt_column)
)
self.pointing_az_unit = u.Unit(model_config.get('pointing_zd_unit', 'deg'))
self.pointing_zd_unit = u.Unit(model_config.get('pointing_zd_unit', 'deg'))
self.pointing_alt_unit = u.Unit(model_config.get('pointing_alt_unit', 'deg'))

self.focal_length_column = model_config.get('focal_length_column', 'focal_length')
self.focal_length_unit = u.Unit(model_config.get('focal_length', 'm'))
Expand All @@ -235,9 +257,12 @@ def __init__(self, config):
cols.update({
self.pointing_az_column,
self.pointing_zd_column,
self.pointing_alt_column,
self.source_az_column,
self.source_zd_column,
self.source_alt_column,
})
cols.discard(None)
if self.coordinate_transformation == 'CTA':
cols.add(self.focal_length_column)

Expand Down
6 changes: 3 additions & 3 deletions aict_tools/cta_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
raise ImportError('This functionality requires ctapipe to be installed')


def horizontal_to_camera_cta_simtel(zd, az, zd_pointing, az_pointing, focal_length):
def horizontal_to_camera_cta_simtel(alt, az, alt_pointing, az_pointing, focal_length):
with warnings.catch_warnings():

altaz = AltAz()
source_altaz = SkyCoord(
az=u.Quantity(az, u.deg, copy=False),
alt=u.Quantity(90 - zd, u.deg, copy=False),
alt=u.Quantity(alt, u.deg, copy=False),
frame=altaz,
)

tel_pointing = SkyCoord(
alt=u.Quantity(90 - zd_pointing, u.deg, copy=False),
alt=u.Quantity(alt_pointing, u.deg, copy=False),
az=u.Quantity(az_pointing, u.deg, copy=False),
frame=altaz,
)
Expand Down
17 changes: 12 additions & 5 deletions aict_tools/scripts/train_disp_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,25 @@ def main(configuration_path, signal_path, predictions_path, disp_model_path, sig
from ..cta_helpers import horizontal_to_camera_cta_simtel
source_x, source_y = horizontal_to_camera_cta_simtel(
az=df[model_config.source_az_column],
zd=df[model_config.source_zd_column],
alt=df[model_config.source_alt_column] if model_config.source_alt_column
else (90-df[model_config.source_zd_column]),
az_pointing=df[model_config.pointing_az_column],
zd_pointing=df[model_config.pointing_zd_column],
alt_pointing=df[model_config.pointing_alt_column]
if model_config.pointing_alt_column
else (90-df[model_config.pointing_zd_column]),
focal_length=df[model_config.focal_length_column],
)
elif model_config.coordinate_transformation == 'FACT':

source_x, source_y = horizontal_to_camera(
az=df[model_config.source_az_column],
zd=df[model_config.source_zd_column],
zd=df[model_config.source_zd_column] if model_config.source_zd_column
else (90-df[model_config.source_alt_column]),
az_pointing=df[model_config.pointing_az_column],
zd_pointing=df[model_config.pointing_zd_column],
)
zd_pointing=df[model_config.pointing_zd_column]
if model_config.pointing_zd_column
else (90-df[model_config.pointing_alt_column]),
)

log.info('Using projected disp: {}'.format(model_config.project_disp))
df['true_disp'], df['true_sign'] = calc_true_disp(
Expand Down
47 changes: 47 additions & 0 deletions examples/config_source_altitude.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# seed for the random number generators, to make things reproducible
seed: 0

# define th number of cross validations to perform
n_cross_validations : 5


disp:
disp_regressor : |
ensemble.RandomForestRegressor(
n_estimators=30,
max_features='sqrt',
n_jobs=-1,
max_depth=20,
)
sign_classifier: |
ensemble.RandomForestClassifier(
n_estimators=30,
max_features='sqrt',
n_jobs=-1,
max_depth=20,
)
coordinate_transformation: FACT
# columns containing coordinates of the source and of the pointing
source_az_column: source_position_az
source_alt_column: source_position_alt
pointing_az_column: pointing_position_az
pointing_alt_column: pointing_position_alt


# randomly sample the data if you dont want to use the whole set
n_signal : 500

features:
- concentration_cog
- concentration_core
- delta
- leakage1
- leakage2
- length
- skewness_long
- kurtosis_long
- num_islands
- size
- width
17 changes: 9 additions & 8 deletions examples/full_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,15 @@ disp:
coordinate_transformation: FACT
# columns containing coordinates of the source and of the pointing
source_zenith_column: source_position_zd
source_zenith_unit: rad
source_azimuth_column: source_position_az
source_azimuth_unit: rad
pointing_azimuth_column: pointing_position_az
pointing_azimuth_unit: rad
pointing_zenith_column: pointing_position_zd
pointing_zenith_unit: rad
source_az_column: source_position_az
source_zd_column: source_position_zd
source_az_unit: deg
source_zd_unit: deg

pointing_az_column: pointing_position_az
pointing_zd_column: pointing_position_zd
pointing_az_unit: deg
pointing_zd_unit: deg

# randomly sample the data if you dont want to use the whole set
n_signal : 500
Expand Down
Binary file added examples/gamma_diffuse_altitude.hdf5
Binary file not shown.
11 changes: 11 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,14 @@ def test_source():

with raises(ValueError):
AICTConfig.from_yaml('tests/config_source.yaml')


def test_altitude():

from aict_tools.configuration import AICTConfig

zd_config = AICTConfig.from_yaml('examples/config_source.yaml')
assert 'source_position_zd' in zd_config.disp.columns_to_read_train

alt_config = AICTConfig.from_yaml('examples/config_source_altitude.yaml')
assert 'source_position_alt' in alt_config.disp.columns_to_read_train
4 changes: 2 additions & 2 deletions tests/test_cta_coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ def test_horizontal_to_camera():
expected_x = df.x
expected_y = df.y
transformed_x, transformed_y = horizontal_to_camera_cta_simtel(
zd=df.zd,
alt=90-df.zd,
az=df.az,
zd_pointing=df.zd_pointing,
alt_pointing=90-df.zd_pointing,
az_pointing=df.az_pointing,
focal_length=df.focal_length,
)
Expand Down
24 changes: 23 additions & 1 deletion tests/test_executables.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,29 @@ def test_apply_separator(temp_dir, separator_model):
assert 'gammaness' in f['events']


def test_train_disp_altitude():
from aict_tools.scripts.train_disp_regressor import main as train

with tempfile.TemporaryDirectory(prefix='aict_tools_test_') as d:

with DateNotModified('examples/gamma_diffuse_altitude.hdf5'):
runner = CliRunner()
result = runner.invoke(
train,
[
'examples/config_source_altitude.yaml',
'examples/gamma_diffuse_altitude.hdf5',
os.path.join(d, 'test.hdf5'),
os.path.join(d, 'disp.pkl'),
os.path.join(d, 'sign.pkl'),
]
)
if result.exit_code != 0:
print(result.output)
print_exception(*result.exc_info)
assert result.exit_code == 0


def test_train_disp_cta():
from aict_tools.scripts.train_disp_regressor import main as train

Expand Down Expand Up @@ -292,7 +315,6 @@ def test_apply_disp_cta():
assert result.exit_code == 0



def test_to_dl3():
from aict_tools.scripts.train_disp_regressor import main as train_disp
from aict_tools.scripts.train_energy_regressor import main as train_energy
Expand Down

0 comments on commit d1aefbb

Please sign in to comment.