Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(API): Imports: nouvelle class CSVImportApiView générique (qui check le format, la taille max, et si déjà uploadé) #4936

Open
wants to merge 5 commits into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions api/filters/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from django.db.models import F
from django.db.models.constants import LOOKUP_SEP
from djangorestframework_camel_case.settings import api_settings
from djangorestframework_camel_case.util import camel_to_underscore
from rest_framework import filters


class MaCantineOrderingFilter(filters.OrderingFilter):
"""
Allows filtering with camel case parameters. More info :
https://github.com/vbabiy/djangorestframework-camel-case/issues/87
Also sets null values last as opposed to DRF's default value of
setting them first.
"""

def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request, queryset, view)

def make_f_object(x):
return F(x[1:]).desc(nulls_last=True) if x[0] == "-" else F(x).asc(nulls_first=True)

if ordering:
ordering = map(make_f_object, ordering)
queryset = queryset.order_by(*ordering)

return queryset

def get_ordering(self, request, queryset, view):
params = request.query_params.get(self.ordering_param)
if params:
fields = [
camel_to_underscore(
field.strip(),
**api_settings.JSON_UNDERSCOREIZE,
)
for field in params.split(",")
]
ordering = self.remove_invalid_fields(queryset, fields, view, request)
if ordering:
return ordering

return self.get_default_ordering(view)


class UnaccentSearchFilter(filters.SearchFilter):
def construct_search(self, field_name, queryset):
lookup = self.lookup_prefixes.get(field_name[0])
if lookup:
field_name = field_name[1:]
else:
lookup = "icontains"
return LOOKUP_SEP.join(
[
field_name,
"unaccent",
lookup,
]
)
3 changes: 1 addition & 2 deletions api/views/blog.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.response import Response

from api.filters.utils import UnaccentSearchFilter
from api.serializers import BlogPostSerializer
from data.models import BlogPost

from .utils import UnaccentSearchFilter

logger = logging.getLogger(__name__)


Expand Down
5 changes: 2 additions & 3 deletions api/views/canteen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from rest_framework.views import APIView

from api.exceptions import DuplicateException
from api.filters.utils import MaCantineOrderingFilter, UnaccentSearchFilter
from api.permissions import (
IsAuthenticated,
IsAuthenticatedOrTokenHasResourceScope,
Expand All @@ -63,7 +64,7 @@
PublicCanteenSerializer,
SatelliteCanteenSerializer,
)
from api.views.utils import update_change_reason_with_auth
from api.views.utils import camelize, update_change_reason_with_auth
from common.utils import get_token_sirene, send_mail
from data.department_choices import Department
from data.models import (
Expand All @@ -80,8 +81,6 @@
fetch_geo_data_from_api_insee_sirene_by_siret,
)

from .utils import MaCantineOrderingFilter, UnaccentSearchFilter, camelize

logger = logging.getLogger(__name__)
redis = r.from_url(settings.REDIS_URL, decode_responses=True)

Expand Down
17 changes: 6 additions & 11 deletions api/views/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from rest_framework.exceptions import NotFound, PermissionDenied
from rest_framework.generics import CreateAPIView, ListAPIView, UpdateAPIView
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.views import APIView

from api.exceptions import DuplicateException
from api.permissions import (
Expand All @@ -23,7 +22,7 @@
IsCanteenManager,
)
from api.serializers import DiagnosticAndCanteenSerializer, ManagerDiagnosticSerializer
from api.views.utils import update_change_reason_with_auth
from api.views.utils import CSVImportApiView, update_change_reason_with_auth
from common.utils import send_mail
from data.models import Canteen, Teledeclaration
from data.models.diagnostic import Diagnostic
Expand Down Expand Up @@ -93,13 +92,14 @@ def perform_update(self, serializer):
update_change_reason_with_auth(self, diagnostic)


class EmailDiagnosticImportFileView(APIView):
class EmailDiagnosticImportFileView(CSVImportApiView):
permission_classes = [IsAuthenticated]

def post(self, request):
try:
file = request.data["file"]
self._verify_file_size(file)
self.file = request.data["file"]
super()._verify_file_size()
super()._verify_file_format()
email = request.data.get("email", request.user.email).strip()
context = {
"from": email,
Expand All @@ -111,7 +111,7 @@ def post(self, request):
to=[settings.CONTACT_EMAIL],
reply_to=[email],
template="unusual_diagnostic_import_file",
attachments=[(file.name, file.read(), file.content_type)],
attachments=[(self.file.name, self.file.read(), self.file.content_type)],
context=context,
)
except ValidationError as e:
Expand All @@ -127,11 +127,6 @@ def post(self, request):

return HttpResponse()

@staticmethod
def _verify_file_size(file):
if file.size > settings.CSV_IMPORT_MAX_SIZE:
raise ValidationError("Ce fichier est trop grand, merci d'utiliser un fichier de moins de 10Mo")


class DiagnosticsToTeledeclarePagination(LimitOffsetPagination):
default_limit = 100
Expand Down
21 changes: 4 additions & 17 deletions api/views/diagnosticimport.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from decimal import Decimal, InvalidOperation

import requests
from django.conf import settings
from django.contrib.auth import get_user_model
from django.core.exceptions import ValidationError
from django.core.validators import validate_email
Expand All @@ -16,11 +15,11 @@
from django.http import JsonResponse
from rest_framework import status
from rest_framework.exceptions import PermissionDenied
from rest_framework.views import APIView
from simple_history.utils import update_change_reason

