Skip to content

Commit

Permalink
add file max size validation
Browse files Browse the repository at this point in the history
  • Loading branch information
henrinie-nc committed Jan 21, 2025
1 parent a4fd0f3 commit 6eb8c62
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 1 deletion.
114 changes: 113 additions & 1 deletion utils/tests/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.utils.datastructures import MultiValueDict
from rest_framework.exceptions import ValidationError

from utils.viewsets.mixins import FileExtensionFileMixin
from utils.viewsets.mixins import MAX_FILE_SIZE_BYTES, FileExtensionFileMixin


@pytest.fixture
Expand Down Expand Up @@ -199,3 +199,115 @@ def test_file_extension_mixin_update_with_invalid_files(mixin):
)
with pytest.raises(ValidationError):
mixin.update(request)


def test_validate_file_size_valid(mixin):
bytes_under_max = b"a" * (MAX_FILE_SIZE_BYTES - 1)
files = MultiValueDict(
{
"file": [
SimpleUploadedFile(
"test.pdf",
bytes_under_max,
content_type="application/pdf",
)
]
}
)
try:
mixin._validate_file_size(files)
except ValidationError:
pytest.fail("_validate_file_size() raised ValidationError unexpectedly!")


def test_validate_file_size_exceeds_limit(mixin):
bytes_over_max = b"a" * (MAX_FILE_SIZE_BYTES + 1)
files = MultiValueDict(
{
"file": [
SimpleUploadedFile(
"test.pdf",
bytes_over_max,
content_type="application/pdf",
)
]
}
)
with pytest.raises(ValidationError):
mixin._validate_file_size(files)


def test_file_extension_mixin_create_with_valid_file_size(mixin):
bytes_under_max = b"a" * (MAX_FILE_SIZE_BYTES - 1)
request = MagicMock()
request.FILES = MultiValueDict(
{
"file": [
SimpleUploadedFile(
"test.pdf",
bytes_under_max,
content_type="application/pdf",
)
]
}
)
mixin.create = MagicMock(return_value="created")
response = mixin.create(request)
mixin.create.assert_called_once_with(request)
assert response == "created"


def test_file_extension_mixin_create_with_too_large_file_size(mixin):
bytes_over_max = b"a" * (MAX_FILE_SIZE_BYTES + 1)
request = MagicMock()
request.FILES = MultiValueDict(
{
"file": [
SimpleUploadedFile(
"test.pdf",
bytes_over_max,
content_type="application/pdf",
)
]
}
)
with pytest.raises(ValidationError):
mixin.create(request)


def test_file_extension_mixin_update_with_valid_file_size(mixin):
bytes_under_max = b"a" * (MAX_FILE_SIZE_BYTES - 1)
request = MagicMock()
request.FILES = MultiValueDict(
{
"file": [
SimpleUploadedFile(
"test.pdf",
bytes_under_max,
content_type="application/pdf",
)
]
}
)
mixin.update = MagicMock(return_value="updated")
response = mixin.update(request)
mixin.update.assert_called_once_with(request)
assert response == "updated"


def test_file_extension_mixin_update_with_invalid_file_size(mixin):
bytes_over_max = b"a" * (MAX_FILE_SIZE_BYTES + 1)
request = MagicMock()
request.FILES = MultiValueDict(
{
"file": [
SimpleUploadedFile(
"test.pdf",
bytes_over_max,
content_type="application/pdf",
)
]
}
)
with pytest.raises(ValidationError):
mixin.update(request)
20 changes: 20 additions & 0 deletions utils/viewsets/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from rest_framework.exceptions import ValidationError
from rest_framework.response import Response

MAX_FILE_SIZE_MB = 20
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024


class FileDownloadMixin:
@action(methods=["get"], detail=True)
Expand Down Expand Up @@ -128,17 +131,34 @@ def _validate_file_extensions(self, files: MultiValueDict) -> NoReturn:
if ext not in allowed_extensions:
raise ValidationError(_(f"File extension '.{ext}' is not allowed."))

def _validate_file_size(self, files: MultiValueDict) -> NoReturn:
"""
Validate that the size of files do not exceed set maximum size.
Raises:
ValidationError: If a file is too large in size.
"""
for file in files.values():
if file.size > MAX_FILE_SIZE_BYTES:
raise ValidationError(
_(
f"File '{file.name}' exceeds maximum file size of {MAX_FILE_SIZE_MB} MB."
)
)

def create(self, request, *args, **kwargs):
files: MultiValueDict = getattr(request, "FILES")
if files and len(files) > 0:
self._validate_file_extensions(files)
self._validate_file_size(files)

return super().create(request, *args, **kwargs)

def update(self, request, *args, **kwargs):
files: MultiValueDict = getattr(request, "FILES")
if files and len(files) > 0:
self._validate_file_extensions(files)
self._validate_file_size(files)

return super().update(request, *args, **kwargs)

Expand Down

0 comments on commit 6eb8c62

Please sign in to comment.