Skip to content

Commit

Permalink
feat: DIA-771: Add validate methods for glob patterns if it matches a…
Browse files Browse the repository at this point in the history
…t least something (#5178)

* feat: DIA-771: Add validate methods for glob patterns if it matches at least something

* Update utils.py

* Clean the code and add comments

* Update utils.py

* Implement changes for looking through all records

* rename new settings

* update error msg for consistency with success message

* run blue

---------

Co-authored-by: hakan458 <[email protected]>
  • Loading branch information
KonstantinKorotaev and hakan458 authored Jan 8, 2024
1 parent 7b628fa commit 24e02f6
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 0 deletions.
3 changes: 3 additions & 0 deletions label_studio/core/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,6 @@ def collect_versions_dummy(**kwargs):
CSP_INCLUDE_NONCE_IN = ['script-src', 'default-src']

MIDDLEWARE.append('core.middleware.HumanSignalCspMiddleware')

CLOUD_STORAGE_CHECK_FOR_RECORDS_PAGE_SIZE = get_env('CLOUD_STORAGE_CHECK_FOR_RECORDS_PAGE_SIZE', 10000)
CLOUD_STORAGE_CHECK_FOR_RECORDS_TIMEOUT = get_env('CLOUD_STORAGE_CHECK_FOR_RECORDS_TIMEOUT', 60)
40 changes: 40 additions & 0 deletions label_studio/io_storages/azure_blob/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import fnmatch
import logging
import re

from azure.storage.blob import BlobServiceClient
from core.utils.params import get_env
from django.conf import settings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,3 +45,40 @@ def get_blob_metadata(cls, url: str, container: str, account_name: str = None, a
_, container = cls.get_client_and_container(container, account_name=account_name, account_key=account_key)
blob = container.get_blob_client(url)
return dict(blob.get_blob_properties())

@classmethod
def validate_pattern(cls, storage, pattern, glob_pattern=True):
"""
Validate pattern against Azure Blob Storage
:param storage: AzureBlobStorage instance
:param pattern: Pattern to validate
:param glob_pattern: If True, pattern is a glob pattern, otherwise it is a regex pattern
:return: Message if pattern is not valid, empty string otherwise
"""
logger.debug('Validating Azure Blob Storage pattern.')
client, container = storage.get_client_and_container()
if storage.prefix:
generator = container.list_blob_names(
name_starts_with=storage.prefix,
results_per_page=settings.CLOUD_STORAGE_CHECK_FOR_RECORDS_PAGE_SIZE,
timeout=settings.CLOUD_STORAGE_CHECK_FOR_RECORDS_TIMEOUT,
)
else:
generator = container.list_blob_names(
results_per_page=settings.CLOUD_STORAGE_CHECK_FOR_RECORDS_PAGE_SIZE,
timeout=settings.CLOUD_STORAGE_CHECK_FOR_RECORDS_TIMEOUT,
)
# compile pattern to regex
if glob_pattern:
pattern = fnmatch.translate(pattern)
regex = re.compile(str(pattern))
# match pattern against all keys in the container
for index, key in enumerate(generator):
# skip directories
if key.endswith('/'):
logger.debug(key + ' is skipped because it is a folder')
continue
if regex and regex.match(key):
logger.debug(key + ' matches file pattern')
return ''
return 'No objects found matching the provided glob pattern'
29 changes: 29 additions & 0 deletions label_studio/io_storages/gcs/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import fnmatch
import json
import logging
import re
Expand Down Expand Up @@ -304,3 +305,31 @@ def get_blob_metadata(
if not properties_name:
return blob._properties
return {key: value for key, value in blob._properties.items() if key in properties_name}

@classmethod
def validate_pattern(cls, storage, pattern, glob_pattern=True):
"""
Validate pattern against Google Cloud Storage
:param storage: Google Cloud Storage instance
:param pattern: Pattern to validate
:param glob_pattern: If True, pattern is a glob pattern, otherwise it is a regex pattern
:return: Message if pattern is not valid, empty string otherwise
"""
client = storage.get_client()
blob_iter = client.list_blobs(
storage.bucket, prefix=storage.prefix, page_size=settings.CLOUD_STORAGE_CHECK_FOR_RECORDS_PAGE_SIZE
)
prefix = str(storage.prefix) if storage.prefix else ''
# compile pattern to regex
if glob_pattern:
pattern = fnmatch.translate(pattern)
regex = re.compile(str(pattern))
for index, blob in enumerate(blob_iter):
# skip directories
if blob.name == (prefix.rstrip('/') + '/'):
continue
# check regex pattern filter
if pattern and regex.match(blob.name):
logger.debug(blob.name + ' matches file pattern')
return ''
return 'No objects found matching the provided glob pattern'
38 changes: 38 additions & 0 deletions label_studio/io_storages/s3/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
import base64
import fnmatch
import logging
import re
from urllib.parse import urlparse

import boto3
from botocore.exceptions import ClientError
from core.utils.params import get_env
from django.conf import settings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -97,3 +100,38 @@ def get_blob_metadata(
metadata.pop('Body', None)
metadata.pop('ResponseMetadata', None)
return metadata

@classmethod
def validate_pattern(cls, storage, pattern, glob_pattern=True):
"""
Validate pattern against S3 Storage
:param storage: S3 Storage instance
:param pattern: Pattern to validate
:param glob_pattern: If True, pattern is a glob pattern, otherwise it is a regex pattern
:return: Message if pattern is not valid, empty string otherwise
"""
client, bucket = storage.get_client_and_bucket()
if glob_pattern:
pattern = fnmatch.translate(pattern)
regex = re.compile(pattern)

if storage.prefix:
list_kwargs = {'Prefix': storage.prefix.rstrip('/') + '/'}
if not storage.recursive_scan:
list_kwargs['Delimiter'] = '/'
bucket_iter = bucket.objects.filter(**list_kwargs)
else:
bucket_iter = bucket.objects

bucket_iter = bucket_iter.page_size(settings.CLOUD_STORAGE_CHECK_FOR_RECORDS_PAGE_SIZE).all()

for index, obj in enumerate(bucket_iter):
key = obj.key
# skip directories
if key.endswith('/'):
logger.debug(key + ' is skipped because it is a folder')
continue
if regex and regex.match(key):
logger.debug(key + ' matches file pattern')
return ''
return 'No objects found matching the provided glob pattern'

0 comments on commit 24e02f6

Please sign in to comment.