Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

identify and prefetch N+1 queries in search/all for learner pathways #4488

Merged
merged 5 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions course_discovery/apps/api/v1/tests/test_views/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,28 +763,28 @@ def test_results_include_aggregation_key(self):
)
assert expected == actual

@ddt.data(True, False)
def test_learner_pathway_feature_flag(self, include_learner_pathways):
@ddt.data((True, 10, 8), (False, 0, 4))
@ddt.unpack
def test_learner_pathway_feature_flag(self, include_learner_pathways, expected_result_count, expected_query_count):
""" Verify the include_learner_pathways feature flag works as expected."""
LearnerPathwayStepFactory(pathway__partner=self.partner)
LearnerPathwayStepFactory.create_batch(10, pathway__partner=self.partner)
pathways = LearnerPathway.objects.all()
assert pathways.count() == 1
assert pathways.count() == 10
query = {
'include_learner_pathways': include_learner_pathways,
}

response = self.get_response(
query,
self.list_path
)
with self.assertNumQueries(expected_query_count):
response = self.get_response(query, self.list_path)

assert response.status_code == 200
response_data = response.json()

assert response_data['count'] == expected_result_count

if include_learner_pathways:
assert response_data['count'] == 1
assert response_data['results'][0] == self.serialize_learner_pathway_search(pathways[0])
else:
assert response_data['count'] == 0
for pathway in pathways:
assert self.serialize_learner_pathway_search(pathway) in response.data['results']


class LimitedAggregateSearchViewSetTests(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from django.conf import settings
from django.db.models import Prefetch
from django_elasticsearch_dsl import Index, fields

from course_discovery.apps.course_metadata.choices import CourseRunStatus
from course_discovery.apps.course_metadata.models import CourseRun
from course_discovery.apps.learner_pathway.choices import PathwayStatus
from course_discovery.apps.learner_pathway.models import LearnerPathway

Expand Down Expand Up @@ -50,10 +53,26 @@ def prepare_partner(self, obj):
def prepare_published(self, obj):
return obj.status == PathwayStatus.Active

def get_queryset(self, excluded_restriction_types=None): # pylint: disable=unused-argument
def get_queryset(self, excluded_restriction_types=None):
if excluded_restriction_types is None:
excluded_restriction_types = []

course_runs = CourseRun.objects.filter(
status=CourseRunStatus.Published
).exclude(
restricted_run__restriction_type__in=excluded_restriction_types
)

return super().get_queryset().prefetch_related(
'steps', 'steps__learnerpathwaycourse_set', 'steps__learnerpathwayprogram_set',
'steps__learnerpathwayblock_set',
'steps',
Prefetch(
'steps__learnerpathwaycourse_set__course__course_runs',
queryset=course_runs
),
Prefetch(
'steps__learnerpathwayprogram_set__program__courses__course_runs',
queryset=course_runs
)
)

def prepare_skill_names(self, obj):
Expand Down
12 changes: 2 additions & 10 deletions course_discovery/apps/learner_pathway/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
"""
from rest_framework import serializers

from course_discovery.apps.api.utils import get_excluded_restriction_types
from course_discovery.apps.course_metadata.choices import CourseRunStatus
from course_discovery.apps.learner_pathway import models


Expand All @@ -20,12 +18,7 @@ class Meta:
fields = ('key', 'course_runs')

def get_course_runs(self, obj):
excluded_restriction_types = get_excluded_restriction_types(self.context['request'])
return list(obj.course.course_runs.filter(
status=CourseRunStatus.Published
).exclude(
restricted_run__restriction_type__in=excluded_restriction_types
).values('key'))
return [{'key': course_run.key} for course_run in obj.course.course_runs.all()]


class LearnerPathwayCourseSerializer(LearnerPathwayCourseMinimalSerializer):
Expand Down Expand Up @@ -87,8 +80,7 @@ def get_card_image_url(self, step):
return program.card_image_url

def get_courses(self, obj):
excluded_restriction_types = get_excluded_restriction_types(self.context['request'])
return obj.get_linked_courses_and_course_runs(excluded_restriction_types=excluded_restriction_types)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these removed from serializer? This should be independent from document changes.

return obj.get_linked_courses_and_course_runs()


class LearnerPathwayBlockSerializer(serializers.ModelSerializer):
Expand Down
Loading
Loading