Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional altitude #131

Merged
merged 6 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions aict_tools/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,14 @@ class DispConfig:
'source_az_unit',
'source_zd_column',
'source_zd_unit',
'source_alt_column',
'source_alt_unit',
'pointing_az_column',
'pointing_az_unit',
'pointing_zd_column',
'pointing_zd_unit',
'pointing_alt_column',
'pointing_alt_unit',
'focal_length_column',
'focal_length_unit',
'cog_x_column',
Expand Down Expand Up @@ -205,14 +209,32 @@ def __init__(self, config):
self.features.sort()

self.source_az_column = model_config.get('source_az_column', 'source_position_az')
self.source_zd_column = model_config.get('source_zd_column', 'source_position_zd')
self.source_zd_column = model_config.get('source_zd_column', None)
self.source_alt_column = model_config.get('source_alt_column', None)
if (self.source_zd_column is None) is (self.source_alt_column is None):
raise ValueError(
'Need to specify exactly one of'
'source_zd_column or source_alt_column.'
'source_zd_column: {}, source_alt_column: {}'.format(
self.source_zd_column, self.source_alt_column)
)
self.source_az_unit = u.Unit(model_config.get('source_az_unit', 'deg'))
self.source_zd_unit = u.Unit(model_config.get('source_zd_unit', 'deg'))
self.source_alt_unit = u.Unit(model_config.get('source_alt_unit', 'deg'))

self.pointing_az_column = model_config.get('pointing_az_column', 'pointing_position_az')
self.pointing_zd_column = model_config.get('pointing_zd_column', 'pointing_position_zd')
self.pointing_zd_column = model_config.get('pointing_zd_column', None)
self.pointing_alt_column = model_config.get('pointing_alt_column', None)
if (self.pointing_zd_column is None) is (self.pointing_alt_column is None):
raise ValueError(
'Need to specify exactly one of'
'pointing_zd_column or pointing_alt_column.'
'pointing_zd_column: {}, pointing_alt_column: {}'.format(
self.pointing_zd_column, self.pointing_alt_column)
)
self.pointing_az_unit = u.Unit(model_config.get('pointing_zd_unit', 'deg'))
self.pointing_zd_unit = u.Unit(model_config.get('pointing_zd_unit', 'deg'))
self.pointing_alt_unit = u.Unit(model_config.get('pointing_alt_unit', 'deg'))

self.focal_length_column = model_config.get('focal_length_column', 'focal_length')
self.focal_length_unit = u.Unit(model_config.get('focal_length', 'm'))
Expand All @@ -235,9 +257,12 @@ def __init__(self, config):
cols.update({
self.pointing_az_column,
self.pointing_zd_column,
self.pointing_alt_column,
self.source_az_column,
self.source_zd_column,
self.source_alt_column,
})
cols.discard(None)
if self.coordinate_transformation == 'CTA':
cols.add(self.focal_length_column)

Expand Down
15 changes: 10 additions & 5 deletions aict_tools/scripts/train_disp_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,23 @@ def main(configuration_path, signal_path, predictions_path, disp_model_path, sig
from ..cta_helpers import horizontal_to_camera_cta_simtel
source_x, source_y = horizontal_to_camera_cta_simtel(
az=df[model_config.source_az_column],
zd=df[model_config.source_zd_column],
zd=df[model_config.source_zd_column] if model_config.source_zd_column
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could do this in the preprocessing.py, add a function that fills the zenith into the dataframe.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.
Was debating that. I really dont know which way I prefer especially with regards to #117.
Both ways feel hacky tbh.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or do it before the if so that the code is not duplicated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean one if clause before the function call instead of two clauses in the function call?
Something like

if not config.zd_column:
    source_zd = 90 - alt
    pointing_zd = 90-pointing_alt
horizontal_to_camera(..., zd, pointing_zd)

Or the other way around for CTA

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my review comment. Now, I came to think more about it, we should just use altitude for CTA (so change the cta function to take altitude) and convert from zenith to altitude if zenith is given for cta (but there should be no real reason for that to happen).

else (90-df[model_config.source_alt_column]),
az_pointing=df[model_config.pointing_az_column],
zd_pointing=df[model_config.pointing_zd_column],
zd_pointing=df[model_config.pointing_zd_column] if model_config.pointing_zd_column
else (90-df[model_config.pointing_alt_column]),
focal_length=df[model_config.focal_length_column],
)
elif model_config.coordinate_transformation == 'FACT':

source_x, source_y = horizontal_to_camera(
az=df[model_config.source_az_column],
zd=df[model_config.source_zd_column],
zd=df[model_config.source_zd_column] if model_config.source_zd_column
else (90-df[model_config.source_alt_column]),
az_pointing=df[model_config.pointing_az_column],
zd_pointing=df[model_config.pointing_zd_column],
)
zd_pointing=df[model_config.pointing_zd_column] if model_config.pointing_zd_column
else (90-df[model_config.pointing_alt_column]),
)