from api.permissions import IsAuthenticated
from api.serializers import FullCanteenSerializer
from api.views.utils import CSVImportApiView
from data.models import Canteen, ImportFailure, ImportType, Sector
from data.models.diagnostic import Diagnostic
from data.models.teledeclaration import Teledeclaration
Expand All @@ -31,7 +30,7 @@
logger = logging.getLogger(__name__)


class ImportDiagnosticsView(ABC, APIView):
class ImportDiagnosticsView(ABC, CSVImportApiView):
permission_classes = [IsAuthenticated]
value_error_regex = re.compile(r"Field '(.+)' expected .+? got '(.+)'.")
annotated_sectors = Sector.objects.annotate(name_lower=Lower("name"))
Expand Down Expand Up @@ -80,8 +79,8 @@ def post(self, request):
try:
with transaction.atomic():
self.file = request.data["file"]
ImportDiagnosticsView._verify_file_format(self.file)
ImportDiagnosticsView._verify_file_size(self.file)
super()._verify_file_size()
super()._verify_file_format()
self._process_file(self.file)

if self.errors:
Expand Down Expand Up @@ -113,18 +112,6 @@ def _log_error(self, message, level="warning"):
import_type=self.import_type,
)

@staticmethod
def _verify_file_format(file):
if file.content_type != "text/csv" and file.content_type != "text/tab-separated-values":
raise ValidationError(
f"Ce fichier est au format {file.content_type}, merci d'exporter votre fichier au format CSV et réessayer."
)

@staticmethod
def _verify_file_size(file):
if file.size > settings.CSV_IMPORT_MAX_SIZE:
raise ValidationError("Ce fichier est trop grand, merci d'utiliser un fichier de moins de 10Mo")

def check_admin_values(self, header):
is_admin_import = any("admin_" in column for column in header)
if is_admin_import and not self.request.user.is_staff:
Expand Down
3 changes: 1 addition & 2 deletions api/views/partner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.response import Response

from api.filters.utils import UnaccentSearchFilter
from api.serializers import (
PartnerContactSerializer,
PartnerSerializer,
PartnerShortSerializer,
)
from data.models import Partner

from .utils import UnaccentSearchFilter

logger = logging.getLogger(__name__)


Expand Down
3 changes: 1 addition & 2 deletions api/views/purchase.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from rest_framework.response import Response
from rest_framework.views import APIView

from api.filters.utils import MaCantineOrderingFilter, UnaccentSearchFilter
from api.permissions import IsAuthenticated, IsCanteenManager, IsLinkedCanteenManager
from api.serializers import (
PurchaseExportSerializer,
Expand All @@ -24,8 +25,6 @@
)
from data.models import Canteen, Diagnostic, Purchase

from .utils import MaCantineOrderingFilter, UnaccentSearchFilter

logger = logging.getLogger(__name__)


Expand Down
27 changes: 8 additions & 19 deletions api/views/purchaseimport.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import csv
import hashlib
import io
import json
import logging
Expand All @@ -13,19 +12,18 @@
from django.http import JsonResponse
from rest_framework import status
from rest_framework.exceptions import PermissionDenied
from rest_framework.views import APIView

from api.permissions import IsAuthenticated
from api.serializers import PurchaseSerializer
from api.views.utils import CSVImportApiView
from data.models import Canteen, ImportFailure, ImportType, Purchase

from .diagnosticimport import ImportDiagnosticsView
from .utils import camelize, decode_bytes, normalise_siret

logger = logging.getLogger(__name__)


class ImportPurchasesView(APIView):
class ImportPurchasesView(CSVImportApiView):
permission_classes = [IsAuthenticated]
max_error_items = 30

Expand All @@ -50,8 +48,12 @@ def post(self, request):
logger.info("Purchase bulk import started")
try:
self.file = request.data["file"]
self._verify_file_size()
ImportDiagnosticsView._verify_file_format(self.file)
super()._verify_file_size()
super()._verify_file_format()

self.file_digest = super()._get_file_digest()
self._check_duplication()

with transaction.atomic():
self._process_file()

Expand All @@ -60,10 +62,6 @@ def post(self, request):
if self.errors:
raise IntegrityError()

# The duplication check is called after the processing. The cost of eventually processing
# the file for nothing appears to be smaller than read the file twice.
self._check_duplication()

# Update all purchases's import source with file digest
Purchase.objects.filter(import_source=self.tmp_id).update(import_source=self.file_digest)

Expand Down Expand Up @@ -100,13 +98,10 @@ def _log_error(self, message, level="warning"):
)

def _process_file(self):
file_hash = hashlib.md5()
chunk = []
read_header = True
row_count = 1
for row in self.file:
file_hash.update(row)

# Sniffing 1st line
if read_header:
# decode header, discarding encoding result that might not be accurate without more data
Expand All @@ -133,12 +128,6 @@ def _process_file(self):
if len(chunk) > 0:
self._process_chunk(chunk)

self.file_digest = file_hash.hexdigest()

def _verify_file_size(self):
if self.file.size > settings.CSV_IMPORT_MAX_SIZE:
raise ValidationError("Ce fichier est trop grand, merci d'utiliser un fichier de moins de 10Mo")

def _decode_chunk(self, chunk_list):
if self.encoding_detected is None:
chunk = b"".join(chunk_list)
Expand Down
Loading
Loading