-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathtasks_test.py
310 lines (258 loc) · 12.7 KB
/
tasks_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
# pylint: disable=redefined-outer-name
"""Tests for search tasks"""
from types import SimpleNamespace
from ddt import (
data,
ddt,
unpack,
)
from django.conf import settings
from django.test import override_settings
import pytest
from dashboard.factories import ProgramEnrollmentFactory
from search.base import MockedESTestCase
from search.exceptions import ReindexException
from search.factories import PercolateQueryFactory
from search.indexing_api import create_backing_indices
from search.models import PercolateQuery
from search.tasks import (
index_users,
index_program_enrolled_users, start_recreate_index, bulk_index_percolate_queries, bulk_index_program_enrollments,
finish_recreate_index,
)
FAKE_INDEX = 'fake'
pytestmark = pytest.mark.django_db
@pytest.fixture
def mocked_celery(mocker):
"""Mock object that patches certain celery functions"""
exception_class = TabError
replace_mock = mocker.patch(
"celery.app.task.Task.replace", autospec=True, side_effect=exception_class
)
group_mock = mocker.patch("celery.group", autospec=True)
chain_mock = mocker.patch("celery.chain", autospec=True)
yield SimpleNamespace(
replace=replace_mock,
group=group_mock,
chain=chain_mock,
replace_exception_class=exception_class,
)
def fail_first():
"""Returns a function which raises an exception the first time then does nothing on subsequent calls"""
first = False
def func(*args, **kwargs): # pylint: disable=unused-argument
"""Raises first time, does nothing subsequent calls"""
nonlocal first
if not first:
first = True
raise KeyError()
return func
@ddt
@override_settings(
OPENSEARCH_INDEX=FAKE_INDEX,
OPEN_DISCUSSIONS_JWT_SECRET='secret',
OPEN_DISCUSSIONS_BASE_URL='http://fake',
OPEN_DISCUSSIONS_API_USERNAME='mitodl',
)
class SearchTasksTests(MockedESTestCase):
"""
Tests for search tasks
"""
def setUp(self):
super().setUp()
for mock in self.patcher_mocks:
if mock.name == "_index_program_enrolled_users":
self.index_program_enrolled_users_mock = mock
elif mock.name == "_document_needs_updating":
self.document_needs_updating_mock = mock
elif mock.name == "_send_automatic_emails":
self.send_automatic_emails_mock = mock
elif mock.name == "_refresh_all_default_indices":
self.refresh_index_mock = mock
elif mock.name == "_update_percolate_memberships":
self.update_percolate_memberships_mock = mock
def test_index_users(self):
"""
When we run the index_users task we should index user's program enrollments and send them automatic emails
"""
enrollment1 = ProgramEnrollmentFactory.create()
enrollment2 = ProgramEnrollmentFactory.create(user=enrollment1.user)
index_users([enrollment1.user.id])
assert self.index_program_enrolled_users_mock.call_count == 1
assert sorted(
self.index_program_enrolled_users_mock.call_args[0][0],
key=lambda _enrollment: _enrollment.id
) == sorted(
[enrollment1, enrollment2],
key=lambda _enrollment: _enrollment.id
)
for enrollment in [enrollment1, enrollment2]:
self.send_automatic_emails_mock.assert_any_call(enrollment)
self.update_percolate_memberships_mock.assert_any_call(
enrollment.user, PercolateQuery.DISCUSSION_CHANNEL_TYPE)
self.refresh_index_mock.assert_called_with()
@data(*[
[True, True],
[True, False],
[False, True],
[False, False],
])
@unpack
def test_index_users_check_if_changed(self, enrollment1_needs_update, enrollment2_needs_update):
"""
If check_if_changed is true we should only update documents which need updating
"""
enrollment1 = ProgramEnrollmentFactory.create()
enrollment2 = ProgramEnrollmentFactory.create()
needs_update_list = []
if enrollment1_needs_update:
needs_update_list.append(enrollment1)
if enrollment2_needs_update:
needs_update_list.append(enrollment2)
def fake_needs_updating(_enrollment):
"""Fake document_needs_update to conform to test data"""
return _enrollment in needs_update_list
self.document_needs_updating_mock.side_effect = fake_needs_updating
index_users([enrollment1.user.id, enrollment2.user.id], check_if_changed=True)
expected_enrollments = []
if enrollment1_needs_update:
expected_enrollments.append(enrollment1)
if enrollment2_needs_update:
expected_enrollments.append(enrollment2)
self.document_needs_updating_mock.assert_any_call(enrollment1)
self.document_needs_updating_mock.assert_any_call(enrollment2)
if len(needs_update_list) > 0:
self.index_program_enrolled_users_mock.assert_called_once_with(needs_update_list)
for enrollment in needs_update_list:
self.send_automatic_emails_mock.assert_any_call(enrollment)
self.update_percolate_memberships_mock.assert_any_call(
enrollment.user, PercolateQuery.DISCUSSION_CHANNEL_TYPE)
else:
assert self.index_program_enrolled_users_mock.called is False
assert self.send_automatic_emails_mock.called is False
def test_index_program_enrolled_users(self):
"""
When we run the index_program_enrolled_users task we should index them and send them automatic emails
"""
enrollments = [ProgramEnrollmentFactory.create() for _ in range(2)]
enrollment_ids = [enrollment.id for enrollment in enrollments]
index_program_enrolled_users(enrollment_ids)
assert list(
self.index_program_enrolled_users_mock.call_args[0][0].values_list('id', flat=True)
) == enrollment_ids
for enrollment in enrollments:
self.send_automatic_emails_mock.assert_any_call(enrollment)
self.update_percolate_memberships_mock.assert_any_call(
enrollment.user, PercolateQuery.DISCUSSION_CHANNEL_TYPE)
self.refresh_index_mock.assert_called_with()
def test_failed_automatic_email(self):
"""
If we fail to send automatic email for one enrollment we should still send them for other enrollments
"""
enrollments = [ProgramEnrollmentFactory.create() for _ in range(2)]
enrollment_ids = [enrollment.id for enrollment in enrollments]
self.send_automatic_emails_mock.side_effect = fail_first()
index_program_enrolled_users(enrollment_ids)
assert list(
self.index_program_enrolled_users_mock.call_args[0][0].values_list('id', flat=True)
) == enrollment_ids
for enrollment in enrollments:
self.send_automatic_emails_mock.assert_any_call(enrollment)
self.update_percolate_memberships_mock.assert_any_call(
enrollment.user, PercolateQuery.DISCUSSION_CHANNEL_TYPE
)
assert self.send_automatic_emails_mock.call_count == len(enrollments)
assert self.update_percolate_memberships_mock.call_count == len(enrollments)
self.refresh_index_mock.assert_called_with()
def test_failed_update_percolate_memberships(self):
"""
If we fail to update percolate memberships for one enrollment we should still update it for other enrollments
"""
enrollments = [ProgramEnrollmentFactory.create() for _ in range(2)]
enrollment_ids = [enrollment.id for enrollment in enrollments]
self.update_percolate_memberships_mock.side_effect = fail_first()
index_program_enrolled_users(enrollment_ids)
assert list(
self.index_program_enrolled_users_mock.call_args[0][0].values_list('id', flat=True)
) == enrollment_ids
for enrollment in enrollments:
self.send_automatic_emails_mock.assert_any_call(enrollment)
self.update_percolate_memberships_mock.assert_any_call(
enrollment.user, PercolateQuery.DISCUSSION_CHANNEL_TYPE
)
assert self.send_automatic_emails_mock.call_count == len(enrollments)
assert self.update_percolate_memberships_mock.call_count == len(enrollments)
self.refresh_index_mock.assert_called_with()
def test_start_recreate_index(mocker, mocked_celery):
"""
recreate_index should recreate the opensearch index and reindex all data with it
"""
settings.OPENSEARCH_INDEXING_CHUNK_SIZE = 2
enrollments = sorted(ProgramEnrollmentFactory.create_batch(4), key=lambda enrollment: enrollment.id)
percolates = sorted(PercolateQueryFactory.create_batch(4), key=lambda percolate: percolate.id)
index_enrollments_mock = mocker.patch("search.tasks.bulk_index_program_enrollments", autospec=True)
index_percolates_mock = mocker.patch("search.tasks.bulk_index_percolate_queries", autospec=True)
test_backing_indices = create_backing_indices()
enrollment_public_index = test_backing_indices[0][0]
enrollment_private_index = test_backing_indices[1][0]
percolate_index = test_backing_indices[2][0]
finish_recreate_index_mock = mocker.patch(
"search.tasks.finish_recreate_index", autospec=True
)
with pytest.raises(mocked_celery.replace_exception_class):
start_recreate_index(test_backing_indices)
# Celery's 'group' function takes a generator as an argument. In order to make assertions about the items
# in that generator, 'list' is being called to force iteration through all of those items.
list(mocked_celery.group.call_args[0][0])
assert mocked_celery.group.call_count == 1
finish_recreate_index_mock.s.assert_called_once_with(test_backing_indices)
assert index_enrollments_mock.si.call_count == 2
index_enrollments_mock.si.assert_any_call([enrollments[0].id, enrollments[1].id], enrollment_public_index,
enrollment_private_index)
index_enrollments_mock.si.assert_any_call([enrollments[2].id, enrollments[3].id], enrollment_public_index,
enrollment_private_index)
assert index_percolates_mock.si.call_count == 2
index_percolates_mock.si.assert_any_call([percolates[0].id, percolates[1].id], percolate_index)
index_percolates_mock.si.assert_any_call([percolates[2].id, percolates[3].id], percolate_index)
assert mocked_celery.replace.call_count == 1
assert mocked_celery.replace.call_args[0][1] == mocked_celery.chain.return_value
def test_bulk_index_program_enrollments(mocker):
"""
bulk_index_program_enrollments should index the user program enrollments correctly
"""
enrollments = ProgramEnrollmentFactory.create_batch(2)
enrollment_ids = [enrollment.id for enrollment in enrollments]
index_enrollments_mock = mocker.patch("search.tasks._index_program_enrolled_users", autospec=True)
test_backing_indices = create_backing_indices()
enrollment_public_index = test_backing_indices[0][0]
enrollment_private_index = test_backing_indices[1][0]
bulk_index_program_enrollments(enrollment_ids, enrollment_public_index, enrollment_private_index)
assert index_enrollments_mock.call_count == 1
def test_bulk_index_percolate_queries(mocker):
"""
bulk_index_percolate_queries should index the percolate queries correctly
"""
percolates = PercolateQueryFactory.create_batch(2)
percolate_ids = [percolate.id for percolate in percolates]
percolate_index_chunk_mock = mocker.patch("search.tasks._index_chunks", autospec=True)
test_backing_indices = create_backing_indices()
percolate_index = test_backing_indices[2][0]
bulk_index_percolate_queries(percolate_ids, percolate_index)
assert percolate_index_chunk_mock.call_count == 1
@pytest.mark.parametrize("with_error", [True, False])
def test_finish_recreate_index(mocker, with_error):
"""
finish_recreate_index should clear and delete all the backing indices
"""
refresh_index_mock = mocker.patch("search.tasks.refresh_index", autospec=True)
delete_backing_indices_mock = mocker.patch("search.tasks.delete_backing_indices", autospec=True)
results = ["error"] if with_error else []
test_backing_indices = create_backing_indices()
if with_error:
with pytest.raises(ReindexException):
finish_recreate_index(results, test_backing_indices)
assert delete_backing_indices_mock.call_count == 1
else:
finish_recreate_index(results, test_backing_indices)
assert refresh_index_mock.call_count == len(test_backing_indices)
assert delete_backing_indices_mock.call_count == 1