Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly specify DB connection when we need to, and don't when we don't need to #8224

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path

from django.core.management.base import BaseCommand
from django.db import connection
from polling_stations.db_routers import get_principal_db_connection


class Command(BaseCommand):
Expand All @@ -16,7 +16,7 @@ def add_arguments(self, parser):
)

def handle(self, *args, **kwargs):
self.cursor = connection.cursor()
self.cursor = get_principal_db_connection().cursor()
# Set where we'll write the join query to.
if kwargs["destination"]:
self.dst = Path(kwargs["destination"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from addressbase.models import Address, UprnToCouncil
from django.core.management.base import BaseCommand
from django.db import connection
from polling_stations.db_routers import get_principal_db_connection


class Command(BaseCommand):
Expand All @@ -22,7 +22,7 @@ def handle(self, *args, **kwargs):
if not self.path.exists():
raise FileNotFoundError(f"No csv found at {kwargs['path']}")

cursor = connection.cursor()
cursor = get_principal_db_connection().cursor()
self.stdout.write("clearing existing data..")
cursor.execute("TRUNCATE TABLE %s;" % (self.table_name))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def handle(self, *args, **options):
uprntocouncil_updater.build_temp_indexes()

# Perform the table swaps in a single transaction
with transaction.atomic():
with transaction.atomic(using=database_name):
self.stdout.write("Starting atomic transaction for table swaps...")

# Drop all foreign keys first
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from unittest.mock import patch

from django.core.management import call_command
from django.db import connection
from django.test import TestCase, TransactionTestCase
from uk_geo_utils.base_importer import BaseImporter

Expand All @@ -14,10 +13,11 @@
from councils.tests.factories import CouncilFactory
from pollingstations.models import PollingStation
from pollingstations.tests.factories import PollingStationFactory
from polling_stations.db_routers import get_principal_db_connection


def get_primary_key_name(table):
with connection.cursor() as cursor:
with get_principal_db_connection().cursor() as cursor:
cursor.execute(f"""
SELECT conname
FROM pg_constraint
Expand All @@ -32,7 +32,7 @@ def get_primary_key_name(table):


def get_foreign_key_names(table):
with connection.cursor() as cursor:
with get_principal_db_connection().cursor() as cursor:
cursor.execute(f"""
SELECT conname
FROM pg_constraint
Expand All @@ -44,7 +44,7 @@ def get_foreign_key_names(table):

class HelpersTest(TestCase):
def setUp(self):
with connection.cursor() as cursor:
with get_principal_db_connection().cursor() as cursor:
# Create table with named primary key
cursor.execute("""
CREATE TABLE foo (
Expand All @@ -53,7 +53,7 @@ def setUp(self):
CONSTRAINT foo_primary_key PRIMARY KEY (id)
);
""")
with connection.cursor() as cursor:
with get_principal_db_connection().cursor() as cursor:
# Create table with named foreign key
cursor.execute("""
CREATE TABLE bar (
Expand All @@ -64,7 +64,7 @@ def setUp(self):
""")

def tearDown(self):
with connection.cursor() as cursor:
with get_principal_db_connection().cursor() as cursor:
cursor.execute("""
DROP TABLE IF EXISTS foo CASCADE;
DROP TABLE IF EXISTS bar CASCADE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
from django.db import DEFAULT_DB_ALIAS, transaction
from requests.exceptions import HTTPError
from retry import retry
from polling_stations.db_routers import get_principal_db_name


from polling_stations.settings.constants.councils import (
COUNCIL_ID_FIELD,
NIR_IDS,
WELSH_COUNCIL_NAMES,
)

DB_NAME = get_principal_db_name()


def union_areas(a1, a2):
if not a1:
Expand Down Expand Up @@ -243,7 +247,7 @@ def import_councils_from_ec(self):

council.save()

@transaction.atomic
@transaction.atomic(using=DB_NAME)
def handle(self, **options):
"""
Manually run system checks for the
Expand Down
7 changes: 5 additions & 2 deletions polling_stations/apps/councils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
from pollingstations.models import PollingStation

from polling_stations.i18n.cy import WelshNameMutationMixin
from polling_stations.db_routers import get_principal_db_name

DB_NAME = get_principal_db_name()


class UnsafeToDeleteCouncil(Exception):
pass


class CouncilQueryset(models.QuerySet):
@transaction.atomic
@transaction.atomic(using=DB_NAME)
def delete(self, force_cascade=False):
if force_cascade:
return super().delete()
Expand Down Expand Up @@ -113,7 +116,7 @@ class Meta:
def save(
self, force_insert=False, force_update=False, using=None, update_fields=None
):
with transaction.atomic():
with transaction.atomic(using=DB_NAME):
new = self._state.adding
super().save(
force_insert=force_insert,
Expand Down
7 changes: 5 additions & 2 deletions polling_stations/apps/data_importers/base_importers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from file_uploads.models import File, Upload
from pollingstations.models import PollingDistrict, PollingStation
from uk_geo_utils.helpers import Postcode
from polling_stations.db_routers import get_principal_db_name

DB_NAME = get_principal_db_name()


class CsvMixin:
Expand Down Expand Up @@ -116,7 +119,7 @@ def add_arguments(self, parser):
)

def teardown(self, council):
with transaction.atomic():
with transaction.atomic(using=DB_NAME):
super().teardown(council)
PollingStation.objects.filter(council=council).delete()
PollingDistrict.objects.filter(council=council).delete()
Expand Down Expand Up @@ -296,7 +299,7 @@ def handle(self, *args, **kwargs):

self.base_folder_path = self.get_base_folder_path()

with transaction.atomic():
with transaction.atomic(using=DB_NAME):
self.import_data()
self.record_import_event()
self.council.update_all_station_visibilities_from_events(self.election_dates)
Expand Down
151 changes: 66 additions & 85 deletions polling_stations/apps/data_importers/data_quality_report.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from addressbase.models import UprnToCouncil
from councils.models import Council
from django.db import connection
from django.db.models import Q
from pollingstations.models import PollingDistrict, PollingStation
from rich.console import Console
Expand Down Expand Up @@ -45,38 +44,32 @@ def get_stations_without_district_id(self):
).count()

def get_stations_with_valid_district_id_ref(self):
cursor = connection.cursor()
cursor.execute(
"""
SELECT COUNT(*) FROM pollingstations_pollingstation
WHERE polling_district_id IN
(SELECT internal_council_id FROM pollingstations_pollingdistrict
WHERE council_id IN %s)
AND council_id IN %s
AND polling_district_id != ''
AND polling_district_id IS NOT NULL;
""",
[tuple(self.councils), tuple(self.councils)],
return (
PollingStation.objects.filter(
polling_district_id__in=PollingDistrict.objects.filter(
council_id__in=self.councils
).values_list("internal_council_id", flat=True),
council_id__in=self.councils,
)
.exclude(polling_district_id__isnull=True)
.exclude(polling_district_id="")
.count()
)
results = cursor.fetchall()
return results[0][0]

def get_stations_with_invalid_district_id_ref(self):
cursor = connection.cursor()
cursor.execute(
"""
SELECT COUNT(*) FROM pollingstations_pollingstation
WHERE polling_district_id NOT IN
(SELECT internal_council_id FROM pollingstations_pollingdistrict
WHERE council_id IN %s)
AND council_id IN %s
AND polling_district_id != ''
AND polling_district_id IS NOT NULL;
""",
[tuple(self.councils), tuple(self.councils)],
return (
PollingStation.objects.filter(
~Q(
polling_district_id__in=PollingDistrict.objects.filter(
council_id__in=self.councils
).values_list("internal_council_id", flat=True)
),
council_id__in=self.councils,
)
.exclude(polling_district_id__isnull=True)
.exclude(polling_district_id="")
.count()
)
results = cursor.fetchall()
return results[0][0]

def get_stations_with_point(self):
return PollingStation.objects.filter(
Expand Down Expand Up @@ -166,38 +159,32 @@ def get_districts_without_station_id(self):
).count()

def get_districts_with_valid_station_id_ref(self):
cursor = connection.cursor()
cursor.execute(
"""
SELECT COUNT(*) FROM pollingstations_pollingdistrict
WHERE polling_station_id IN
(SELECT internal_council_id FROM pollingstations_pollingstation
WHERE council_id IN %s)
AND council_id IN %s
AND polling_station_id != ''
AND polling_station_id IS NOT NULL;
""",
[tuple(self.councils), tuple(self.councils)],
return (
PollingDistrict.objects.filter(
polling_station_id__in=PollingStation.objects.filter(
council_id__in=self.councils
).values_list("internal_council_id", flat=True),
council_id__in=self.councils,
)
.exclude(polling_station_id__isnull=True)
.exclude(polling_station_id="")
.count()
)
results = cursor.fetchall()
return results[0][0]

def get_districts_with_invalid_station_id_ref(self):
cursor = connection.cursor()
cursor.execute(
"""
SELECT COUNT(*) FROM pollingstations_pollingdistrict
WHERE polling_station_id NOT IN
(SELECT internal_council_id FROM pollingstations_pollingstation
WHERE council_id IN %s)
AND council_id IN %s
AND polling_station_id != ''
AND polling_station_id IS NOT NULL;
""",
[tuple(self.councils), tuple(self.councils)],
return (
PollingDistrict.objects.filter(
~Q(
polling_station_id__in=PollingStation.objects.filter(
council_id__in=self.councils
).values_list("internal_council_id", flat=True)
),
council_id__in=self.councils,
)
.exclude(polling_station_id__isnull=True)
.exclude(polling_station_id="")
.count()
)
results = cursor.fetchall()
return results[0][0]

def generate_counts(self):
districts = PollingDistrict.objects.filter(council_id__in=self.councils)
Expand Down Expand Up @@ -261,38 +248,32 @@ def get_addresses_without_station_id(self):
).count()

def get_addresses_with_valid_station_id_ref(self):
cursor = connection.cursor()
cursor.execute(
"""
SELECT COUNT(*) FROM addressbase_uprntocouncil
WHERE polling_station_id IN
(SELECT internal_council_id FROM pollingstations_pollingstation
WHERE council_id IN %s)
AND lad IN %s
AND polling_station_id != ''
AND polling_station_id IS NOT NULL;
""",
[tuple(self.councils), tuple(self.gss_codes)],
return (
UprnToCouncil.objects.filter(
polling_station_id__in=PollingStation.objects.filter(
council_id__in=self.councils
).values_list("internal_council_id", flat=True),
lad__in=self.gss_codes,
)
.exclude(polling_station_id__isnull=True)
.exclude(polling_station_id="")
.count()
)
results = cursor.fetchall()
return results[0][0]

def get_addresses_with_invalid_station_id_ref(self):
cursor = connection.cursor()
cursor.execute(
"""
SELECT COUNT(*) FROM addressbase_uprntocouncil
WHERE polling_station_id NOT IN
(SELECT internal_council_id FROM pollingstations_pollingstation
WHERE council_id IN %s)
AND lad IN %s
AND polling_station_id != ''
AND polling_station_id IS NOT NULL;
""",
[tuple(self.councils), tuple(self.gss_codes)],
return (
UprnToCouncil.objects.filter(
~Q(
polling_station_id__in=PollingStation.objects.filter(
council_id__in=self.councils
).values_list("internal_council_id", flat=True)
),
lad__in=self.gss_codes,
)
.exclude(polling_station_id__isnull=True)
.exclude(polling_station_id="")
.count()
)
results = cursor.fetchall()
return results[0][0]


# generate all the stats
Expand Down
Loading