log.info('Using projected disp: {}'.format(model_config.project_disp))
df['true_disp'], df['true_sign'] = calc_true_disp(
Expand Down
47 changes: 47 additions & 0 deletions examples/config_source_altitude.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# seed for the random number generators, to make things reproducible
seed: 0

# define th number of cross validations to perform
n_cross_validations : 5


disp:
disp_regressor : |
ensemble.RandomForestRegressor(
n_estimators=30,
max_features='sqrt',
n_jobs=-1,
max_depth=20,
)

sign_classifier: |
ensemble.RandomForestClassifier(
n_estimators=30,
max_features='sqrt',
n_jobs=-1,
max_depth=20,
)

coordinate_transformation: FACT
# columns containing coordinates of the source and of the pointing
source_az_column: source_position_az
source_alt_column: source_position_alt
pointing_az_column: pointing_position_az
pointing_alt_column: pointing_position_alt


# randomly sample the data if you dont want to use the whole set
n_signal : 500

features:
- concentration_cog
- concentration_core
- delta
- leakage1
- leakage2
- length
- skewness_long
- kurtosis_long
- num_islands
- size
- width
17 changes: 9 additions & 8 deletions examples/full_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,15 @@ disp:

coordinate_transformation: FACT
# columns containing coordinates of the source and of the pointing
source_zenith_column: source_position_zd
source_zenith_unit: rad
source_azimuth_column: source_position_az
source_azimuth_unit: rad
pointing_azimuth_column: pointing_position_az
pointing_azimuth_unit: rad
pointing_zenith_column: pointing_position_zd
pointing_zenith_unit: rad
source_az_column: source_position_az
source_zd_column: source_position_zd
source_az_unit: deg
source_zd_unit: deg

pointing_az_column: pointing_position_az
pointing_zd_column: pointing_position_zd
pointing_az_unit: deg
pointing_zd_unit: deg

# randomly sample the data if you dont want to use the whole set
n_signal : 500
Expand Down
Binary file added examples/gamma_diffuse_altitude.hdf5
Binary file not shown.
11 changes: 11 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,14 @@ def test_source():

with raises(ValueError):
AICTConfig.from_yaml('tests/config_source.yaml')


def test_altitude():

from aict_tools.configuration import AICTConfig

zd_config = AICTConfig.from_yaml('examples/config_source.yaml')
assert 'source_position_zd' in zd_config.disp.columns_to_read_train

alt_config = AICTConfig.from_yaml('examples/config_source_altitude.yaml')
assert 'source_position_alt' in alt_config.disp.columns_to_read_train
24 changes: 23 additions & 1 deletion tests/test_executables.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,29 @@ def test_apply_separator(temp_dir, separator_model):
assert 'gammaness' in f['events']


def test_train_disp_altitude():
from aict_tools.scripts.train_disp_regressor import main as train

with tempfile.TemporaryDirectory(prefix='aict_tools_test_') as d:

with DateNotModified('examples/gamma_diffuse_altitude.hdf5'):
runner = CliRunner()
result = runner.invoke(
train,
[
'examples/config_source_altitude.yaml',
'examples/gamma_diffuse_altitude.hdf5',
os.path.join(d, 'test.hdf5'),
os.path.join(d, 'disp.pkl'),
os.path.join(d, 'sign.pkl'),
]
)
if result.exit_code != 0:
print(result.output)
print_exception(*result.exc_info)
assert result.exit_code == 0


def test_train_disp_cta():
from aict_tools.scripts.train_disp_regressor import main as train

Expand Down Expand Up @@ -292,7 +315,6 @@ def test_apply_disp_cta():
assert result.exit_code == 0



def test_to_dl3():
from aict_tools.scripts.train_disp_regressor import main as train_disp
from aict_tools.scripts.train_energy_regressor import main as train_energy
Expand Down