forked from joleroi/nddata
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnddata_test.py
107 lines (77 loc) · 2.57 KB
/
nddata_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import pytest
from nddata import NDDataArray, BinnedDataAxis, DataAxis
from astropy.units import Unit as u
import random
import numpy as np
import IPython
# Example: 1D Histogram
#----------------------
hist = NDDataArray()
assert hist.dim == 0
x_axis = BinnedDataAxis.linspace(0, 100, 10, 'm')
hist.add_axis(x_axis)
assert hist.axes[0].name == 'x'
assert hist.axis_names == ['x']
with pytest.raises(ValueError):
hist.get_axis('y')
assert (x_axis == hist.get_axis('x')).all
assert hist.dim == 1
data = np.arange(15)
with pytest.raises(ValueError):
hist.data = data
data = [random.expovariate(hist.axes[0][_].value+1) for _ in range(10)]
hist.data = data
# Find nodes on x-axis
with pytest.raises(ValueError):
hist.find_node(x = [14 * u('s')])
idx = hist.get_axis('x').find_node(12 * u('m'))
assert idx[0] == 1
idx = hist.get_axis('x').find_node(1200 * u('cm'))
assert idx[0] == 1
vals = [13 * u('m'), 2500*u('cm'), 600 * u('dm')]
idx = hist.get_axis('x').find_node(vals)
assert idx[0] == np.array([1, 2, 6]).all()
# Find nodes using array
with pytest.raises(ValueError):
hist.find_node(energy = 5)
idx = hist.find_node(x = [12 * u('m'), 67 * u('m')])
assert idx[0][0] == 1
eval_data = hist.evaluate(x = [32.52 * u('m')])
assert eval_data == data[3]
eval_data = hist.evaluate(x = [32.52 * u('m'), 12 * u('m'), 61.1512 * u('m')])
assert (eval_data == np.asarray(data)[np.array([3,1,6])]).all()
# Example: 2D Histogram
#----------------------
y_axis = DataAxis(np.arange(1,6), 'kg')
y_axis.name = 'weight'
hist.add_axis(y_axis)
assert hist.axis_names == ['weight', 'x']
assert (hist.get_axis('weight') == y_axis).all()
assert hist.data is None
# Data in wrong axis order
val = np.arange(1,6)
d = np.array(data)
data_2d = np.tensordot(d, val, axes=0)
assert data_2d.shape == (10, 5)
with pytest.raises(ValueError):
hist.data = data_2d
data_2d = data_2d.transpose()
hist.data = data_2d
nodes = hist.find_node(x = [12 * u('m'), 23 * u('m')], weight = [1.2, 4.3, 3.5] * u('kg'))
assert len(nodes) == 2
assert len(nodes[0]) == 3
assert len(nodes[1]) == 2
assert nodes[1][1] == 2
assert nodes[0][2] == 2
nodes = hist.find_node(x = [16 * u('m')])
assert len(nodes) == 2
assert nodes[0][4] == 4
eval_data = hist.evaluate(x = 12 * u('m'), weight = 3.2 * u('kg'))
assert eval_data == data_2d[2,1]
eval_data = hist.evaluate(x = [12, 34] * u('m'), weight = [3.2, 2, 2.4] * u('kg'))
assert eval_data.shape == (3,2)
eval_data = hist.evaluate(weight = [3.2, 2, 2.4] * u('kg'))
assert eval_data.shape == (3, 10)
with pytest.raises(ValueError):
hist.plot_image(x = 12 * u('m'))
hist.plot_image()