diff --git a/.github/workflows/test_code_format.yml b/.github/workflows/test_code_format.yml new file mode 100644 index 000000000..504756df3 --- /dev/null +++ b/.github/workflows/test_code_format.yml @@ -0,0 +1,34 @@ +name: Code Format Check + +on: + push: + pull_request: + workflow_dispatch: + +jobs: + test: + name: Check + runs-on: ubuntu-latest + + steps: + - name: Checkout Code + uses: actions/checkout@v3 + - name: Hack to get setup-python to work on nektos/act + run: | + if [ ! -f "/etc/lsb-release" ] ; then + echo "DISTRIB_RELEASE=18.04" > /etc/lsb-release + fi + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - uses: actions/cache@v2 + with: + path: ${{ env.pythonLocation }} + key: cache_v2_${{ env.pythonLocation }}-${{ hashFiles('requirements/**') }} + - name: Install Dependencies and lightly + run: pip install -e '.[all]' + - name: Run Format Check + run: | + export LIGHTLY_SERVER_LOCATION="localhost:-1" + make format-check diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c486abf2b..14701028d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -89,13 +89,17 @@ Follow these steps to start contributing: 5. Develop the features on your branch. - As you work on the features, you should make sure that the test suite - passes: + As you work on the features, you should make sure that the code is formatted and the + test suite passes: ```bash - $ make test + $ make format + $ make all-checks ``` + If you get an error from isort or black, please run `make format` again before + running `make all-checks`. + If you're modifying documents under `docs/source`, make sure to validate that they can still be built. This check also runs in CI. diff --git a/Makefile b/Makefile index f81534b87..08245f02e 100644 --- a/Makefile +++ b/Makefile @@ -34,6 +34,17 @@ clean-out: clean-tox: rm -fr .tox +# format code with isort and black +format: + isort . + black . + +# check if code is formatted with isort and black +format-check: + @echo "⚫ Checking code format..." + isort --check-only --diff . + black --check . + # check style with flake8 lint: lint-lightly lint-tests @@ -49,6 +60,9 @@ lint-tests: test: pytest tests --runslow +# run format checks and tests +all-checks: format-check test + ## build source and wheel package dist: clean python setup.py sdist bdist_wheel diff --git a/docs/source/conf.py b/docs/source/conf.py index 2ca49beec..e3951afa4 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,23 +12,24 @@ # import os import sys -sys.path.insert(0, os.path.abspath('../..')) + +sys.path.insert(0, os.path.abspath("../..")) import sphinx_rtd_theme -import lightly +import lightly # -- Project information ----------------------------------------------------- -project = 'lightly' -copyright_year = '2020' +project = "lightly" +copyright_year = "2020" copyright = "Lightly AG" -website_url = 'https://www.lightly.ai/' -author = 'Philipp Wirth, Igor Susmelj' +website_url = "https://www.lightly.ai/" +author = "Philipp Wirth, Igor Susmelj" # The full version, including alpha/beta/rc tags release = lightly.__version__ -master_doc = 'index' +master_doc = "index" # -- General configuration --------------------------------------------------- @@ -44,13 +45,16 @@ "sphinx_tabs.tabs", "sphinx_copybutton", "sphinx_design", - 'sphinx_reredirects' + "sphinx_reredirects", ] sphinx_gallery_conf = { - 'examples_dirs': ['tutorials_source/package', 'tutorials_source/platform'], - 'gallery_dirs': ['tutorials/package', 'tutorials/platform'], # path to where to save gallery generated output - 'filename_pattern': '/tutorial_', + "examples_dirs": ["tutorials_source/package", "tutorials_source/platform"], + "gallery_dirs": [ + "tutorials/package", + "tutorials/platform", + ], # path to where to save gallery generated output + "filename_pattern": "/tutorial_", } napoleon_google_docstring = True @@ -67,7 +71,7 @@ napoleon_type_aliases = None # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -80,28 +84,28 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" html_theme_options = { - 'collapse_navigation': False, # set to false to prevent menu item collapse - 'logo_only': True + "collapse_navigation": False, # set to false to prevent menu item collapse + "logo_only": True, } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] -html_favicon = 'favicon.png' +html_favicon = "favicon.png" -html_logo = '../logos/lightly_logo_crop_white_text.png' +html_logo = "../logos/lightly_logo_crop_white_text.png" -# Exposes variables so that they can be used by django +# Exposes variables so that they can be used by django html_context = { - 'copyright_year': copyright_year, - 'website_url': website_url, + "copyright_year": copyright_year, + "website_url": website_url, } redirects = { - "docker/advanced/active_learning": "../../docker/getting_started/selection.html" -} \ No newline at end of file + "docker/advanced/active_learning": "../../docker/getting_started/selection.html" +} diff --git a/docs/source/docker/advanced/code_examples/load_model_from_checkpoint.py b/docs/source/docker/advanced/code_examples/load_model_from_checkpoint.py index 1c9808ed8..93c907cfc 100644 --- a/docs/source/docker/advanced/code_examples/load_model_from_checkpoint.py +++ b/docs/source/docker/advanced/code_examples/load_model_from_checkpoint.py @@ -1,17 +1,20 @@ from collections import OrderedDict + import torch + import lightly -def load_ckpt(ckpt_path, model_name='resnet-18', model_width=1, map_location='cpu'): + +def load_ckpt(ckpt_path, model_name="resnet-18", model_width=1, map_location="cpu"): ckpt = torch.load(ckpt_path, map_location=map_location) - + state_dict = OrderedDict() - for key, value in ckpt['state_dict'].items(): - if ('projection_head' in key) or ('backbone.7' in key): - # drop layers used for projection head + for key, value in ckpt["state_dict"].items(): + if ("projection_head" in key) or ("backbone.7" in key): + # drop layers used for projection head continue - state_dict[key.replace('model.backbone.', '')] = value - + state_dict[key.replace("model.backbone.", "")] = value + resnet = lightly.models.ResNetGenerator(name=model_name, width=model_width) model = torch.nn.Sequential( lightly.models.batchnorm.get_norm_layer(3, 0), @@ -23,28 +26,28 @@ def load_ckpt(ckpt_path, model_name='resnet-18', model_width=1, map_location='cp model.load_state_dict(state_dict) except RuntimeError: raise RuntimeError( - f'It looks like you tried loading a checkpoint from a model that is not a {model_name} with width={model_width}! ' - f'Please set model_name and model_width to the lightly.model.name and lightly.model.width parameters from the ' - f'configuration you used to run Lightly. The configuration from a Lightly worker run can be found in output_dir/config/config.yaml' + f"It looks like you tried loading a checkpoint from a model that is not a {model_name} with width={model_width}! " + f"Please set model_name and model_width to the lightly.model.name and lightly.model.width parameters from the " + f"configuration you used to run Lightly. The configuration from a Lightly worker run can be found in output_dir/config/config.yaml" ) return model + # loading the model -model = load_ckpt('output_dir/lightly_epoch_X.ckpt') +model = load_ckpt("output_dir/lightly_epoch_X.ckpt") # example usage image_batch = torch.rand(16, 3, 224, 224) out = model(image_batch) -print(out.shape) # prints: torch.Size([16, 512]) +print(out.shape) # prints: torch.Size([16, 512]) # creating a classifier from the pre-trained model num_classes = 10 classifier = torch.nn.Sequential( - model, - torch.nn.Linear(512, num_classes) # use 2048 instead of 512 for resnet-50 + model, torch.nn.Linear(512, num_classes) # use 2048 instead of 512 for resnet-50 ) out = classifier(image_batch) -print(out.shape) # prints: torch.Size(16, 10) +print(out.shape) # prints: torch.Size(16, 10) diff --git a/docs/source/docker/advanced/code_examples/python_create_dataset_azure_example.py b/docs/source/docker/advanced/code_examples/python_create_dataset_azure_example.py index 5caa7f4cc..4badd92a2 100644 --- a/docs/source/docker/advanced/code_examples/python_create_dataset_azure_example.py +++ b/docs/source/docker/advanced/code_examples/python_create_dataset_azure_example.py @@ -1,29 +1,29 @@ import json + import lightly from lightly.openapi_generated.swagger_client.models.dataset_type import DatasetType -from lightly.openapi_generated.swagger_client.models.datasource_purpose import DatasourcePurpose - +from lightly.openapi_generated.swagger_client.models.datasource_purpose import ( + DatasourcePurpose, +) # Create the Lightly client to connect to the API. client = lightly.api.ApiWorkflowClient(token="YOUR_TOKEN") # Create a new dataset on the Lightly Platform. -client.create_dataset('pedestrian-videos-datapool', - dataset_type=DatasetType.VIDEOS) +client.create_dataset("pedestrian-videos-datapool", dataset_type=DatasetType.VIDEOS) # Azure Blob Storage # Input bucket client.set_azure_config( - container_name='my-container/input/', - account_name='ACCOUNT-NAME', - sas_token='SAS-TOKEN', - purpose=DatasourcePurpose.INPUT + container_name="my-container/input/", + account_name="ACCOUNT-NAME", + sas_token="SAS-TOKEN", + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_azure_config( - container_name='my-container/output/', - account_name='ACCOUNT-NAME', - sas_token='SAS-TOKEN', - purpose=DatasourcePurpose.LIGHTLY + container_name="my-container/output/", + account_name="ACCOUNT-NAME", + sas_token="SAS-TOKEN", + purpose=DatasourcePurpose.LIGHTLY, ) - diff --git a/docs/source/docker/advanced/code_examples/python_create_dataset_gcs_example.py b/docs/source/docker/advanced/code_examples/python_create_dataset_gcs_example.py index a7d29d164..f8d86cc56 100644 --- a/docs/source/docker/advanced/code_examples/python_create_dataset_gcs_example.py +++ b/docs/source/docker/advanced/code_examples/python_create_dataset_gcs_example.py @@ -1,29 +1,29 @@ import json + import lightly from lightly.openapi_generated.swagger_client.models.dataset_type import DatasetType -from lightly.openapi_generated.swagger_client.models.datasource_purpose import DatasourcePurpose - +from lightly.openapi_generated.swagger_client.models.datasource_purpose import ( + DatasourcePurpose, +) # Create the Lightly client to connect to the API. client = lightly.api.ApiWorkflowClient(token="YOUR_TOKEN") # Create a new dataset on the Lightly Platform. -client.create_dataset('pedestrian-videos-datapool', - dataset_type=DatasetType.VIDEOS) +client.create_dataset("pedestrian-videos-datapool", dataset_type=DatasetType.VIDEOS) # Google Cloud Storage # Input bucket client.set_gcs_config( resource_path="gs://bucket/input/", project_id="PROJECT-ID", - credentials=json.dumps(json.load(open('credentials_read.json'))), - purpose=DatasourcePurpose.INPUT + credentials=json.dumps(json.load(open("credentials_read.json"))), + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_gcs_config( resource_path="gs://bucket/output/", project_id="PROJECT-ID", - credentials=json.dumps(json.load(open('credentials_write.json'))), - purpose=DatasourcePurpose.LIGHTLY + credentials=json.dumps(json.load(open("credentials_write.json"))), + purpose=DatasourcePurpose.LIGHTLY, ) - diff --git a/docs/source/docker/advanced/code_examples/python_create_dataset_s3_example.py b/docs/source/docker/advanced/code_examples/python_create_dataset_s3_example.py index b8741c6b1..8dea5a7b0 100644 --- a/docs/source/docker/advanced/code_examples/python_create_dataset_s3_example.py +++ b/docs/source/docker/advanced/code_examples/python_create_dataset_s3_example.py @@ -1,31 +1,31 @@ import json + import lightly from lightly.openapi_generated.swagger_client.models.dataset_type import DatasetType -from lightly.openapi_generated.swagger_client.models.datasource_purpose import DatasourcePurpose - +from lightly.openapi_generated.swagger_client.models.datasource_purpose import ( + DatasourcePurpose, +) # Create the Lightly client to connect to the API. client = lightly.api.ApiWorkflowClient(token="YOUR_TOKEN") # Create a new dataset on the Lightly Platform. -client.create_dataset('pedestrian-videos-datapool', - dataset_type=DatasetType.VIDEOS) +client.create_dataset("pedestrian-videos-datapool", dataset_type=DatasetType.VIDEOS) -# AWS S3 +# AWS S3 # Input bucket client.set_s3_config( resource_path="s3://bucket/input/", - region='eu-central-1', - access_key='S3-ACCESS-KEY', - secret_access_key='S3-SECRET-ACCESS-KEY', - purpose=DatasourcePurpose.INPUT + region="eu-central-1", + access_key="S3-ACCESS-KEY", + secret_access_key="S3-SECRET-ACCESS-KEY", + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_s3_config( resource_path="s3://bucket/output/", - region='eu-central-1', - access_key='S3-ACCESS-KEY', - secret_access_key='S3-SECRET-ACCESS-KEY', - purpose=DatasourcePurpose.LIGHTLY + region="eu-central-1", + access_key="S3-ACCESS-KEY", + secret_access_key="S3-SECRET-ACCESS-KEY", + purpose=DatasourcePurpose.LIGHTLY, ) - diff --git a/docs/source/docker/advanced/code_examples/python_create_frame_predictions.py b/docs/source/docker/advanced/code_examples/python_create_frame_predictions.py index df0d2a8fa..d71133616 100644 --- a/docs/source/docker/advanced/code_examples/python_create_frame_predictions.py +++ b/docs/source/docker/advanced/code_examples/python_create_frame_predictions.py @@ -1,19 +1,22 @@ -import av import json from pathlib import Path -from typing import List, Dict +from typing import Dict, List + +import av + +dataset_dir = Path("/datasets/my_dataset") +predictions_dir = dataset_dir / ".lightly" / "predictions" / "my_prediction_task" -dataset_dir = Path('/datasets/my_dataset') -predictions_dir = dataset_dir / '.lightly' / 'predictions' / 'my_prediction_task' def model_predict(frame) -> List[Dict]: # This function must be overwritten to generate predictions for a frame using # a prediction model of your choice. Here we just return an example prediction. - # See https://docs.lightly.ai/docker/advanced/datasource_predictions.html#prediction-format + # See https://docs.lightly.ai/docker/advanced/datasource_predictions.html#prediction-format # for possible prediction formats. - return [{'category_id': 0, 'bbox': [0, 10, 100, 30], 'score': 0.8}] + return [{"category_id": 0, "bbox": [0, 10, 100, 30], "score": 0.8}] -for video_path in dataset_dir.glob('**/*.mp4'): + +for video_path in dataset_dir.glob("**/*.mp4"): # get predictions for frames predictions = [] with av.open(str(video_path)) as container: @@ -21,51 +24,53 @@ def model_predict(frame) -> List[Dict]: for frame in container.decode(stream): predictions.append(model_predict(frame.to_image())) - # save predictions + # save predictions num_frames = len(predictions) zero_padding = len(str(num_frames)) for frame_index, frame_predictions in enumerate(predictions): - video_name = video_path.relative_to(dataset_dir).with_suffix('') - frame_name = Path(f'{video_name}-{frame_index:0{zero_padding}}-{video_path.suffix[1:]}.png') + video_name = video_path.relative_to(dataset_dir).with_suffix("") + frame_name = Path( + f"{video_name}-{frame_index:0{zero_padding}}-{video_path.suffix[1:]}.png" + ) prediction = { - 'file_name': str(frame_name), - 'predictions': frame_predictions, + "file_name": str(frame_name), + "predictions": frame_predictions, } - out_path = predictions_dir / frame_name.with_suffix('.json') + out_path = predictions_dir / frame_name.with_suffix(".json") out_path.parent.mkdir(parents=True, exist_ok=True) - with open(out_path, 'w') as file: + with open(out_path, "w") as file: json.dump(prediction, file) -# example directory structure before -# . -# ├── test -# │ └── video_0.mp4 -# └── train +# example directory structure before +# . +# ├── test +# │ └── video_0.mp4 +# └── train # ├── video_1.mp4 -#  └── video_2.mp4 +# └── video_2.mp4 # -# example directory structure after +# example directory structure after # . # ├── .lightly # │ └── predictions # │ └── my_prediction_task -# │ ├── test -# │ │ ├── video_0-000-mp4.json -# │ │ ├── video_0-001-mp4.json -# │ │ ├── video_0-002-mp4.json -# │ │ └── ... -# │ └── train -# │ ├── video_1-000-mp4.json -# │ ├── video_1-001-mp4.json -# │ ├── video_1-002-mp4.json -# | ├── ... -# | ├── video_2-000-mp4.json -# | ├── video_2-001-mp4.json -# | ├── video_2-002-mp4.json -# │ └── ... +# │ ├── test +# │ │ ├── video_0-000-mp4.json +# │ │ ├── video_0-001-mp4.json +# │ │ ├── video_0-002-mp4.json +# │ │ └── ... +# │ └── train +# │ ├── video_1-000-mp4.json +# │ ├── video_1-001-mp4.json +# │ ├── video_1-002-mp4.json +# | ├── ... +# | ├── video_2-000-mp4.json +# | ├── video_2-001-mp4.json +# | ├── video_2-002-mp4.json +# │ └── ... # ├── test -# │ └── video_0.mp4 -# └── train -#  ├── video_1.mp4 -#  └── video_2.mp4 +# │ └── video_0.mp4 +# └── train +# ├── video_1.mp4 +# └── video_2.mp4 diff --git a/docs/source/docker/advanced/code_examples/python_run_datapool_example.py b/docs/source/docker/advanced/code_examples/python_run_datapool_example.py index a986651ce..669c42ba3 100644 --- a/docs/source/docker/advanced/code_examples/python_run_datapool_example.py +++ b/docs/source/docker/advanced/code_examples/python_run_datapool_example.py @@ -1,77 +1,59 @@ import lightly - # Create the Lightly client to connect to the API. client = lightly.api.ApiWorkflowClient(token="YOUR_TOKEN") # Let's fetch the dataset we created above, by name -client.set_dataset_id_by_name('pedestrian-videos-datapool') +client.set_dataset_id_by_name("pedestrian-videos-datapool") # Schedule the compute run using our custom config. # We show here the full default config so you can easily edit the # values according to your needs. client.schedule_compute_worker_run( worker_config={ - 'enable_corruptness_check': True, - 'remove_exact_duplicates': True, - 'enable_training': False, - 'pretagging': False, - 'pretagging_debug': False, + "enable_corruptness_check": True, + "remove_exact_duplicates": True, + "enable_training": False, + "pretagging": False, + "pretagging_debug": False, }, selection_config={ "n_samples": 100, "strategies": [ { - "input": { - "type": "EMBEDDINGS" - }, + "input": {"type": "EMBEDDINGS"}, "strategy": { "type": "DIVERSITY", - "stopping_condition_minimum_distance": 0.1 - } + "stopping_condition_minimum_distance": 0.1, + }, } - ] + ], }, lightly_config={ - 'loader': { - 'batch_size': 128, - 'shuffle': True, - 'num_workers': -1, - 'drop_last': True - }, - 'model': { - 'name': 'resnet-18', - 'out_dim': 128, - 'num_ftrs': 32, - 'width': 1 + "loader": { + "batch_size": 128, + "shuffle": True, + "num_workers": -1, + "drop_last": True, }, - 'trainer': { - 'gpus': 1, - 'max_epochs': 1, - 'precision': 16 + "model": {"name": "resnet-18", "out_dim": 128, "num_ftrs": 32, "width": 1}, + "trainer": {"gpus": 1, "max_epochs": 1, "precision": 16}, + "criterion": {"temperature": 0.5}, + "optimizer": {"lr": 1, "weight_decay": 0.00001}, + "collate": { + "input_size": 64, + "cj_prob": 0.8, + "cj_bright": 0.7, + "cj_contrast": 0.7, + "cj_sat": 0.7, + "cj_hue": 0.2, + "min_scale": 0.15, + "random_gray_scale": 0.2, + "gaussian_blur": 0.0, + "kernel_size": 0.1, + "vf_prob": 0, + "hf_prob": 0.5, + "rr_prob": 0, }, - 'criterion': { - 'temperature': 0.5 - }, - 'optimizer': { - 'lr': 1, - 'weight_decay': 0.00001 - }, - 'collate': { - 'input_size': 64, - 'cj_prob': 0.8, - 'cj_bright': 0.7, - 'cj_contrast': 0.7, - 'cj_sat': 0.7, - 'cj_hue': 0.2, - 'min_scale': 0.15, - 'random_gray_scale': 0.2, - 'gaussian_blur': 0.0, - 'kernel_size': 0.1, - 'vf_prob': 0, - 'hf_prob': 0.5, - 'rr_prob': 0 - } - } + }, ) - diff --git a/docs/source/docker/advanced/code_examples/python_run_datapool_example_2.py b/docs/source/docker/advanced/code_examples/python_run_datapool_example_2.py index 5280707c0..d625a3e15 100644 --- a/docs/source/docker/advanced/code_examples/python_run_datapool_example_2.py +++ b/docs/source/docker/advanced/code_examples/python_run_datapool_example_2.py @@ -1,77 +1,59 @@ import lightly - # Create the Lightly client to connect to the API. client = lightly.api.ApiWorkflowClient(token="YOUR_TOKEN") # Let's fetch the dataset we created above, by name -client.set_dataset_id_by_name('pedestrian-videos-datapool') +client.set_dataset_id_by_name("pedestrian-videos-datapool") # Schedule the compute run using our custom config. # We show here the full default config so you can easily edit the # values according to your needs. client.schedule_compute_worker_run( worker_config={ - 'enable_corruptness_check': True, - 'remove_exact_duplicates': True, - 'enable_training': False, - 'pretagging': False, - 'pretagging_debug': False, + "enable_corruptness_check": True, + "remove_exact_duplicates": True, + "enable_training": False, + "pretagging": False, + "pretagging_debug": False, }, selection_config={ "n_samples": 100, "strategies": [ { - "input": { - "type": "EMBEDDINGS" - }, + "input": {"type": "EMBEDDINGS"}, "strategy": { "type": "DIVERSITY", - "stopping_condition_minimum_distance": 0.2 - } + "stopping_condition_minimum_distance": 0.2, + }, } - ] + ], }, lightly_config={ - 'loader': { - 'batch_size': 128, - 'shuffle': True, - 'num_workers': -1, - 'drop_last': True - }, - 'model': { - 'name': 'resnet-18', - 'out_dim': 128, - 'num_ftrs': 32, - 'width': 1 + "loader": { + "batch_size": 128, + "shuffle": True, + "num_workers": -1, + "drop_last": True, }, - 'trainer': { - 'gpus': 1, - 'max_epochs': 1, - 'precision': 16 + "model": {"name": "resnet-18", "out_dim": 128, "num_ftrs": 32, "width": 1}, + "trainer": {"gpus": 1, "max_epochs": 1, "precision": 16}, + "criterion": {"temperature": 0.5}, + "optimizer": {"lr": 1, "weight_decay": 0.00001}, + "collate": { + "input_size": 64, + "cj_prob": 0.8, + "cj_bright": 0.7, + "cj_contrast": 0.7, + "cj_sat": 0.7, + "cj_hue": 0.2, + "min_scale": 0.15, + "random_gray_scale": 0.2, + "gaussian_blur": 0.0, + "kernel_size": 0.1, + "vf_prob": 0, + "hf_prob": 0.5, + "rr_prob": 0, }, - 'criterion': { - 'temperature': 0.5 - }, - 'optimizer': { - 'lr': 1, - 'weight_decay': 0.00001 - }, - 'collate': { - 'input_size': 64, - 'cj_prob': 0.8, - 'cj_bright': 0.7, - 'cj_contrast': 0.7, - 'cj_sat': 0.7, - 'cj_hue': 0.2, - 'min_scale': 0.15, - 'random_gray_scale': 0.2, - 'gaussian_blur': 0.0, - 'kernel_size': 0.1, - 'vf_prob': 0, - 'hf_prob': 0.5, - 'rr_prob': 0 - } - } + }, ) - diff --git a/docs/source/docker/advanced/code_examples/python_run_object_level.py b/docs/source/docker/advanced/code_examples/python_run_object_level.py index fab9e1861..85239ae2b 100644 --- a/docs/source/docker/advanced/code_examples/python_run_object_level.py +++ b/docs/source/docker/advanced/code_examples/python_run_object_level.py @@ -1,32 +1,34 @@ import json + import lightly from lightly.openapi_generated.swagger_client.models.dataset_type import DatasetType -from lightly.openapi_generated.swagger_client.models.datasource_purpose import DatasourcePurpose +from lightly.openapi_generated.swagger_client.models.datasource_purpose import ( + DatasourcePurpose, +) # Create the Lightly client to connect to the API. client = lightly.api.ApiWorkflowClient(token="YOUR_TOKEN") # Create a new dataset on the Lightly Platform. -client.create_dataset('dataset-name', - dataset_type=DatasetType.IMAGES) +client.create_dataset("dataset-name", dataset_type=DatasetType.IMAGES) # Pick one of the following three blocks depending on where your data is -# AWS S3 +# AWS S3 # Input bucket client.set_s3_config( resource_path="s3://bucket/input/", - region='eu-central-1', - access_key='S3-ACCESS-KEY', - secret_access_key='S3-SECRET-ACCESS-KEY', - purpose=DatasourcePurpose.INPUT + region="eu-central-1", + access_key="S3-ACCESS-KEY", + secret_access_key="S3-SECRET-ACCESS-KEY", + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_s3_config( resource_path="s3://bucket/output/", - region='eu-central-1', - access_key='S3-ACCESS-KEY', - secret_access_key='S3-SECRET-ACCESS-KEY', - purpose=DatasourcePurpose.LIGHTLY + region="eu-central-1", + access_key="S3-ACCESS-KEY", + secret_access_key="S3-SECRET-ACCESS-KEY", + purpose=DatasourcePurpose.LIGHTLY, ) @@ -35,37 +37,37 @@ client.set_gcs_config( resource_path="gs://bucket/input/", project_id="PROJECT-ID", - credentials=json.dumps(json.load(open('credentials_read.json'))), - purpose=DatasourcePurpose.INPUT + credentials=json.dumps(json.load(open("credentials_read.json"))), + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_gcs_config( resource_path="gs://bucket/output/", project_id="PROJECT-ID", - credentials=json.dumps(json.load(open('credentials_write.json'))), - purpose=DatasourcePurpose.LIGHTLY + credentials=json.dumps(json.load(open("credentials_write.json"))), + purpose=DatasourcePurpose.LIGHTLY, ) # or Azure Blob Storage # Input bucket client.set_azure_config( - container_name='my-container/input/', - account_name='ACCOUNT-NAME', - sas_token='SAS-TOKEN', - purpose=DatasourcePurpose.INPUT + container_name="my-container/input/", + account_name="ACCOUNT-NAME", + sas_token="SAS-TOKEN", + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_azure_config( - container_name='my-container/output/', - account_name='ACCOUNT-NAME', - sas_token='SAS-TOKEN', - purpose=DatasourcePurpose.LIGHTLY + container_name="my-container/output/", + account_name="ACCOUNT-NAME", + sas_token="SAS-TOKEN", + purpose=DatasourcePurpose.LIGHTLY, ) -# Schedule the docker run with the "object_level.task_name" argument set. -# All other settings are default values and we show them so you can easily edit -# the values according to your need. +# Schedule the docker run with the "object_level.task_name" argument set. +# All other settings are default values and we show them so you can easily edit +# the values according to your need. client.schedule_compute_worker_run( worker_config={ "enable_corruptness_check": True, @@ -73,20 +75,18 @@ "enable_training": False, "pretagging": False, "pretagging_debug": False, - "object_level": { # used for object level workflow - "task_name": "vehicles_object_detections" + "object_level": { # used for object level workflow + "task_name": "vehicles_object_detections" }, }, selection_config={ "n_samples": 100, "strategies": [ { - "input": { - "type": "EMBEDDINGS" - }, + "input": {"type": "EMBEDDINGS"}, "strategy": { "type": "DIVERSITY", - } + }, }, # Optionally, you can combine diversity selection with active learning # to prefer selecting objects the model struggles with. @@ -102,48 +102,34 @@ "type": "WEIGHTS" } } - """ - ] + """, + ], }, lightly_config={ - 'loader': { - 'batch_size': 16, - 'shuffle': True, - 'num_workers': -1, - 'drop_last': True - }, - 'model': { - 'name': 'resnet-18', - 'out_dim': 128, - 'num_ftrs': 32, - 'width': 1 - }, - 'trainer': { - 'gpus': 1, - 'max_epochs': 100, - 'precision': 32 + "loader": { + "batch_size": 16, + "shuffle": True, + "num_workers": -1, + "drop_last": True, }, - 'criterion': { - 'temperature': 0.5 + "model": {"name": "resnet-18", "out_dim": 128, "num_ftrs": 32, "width": 1}, + "trainer": {"gpus": 1, "max_epochs": 100, "precision": 32}, + "criterion": {"temperature": 0.5}, + "optimizer": {"lr": 1, "weight_decay": 0.00001}, + "collate": { + "input_size": 64, + "cj_prob": 0.8, + "cj_bright": 0.7, + "cj_contrast": 0.7, + "cj_sat": 0.7, + "cj_hue": 0.2, + "min_scale": 0.15, + "random_gray_scale": 0.2, + "gaussian_blur": 0.5, + "kernel_size": 0.1, + "vf_prob": 0, + "hf_prob": 0.5, + "rr_prob": 0, }, - 'optimizer': { - 'lr': 1, - 'weight_decay': 0.00001 - }, - 'collate': { - 'input_size': 64, - 'cj_prob': 0.8, - 'cj_bright': 0.7, - 'cj_contrast': 0.7, - 'cj_sat': 0.7, - 'cj_hue': 0.2, - 'min_scale': 0.15, - 'random_gray_scale': 0.2, - 'gaussian_blur': 0.5, - 'kernel_size': 0.1, - 'vf_prob': 0, - 'hf_prob': 0.5, - 'rr_prob': 0 - } - } + }, ) diff --git a/docs/source/docker/advanced/code_examples/python_run_object_level_pretagging.py b/docs/source/docker/advanced/code_examples/python_run_object_level_pretagging.py index ad65c8247..8dd55de20 100644 --- a/docs/source/docker/advanced/code_examples/python_run_object_level_pretagging.py +++ b/docs/source/docker/advanced/code_examples/python_run_object_level_pretagging.py @@ -1,33 +1,35 @@ import json + import lightly from lightly.openapi_generated.swagger_client.models.dataset_type import DatasetType -from lightly.openapi_generated.swagger_client.models.datasource_purpose import DatasourcePurpose +from lightly.openapi_generated.swagger_client.models.datasource_purpose import ( + DatasourcePurpose, +) # Create the Lightly client to connect to the API. client = lightly.api.ApiWorkflowClient(token="YOUR_TOKEN") # Create a new dataset on the Lightly Platform. -client.create_dataset('dataset-name', - dataset_type=DatasetType.IMAGES) +client.create_dataset("dataset-name", dataset_type=DatasetType.IMAGES) # Pick one of the following three blocks depending on where your data is -# AWS S3 +# AWS S3 # Input bucket client.set_s3_config( resource_path="s3://bucket/input/", - region='eu-central-1', - access_key='S3-ACCESS-KEY', - secret_access_key='S3-SECRET-ACCESS-KEY', - purpose=DatasourcePurpose.INPUT + region="eu-central-1", + access_key="S3-ACCESS-KEY", + secret_access_key="S3-SECRET-ACCESS-KEY", + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_s3_config( resource_path="s3://bucket/output/", - region='eu-central-1', - access_key='S3-ACCESS-KEY', - secret_access_key='S3-SECRET-ACCESS-KEY', - purpose=DatasourcePurpose.LIGHTLY + region="eu-central-1", + access_key="S3-ACCESS-KEY", + secret_access_key="S3-SECRET-ACCESS-KEY", + purpose=DatasourcePurpose.LIGHTLY, ) @@ -36,38 +38,38 @@ client.set_gcs_config( resource_path="gs://bucket/input/", project_id="PROJECT-ID", - credentials=json.dumps(json.load(open('credentials_read.json'))), - purpose=DatasourcePurpose.INPUT + credentials=json.dumps(json.load(open("credentials_read.json"))), + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_gcs_config( resource_path="gs://bucket/output/", project_id="PROJECT-ID", - credentials=json.dumps(json.load(open('credentials_write.json'))), - purpose=DatasourcePurpose.LIGHTLY + credentials=json.dumps(json.load(open("credentials_write.json"))), + purpose=DatasourcePurpose.LIGHTLY, ) # or Azure Blob Storage # Input bucket client.set_azure_config( - container_name='my-container/input/', - account_name='ACCOUNT-NAME', - sas_token='SAS-TOKEN', - purpose=DatasourcePurpose.INPUT + container_name="my-container/input/", + account_name="ACCOUNT-NAME", + sas_token="SAS-TOKEN", + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_azure_config( - container_name='my-container/output/', - account_name='ACCOUNT-NAME', - sas_token='SAS-TOKEN', - purpose=DatasourcePurpose.LIGHTLY + container_name="my-container/output/", + account_name="ACCOUNT-NAME", + sas_token="SAS-TOKEN", + purpose=DatasourcePurpose.LIGHTLY, ) # Schedule the docker run with the "object_level.task_name" argument set to # "lightly_pretagging" and with "pretagging" set to True. -# All other settings are default values and we show them so you can easily edit -# the values according to your need. +# All other settings are default values and we show them so you can easily edit +# the values according to your need. client.schedule_compute_worker_run( worker_config={ "enable_corruptness_check": True, @@ -75,62 +77,44 @@ "enable_training": False, "pretagging": True, "pretagging_debug": False, - "object_level": { - "task_name": "lightly_pretagging" - } + "object_level": {"task_name": "lightly_pretagging"}, }, selection_config={ "n_samples": 100, "strategies": [ { - "input": { - "type": "EMBEDDINGS" - }, + "input": {"type": "EMBEDDINGS"}, "strategy": { "type": "DIVERSITY", - } + }, } - ] + ], }, lightly_config={ - 'loader': { - 'batch_size': 16, - 'shuffle': True, - 'num_workers': -1, - 'drop_last': True + "loader": { + "batch_size": 16, + "shuffle": True, + "num_workers": -1, + "drop_last": True, }, - 'model': { - 'name': 'resnet-18', - 'out_dim': 128, - 'num_ftrs': 32, - 'width': 1 + "model": {"name": "resnet-18", "out_dim": 128, "num_ftrs": 32, "width": 1}, + "trainer": {"gpus": 1, "max_epochs": 100, "precision": 32}, + "criterion": {"temperature": 0.5}, + "optimizer": {"lr": 1, "weight_decay": 0.00001}, + "collate": { + "input_size": 64, + "cj_prob": 0.8, + "cj_bright": 0.7, + "cj_contrast": 0.7, + "cj_sat": 0.7, + "cj_hue": 0.2, + "min_scale": 0.15, + "random_gray_scale": 0.2, + "gaussian_blur": 0.5, + "kernel_size": 0.1, + "vf_prob": 0, + "hf_prob": 0.5, + "rr_prob": 0, }, - 'trainer': { - 'gpus': 1, - 'max_epochs': 100, - 'precision': 32 - }, - 'criterion': { - 'temperature': 0.5 - }, - 'optimizer': { - 'lr': 1, - 'weight_decay': 0.00001 - }, - 'collate': { - 'input_size': 64, - 'cj_prob': 0.8, - 'cj_bright': 0.7, - 'cj_contrast': 0.7, - 'cj_sat': 0.7, - 'cj_hue': 0.2, - 'min_scale': 0.15, - 'random_gray_scale': 0.2, - 'gaussian_blur': 0.5, - 'kernel_size': 0.1, - 'vf_prob': 0, - 'hf_prob': 0.5, - 'rr_prob': 0 - } - } + }, ) diff --git a/docs/source/docker/advanced/code_examples/python_run_pretagging.py b/docs/source/docker/advanced/code_examples/python_run_pretagging.py index c2c137080..fe7dedd7b 100644 --- a/docs/source/docker/advanced/code_examples/python_run_pretagging.py +++ b/docs/source/docker/advanced/code_examples/python_run_pretagging.py @@ -1,33 +1,35 @@ import json + import lightly from lightly.openapi_generated.swagger_client.models.dataset_type import DatasetType -from lightly.openapi_generated.swagger_client.models.datasource_purpose import DatasourcePurpose - +from lightly.openapi_generated.swagger_client.models.datasource_purpose import ( + DatasourcePurpose, +) # Create the Lightly client to connect to the API. client = lightly.api.ApiWorkflowClient(token="YOUR_TOKEN") # Create a new dataset on the Lightly Platform. In this example we use pretagging # on images. We can also use videos instead by setting dataset_type=DatasetType.VIDEOS -client.create_dataset('your-dataset-name', dataset_type=DatasetType.IMAGES) +client.create_dataset("your-dataset-name", dataset_type=DatasetType.IMAGES) # Pick one of the following three blocks depending on where your data is -# AWS S3 +# AWS S3 # Input bucket client.set_s3_config( resource_path="s3://bucket/input/", - region='eu-central-1', - access_key='S3-ACCESS-KEY', - secret_access_key='S3-SECRET-ACCESS-KEY', - purpose=DatasourcePurpose.INPUT + region="eu-central-1", + access_key="S3-ACCESS-KEY", + secret_access_key="S3-SECRET-ACCESS-KEY", + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_s3_config( resource_path="s3://bucket/output/", - region='eu-central-1', - access_key='S3-ACCESS-KEY', - secret_access_key='S3-SECRET-ACCESS-KEY', - purpose=DatasourcePurpose.LIGHTLY + region="eu-central-1", + access_key="S3-ACCESS-KEY", + secret_access_key="S3-SECRET-ACCESS-KEY", + purpose=DatasourcePurpose.LIGHTLY, ) @@ -36,32 +38,32 @@ client.set_gcs_config( resource_path="gs://bucket/input/", project_id="PROJECT-ID", - credentials=json.dumps(json.load(open('credentials_read.json'))), - purpose=DatasourcePurpose.INPUT + credentials=json.dumps(json.load(open("credentials_read.json"))), + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_gcs_config( resource_path="gs://bucket/output/", project_id="PROJECT-ID", - credentials=json.dumps(json.load(open('credentials_write.json'))), - purpose=DatasourcePurpose.LIGHTLY + credentials=json.dumps(json.load(open("credentials_write.json"))), + purpose=DatasourcePurpose.LIGHTLY, ) # or Azure Blob Storage # Input bucket client.set_azure_config( - container_name='my-container/input/', - account_name='ACCOUNT-NAME', - sas_token='SAS-TOKEN', - purpose=DatasourcePurpose.INPUT + container_name="my-container/input/", + account_name="ACCOUNT-NAME", + sas_token="SAS-TOKEN", + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_azure_config( - container_name='my-container/output/', - account_name='ACCOUNT-NAME', - sas_token='SAS-TOKEN', - purpose=DatasourcePurpose.LIGHTLY + container_name="my-container/output/", + account_name="ACCOUNT-NAME", + sas_token="SAS-TOKEN", + purpose=DatasourcePurpose.LIGHTLY, ) # Schedule the compute run using our custom config. @@ -69,65 +71,48 @@ # values according to your needs. client.schedule_compute_worker_run( worker_config={ - 'enable_corruptness_check': True, - 'remove_exact_duplicates': True, - 'enable_training': False, - 'pretagging': True, # to enable pretagging - 'pretagging_debug': True, # we also want debugging images in the report + "enable_corruptness_check": True, + "remove_exact_duplicates": True, + "enable_training": False, + "pretagging": True, # to enable pretagging + "pretagging_debug": True, # we also want debugging images in the report }, selection_config={ "n_samples": 100, "strategies": [ { - "input": { - "type": "EMBEDDINGS" - }, + "input": {"type": "EMBEDDINGS"}, "strategy": { "type": "DIVERSITY", - } + }, } - ] + ], }, lightly_config={ - 'loader': { - 'batch_size': 128, - 'shuffle': True, - 'num_workers': -1, - 'drop_last': True + "loader": { + "batch_size": 128, + "shuffle": True, + "num_workers": -1, + "drop_last": True, }, - 'model': { - 'name': 'resnet-18', - 'out_dim': 128, - 'num_ftrs': 32, - 'width': 1 + "model": {"name": "resnet-18", "out_dim": 128, "num_ftrs": 32, "width": 1}, + "trainer": {"gpus": 1, "max_epochs": 1, "precision": 16}, + "criterion": {"temperature": 0.5}, + "optimizer": {"lr": 1, "weight_decay": 0.00001}, + "collate": { + "input_size": 64, + "cj_prob": 0.8, + "cj_bright": 0.7, + "cj_contrast": 0.7, + "cj_sat": 0.7, + "cj_hue": 0.2, + "min_scale": 0.15, + "random_gray_scale": 0.2, + "gaussian_blur": 0.0, + "kernel_size": 0.1, + "vf_prob": 0, + "hf_prob": 0.5, + "rr_prob": 0, }, - 'trainer': { - 'gpus': 1, - 'max_epochs': 1, - 'precision': 16 - }, - 'criterion': { - 'temperature': 0.5 - }, - 'optimizer': { - 'lr': 1, - 'weight_decay': 0.00001 - }, - 'collate': { - 'input_size': 64, - 'cj_prob': 0.8, - 'cj_bright': 0.7, - 'cj_contrast': 0.7, - 'cj_sat': 0.7, - 'cj_hue': 0.2, - 'min_scale': 0.15, - 'random_gray_scale': 0.2, - 'gaussian_blur': 0.0, - 'kernel_size': 0.1, - 'vf_prob': 0, - 'hf_prob': 0.5, - 'rr_prob': 0 - } - } + }, ) - diff --git a/docs/source/docker/advanced/code_examples/python_run_sequence_selection.py b/docs/source/docker/advanced/code_examples/python_run_sequence_selection.py index c8f88c1e0..dfc405793 100644 --- a/docs/source/docker/advanced/code_examples/python_run_sequence_selection.py +++ b/docs/source/docker/advanced/code_examples/python_run_sequence_selection.py @@ -1,31 +1,34 @@ import json + import lightly from lightly.openapi_generated.swagger_client.models.dataset_type import DatasetType -from lightly.openapi_generated.swagger_client.models.datasource_purpose import DatasourcePurpose +from lightly.openapi_generated.swagger_client.models.datasource_purpose import ( + DatasourcePurpose, +) # Create the Lightly client to connect to the API. client = lightly.api.ApiWorkflowClient(token="YOUR_TOKEN") # Create a new dataset on the Lightly Platform. -client.create_dataset('pexels', dataset_type=DatasetType.VIDEOS) +client.create_dataset("pexels", dataset_type=DatasetType.VIDEOS) # Pick one of the following three blocks depending on where your data is -# AWS S3 +# AWS S3 # Input bucket client.set_s3_config( resource_path="s3://bucket/input/", - region='eu-central-1', - access_key='S3-ACCESS-KEY', - secret_access_key='S3-SECRET-ACCESS-KEY', - purpose=DatasourcePurpose.INPUT + region="eu-central-1", + access_key="S3-ACCESS-KEY", + secret_access_key="S3-SECRET-ACCESS-KEY", + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_s3_config( resource_path="s3://bucket/output/", - region='eu-central-1', - access_key='S3-ACCESS-KEY', - secret_access_key='S3-SECRET-ACCESS-KEY', - purpose=DatasourcePurpose.LIGHTLY + region="eu-central-1", + access_key="S3-ACCESS-KEY", + secret_access_key="S3-SECRET-ACCESS-KEY", + purpose=DatasourcePurpose.LIGHTLY, ) @@ -34,32 +37,32 @@ client.set_gcs_config( resource_path="gs://bucket/input/", project_id="PROJECT-ID", - credentials=json.dumps(json.load(open('credentials_read.json'))), - purpose=DatasourcePurpose.INPUT + credentials=json.dumps(json.load(open("credentials_read.json"))), + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_gcs_config( resource_path="gs://bucket/output/", project_id="PROJECT-ID", - credentials=json.dumps(json.load(open('credentials_write.json'))), - purpose=DatasourcePurpose.LIGHTLY + credentials=json.dumps(json.load(open("credentials_write.json"))), + purpose=DatasourcePurpose.LIGHTLY, ) # or Azure Blob Storage # Input bucket client.set_azure_config( - container_name='my-container/input/', - account_name='ACCOUNT-NAME', - sas_token='SAS-TOKEN', - purpose=DatasourcePurpose.INPUT + container_name="my-container/input/", + account_name="ACCOUNT-NAME", + sas_token="SAS-TOKEN", + purpose=DatasourcePurpose.INPUT, ) # Output bucket client.set_azure_config( - container_name='my-container/output/', - account_name='ACCOUNT-NAME', - sas_token='SAS-TOKEN', - purpose=DatasourcePurpose.LIGHTLY + container_name="my-container/output/", + account_name="ACCOUNT-NAME", + sas_token="SAS-TOKEN", + purpose=DatasourcePurpose.LIGHTLY, ) # Schedule the compute run using our custom config. @@ -67,58 +70,43 @@ # values according to your needs. client.schedule_compute_worker_run( worker_config={ - 'enable_corruptness_check': False, - 'remove_exact_duplicates': False, - 'enable_training': False, - 'pretagging': False, - 'pretagging_debug': False, - 'method': 'coreset', - 'stopping_condition': { - 'n_samples': 200, # select 200 frames of length 10 frames -> 20 sequences - 'min_distance': -1 + "enable_corruptness_check": False, + "remove_exact_duplicates": False, + "enable_training": False, + "pretagging": False, + "pretagging_debug": False, + "method": "coreset", + "stopping_condition": { + "n_samples": 200, # select 200 frames of length 10 frames -> 20 sequences + "min_distance": -1, }, - 'selected_sequence_length': 10 # we want sequences of 10 frames lenght + "selected_sequence_length": 10, # we want sequences of 10 frames lenght }, lightly_config={ - 'loader': { - 'batch_size': 128, - 'shuffle': True, - 'num_workers': -1, - 'drop_last': True - }, - 'model': { - 'name': 'resnet-18', - 'out_dim': 128, - 'num_ftrs': 32, - 'width': 1 - }, - 'trainer': { - 'gpus': 1, - 'max_epochs': 1, - 'precision': 16 + "loader": { + "batch_size": 128, + "shuffle": True, + "num_workers": -1, + "drop_last": True, }, - 'criterion': { - 'temperature': 0.5 + "model": {"name": "resnet-18", "out_dim": 128, "num_ftrs": 32, "width": 1}, + "trainer": {"gpus": 1, "max_epochs": 1, "precision": 16}, + "criterion": {"temperature": 0.5}, + "optimizer": {"lr": 1, "weight_decay": 0.00001}, + "collate": { + "input_size": 64, + "cj_prob": 0.8, + "cj_bright": 0.7, + "cj_contrast": 0.7, + "cj_sat": 0.7, + "cj_hue": 0.2, + "min_scale": 0.15, + "random_gray_scale": 0.2, + "gaussian_blur": 0.0, + "kernel_size": 0.1, + "vf_prob": 0, + "hf_prob": 0.5, + "rr_prob": 0, }, - 'optimizer': { - 'lr': 1, - 'weight_decay': 0.00001 - }, - 'collate': { - 'input_size': 64, - 'cj_prob': 0.8, - 'cj_bright': 0.7, - 'cj_contrast': 0.7, - 'cj_sat': 0.7, - 'cj_hue': 0.2, - 'min_scale': 0.15, - 'random_gray_scale': 0.2, - 'gaussian_blur': 0.0, - 'kernel_size': 0.1, - 'vf_prob': 0, - 'hf_prob': 0.5, - 'rr_prob': 0 - } - } + }, ) - diff --git a/docs/source/docker/advanced/code_examples/semantic_segmentation_inference.py b/docs/source/docker/advanced/code_examples/semantic_segmentation_inference.py index 14c0b976b..62a4f379f 100644 --- a/docs/source/docker/advanced/code_examples/semantic_segmentation_inference.py +++ b/docs/source/docker/advanced/code_examples/semantic_segmentation_inference.py @@ -1,5 +1,5 @@ -import os import json +import os import numpy as np diff --git a/docs/source/docker/integration/examples/create_dataset.py b/docs/source/docker/integration/examples/create_dataset.py index 296c1f259..ce8e9aee8 100644 --- a/docs/source/docker/integration/examples/create_dataset.py +++ b/docs/source/docker/integration/examples/create_dataset.py @@ -4,11 +4,11 @@ client = lightly.api.ApiWorkflowClient(token="LIGHTLY_TOKEN") # Create a new dataset on the Lightly Platform. -client.create_dataset('dataset-name') +client.create_dataset("dataset-name") # Connect the dataset to your cloud bucket. -# AWS S3 +# AWS S3 client.set_s3_config( resource_path="s3://bucket/dataset/", region="eu-central-1", @@ -19,17 +19,18 @@ # Google Cloud Storage import json + client.set_gcs_config( resource_path="gs://bucket/dataset/", project_id="PROJECT-ID", - credentials=json.dumps(json.load(open('credentials.json'))), + credentials=json.dumps(json.load(open("credentials.json"))), thumbnail_suffix=None, ) -# Azure Blob Storage +# Azure Blob Storage client.set_azure_config( container_name="container/dataset/", account_name="ACCOUNT-NAME", sas_token="SAS-TOKEN", thumbnail_suffix=None, -) \ No newline at end of file +) diff --git a/docs/source/docker/integration/examples/trigger_job.py b/docs/source/docker/integration/examples/trigger_job.py index fea017c39..e60831c24 100644 --- a/docs/source/docker/integration/examples/trigger_job.py +++ b/docs/source/docker/integration/examples/trigger_job.py @@ -1,6 +1,9 @@ import time -from lightly.openapi_generated.swagger_client import DockerRunScheduledState, DockerRunState +from lightly.openapi_generated.swagger_client import ( + DockerRunScheduledState, + DockerRunState, +) # You can reuse the client from previous scripts. If you want to create a new # one you can uncomment the following line: @@ -13,72 +16,55 @@ scheduled_run_id = client.schedule_compute_worker_run( worker_config={ - 'enable_corruptness_check': True, - 'remove_exact_duplicates': True, - 'enable_training': False, + "enable_corruptness_check": True, + "remove_exact_duplicates": True, + "enable_training": False, }, selection_config={ "n_samples": 50, "strategies": [ - { - "input": { - "type": "EMBEDDINGS" - }, - "strategy": { - "type": "DIVERSITY" - } - } - ] + {"input": {"type": "EMBEDDINGS"}, "strategy": {"type": "DIVERSITY"}} + ], }, lightly_config={ - 'loader': { - 'batch_size': 16, - 'shuffle': True, - 'num_workers': -1, - 'drop_last': True - }, - 'model': { - 'name': 'resnet-18', - 'out_dim': 128, - 'num_ftrs': 32, - 'width': 1 - }, - 'trainer': { - 'gpus': 1, - 'max_epochs': 100, - 'precision': 32 + "loader": { + "batch_size": 16, + "shuffle": True, + "num_workers": -1, + "drop_last": True, }, - 'criterion': { - 'temperature': 0.5 + "model": {"name": "resnet-18", "out_dim": 128, "num_ftrs": 32, "width": 1}, + "trainer": {"gpus": 1, "max_epochs": 100, "precision": 32}, + "criterion": {"temperature": 0.5}, + "optimizer": {"lr": 1, "weight_decay": 0.00001}, + "collate": { + "input_size": 64, + "cj_prob": 0.8, + "cj_bright": 0.7, + "cj_contrast": 0.7, + "cj_sat": 0.7, + "cj_hue": 0.2, + "min_scale": 0.15, + "random_gray_scale": 0.2, + "gaussian_blur": 0.5, + "kernel_size": 0.1, + "vf_prob": 0, + "hf_prob": 0.5, + "rr_prob": 0, }, - 'optimizer': { - 'lr': 1, - 'weight_decay': 0.00001 - }, - 'collate': { - 'input_size': 64, - 'cj_prob': 0.8, - 'cj_bright': 0.7, - 'cj_contrast': 0.7, - 'cj_sat': 0.7, - 'cj_hue': 0.2, - 'min_scale': 0.15, - 'random_gray_scale': 0.2, - 'gaussian_blur': 0.5, - 'kernel_size': 0.1, - 'vf_prob': 0, - 'hf_prob': 0.5, - 'rr_prob': 0 - } - } + }, ) """ Optionally, You can use this code to track and print the state of the compute worker. The loop will end once the compute worker run has finished, was canceled or aborted/failed. """ -for run_info in client.compute_worker_run_info_generator(scheduled_run_id=scheduled_run_id): - print(f"Compute worker run is now in state='{run_info.state}' with message='{run_info.message}'") +for run_info in client.compute_worker_run_info_generator( + scheduled_run_id=scheduled_run_id +): + print( + f"Compute worker run is now in state='{run_info.state}' with message='{run_info.message}'" + ) if run_info.ended_successfully(): print("SUCCESS") diff --git a/docs/source/docker_archive/advanced/code_examples/python_run_active_learning.py b/docs/source/docker_archive/advanced/code_examples/python_run_active_learning.py index aefc9dfbc..9eba6c458 100644 --- a/docs/source/docker_archive/advanced/code_examples/python_run_active_learning.py +++ b/docs/source/docker_archive/advanced/code_examples/python_run_active_learning.py @@ -3,7 +3,7 @@ # Create the Lightly client to connect to the API. client = lightly.api.ApiWorkflowClient(token="LIGHTLY_TOKEN", dataset_id="DATASET_ID") -# Schedule the docker run with +# Schedule the docker run with # - "active_learning.task_name" set to your task name # - "method" set to "coral" # All other settings are default values and we show them so you can easily edit @@ -16,59 +16,39 @@ "pretagging": False, "pretagging_debug": False, "method": "coral", - "stopping_condition": { - "n_samples": 0.1, - "min_distance": -1 - }, + "stopping_condition": {"n_samples": 0.1, "min_distance": -1}, "scorer": "object-frequency", - "scorer_config": { - "frequency_penalty": 0.25, - "min_score": 0.9 - }, + "scorer_config": {"frequency_penalty": 0.25, "min_score": 0.9}, "active_learning": { - "task_name": "my-classification-task", - "score_name": "uncertainty_margin" - } + "task_name": "my-classification-task", + "score_name": "uncertainty_margin", + }, }, lightly_config={ - 'loader': { - 'batch_size': 16, - 'shuffle': True, - 'num_workers': -1, - 'drop_last': True - }, - 'model': { - 'name': 'resnet-18', - 'out_dim': 128, - 'num_ftrs': 32, - 'width': 1 + "loader": { + "batch_size": 16, + "shuffle": True, + "num_workers": -1, + "drop_last": True, }, - 'trainer': { - 'gpus': 1, - 'max_epochs': 100, - 'precision': 32 + "model": {"name": "resnet-18", "out_dim": 128, "num_ftrs": 32, "width": 1}, + "trainer": {"gpus": 1, "max_epochs": 100, "precision": 32}, + "criterion": {"temperature": 0.5}, + "optimizer": {"lr": 1, "weight_decay": 0.00001}, + "collate": { + "input_size": 64, + "cj_prob": 0.8, + "cj_bright": 0.7, + "cj_contrast": 0.7, + "cj_sat": 0.7, + "cj_hue": 0.2, + "min_scale": 0.15, + "random_gray_scale": 0.2, + "gaussian_blur": 0.5, + "kernel_size": 0.1, + "vf_prob": 0, + "hf_prob": 0.5, + "rr_prob": 0, }, - 'criterion': { - 'temperature': 0.5 - }, - 'optimizer': { - 'lr': 1, - 'weight_decay': 0.00001 - }, - 'collate': { - 'input_size': 64, - 'cj_prob': 0.8, - 'cj_bright': 0.7, - 'cj_contrast': 0.7, - 'cj_sat': 0.7, - 'cj_hue': 0.2, - 'min_scale': 0.15, - 'random_gray_scale': 0.2, - 'gaussian_blur': 0.5, - 'kernel_size': 0.1, - 'vf_prob': 0, - 'hf_prob': 0.5, - 'rr_prob': 0 - } - } -) \ No newline at end of file + }, +) diff --git a/docs/source/docker_archive/advanced/code_examples/python_run_object_level.py b/docs/source/docker_archive/advanced/code_examples/python_run_object_level.py index b1bc99d68..08b73a1e5 100644 --- a/docs/source/docker_archive/advanced/code_examples/python_run_object_level.py +++ b/docs/source/docker_archive/advanced/code_examples/python_run_object_level.py @@ -3,73 +3,48 @@ # Create the Lightly client to connect to the API. client = lightly.api.ApiWorkflowClient(token="LIGHTLY_TOKEN", dataset_id="DATASET_ID") -# Schedule the docker run with the "object_level.task_name" argument set. -# All other settings are default values and we show them so you can easily edit -# the values according to your need. +# Schedule the docker run with the "object_level.task_name" argument set. +# All other settings are default values and we show them so you can easily edit +# the values according to your need. client.schedule_compute_worker_run( worker_config={ - "object_level": { - "task_name": "vehicles_object_detections" - }, + "object_level": {"task_name": "vehicles_object_detections"}, "enable_corruptness_check": True, "remove_exact_duplicates": True, "enable_training": False, "pretagging": False, "pretagging_debug": False, "method": "coreset", - "stopping_condition": { - "n_samples": 0.1, - "min_distance": -1 - }, + "stopping_condition": {"n_samples": 0.1, "min_distance": -1}, "scorer": "object-frequency", - "scorer_config": { - "frequency_penalty": 0.25, - "min_score": 0.9 - }, - "active_learning": { - "task_name": "", - "score_name": "uncertainty_margin" - } + "scorer_config": {"frequency_penalty": 0.25, "min_score": 0.9}, + "active_learning": {"task_name": "", "score_name": "uncertainty_margin"}, }, lightly_config={ - 'loader': { - 'batch_size': 16, - 'shuffle': True, - 'num_workers': -1, - 'drop_last': True - }, - 'model': { - 'name': 'resnet-18', - 'out_dim': 128, - 'num_ftrs': 32, - 'width': 1 + "loader": { + "batch_size": 16, + "shuffle": True, + "num_workers": -1, + "drop_last": True, + }, + "model": {"name": "resnet-18", "out_dim": 128, "num_ftrs": 32, "width": 1}, + "trainer": {"gpus": 1, "max_epochs": 100, "precision": 32}, + "criterion": {"temperature": 0.5}, + "optimizer": {"lr": 1, "weight_decay": 0.00001}, + "collate": { + "input_size": 64, + "cj_prob": 0.8, + "cj_bright": 0.7, + "cj_contrast": 0.7, + "cj_sat": 0.7, + "cj_hue": 0.2, + "min_scale": 0.15, + "random_gray_scale": 0.2, + "gaussian_blur": 0.5, + "kernel_size": 0.1, + "vf_prob": 0, + "hf_prob": 0.5, + "rr_prob": 0, }, - 'trainer': { - 'gpus': 1, - 'max_epochs': 100, - 'precision': 32 - }, - 'criterion': { - 'temperature': 0.5 - }, - 'optimizer': { - 'lr': 1, - 'weight_decay': 0.00001 - }, - 'collate': { - 'input_size': 64, - 'cj_prob': 0.8, - 'cj_bright': 0.7, - 'cj_contrast': 0.7, - 'cj_sat': 0.7, - 'cj_hue': 0.2, - 'min_scale': 0.15, - 'random_gray_scale': 0.2, - 'gaussian_blur': 0.5, - 'kernel_size': 0.1, - 'vf_prob': 0, - 'hf_prob': 0.5, - 'rr_prob': 0 - } - } + }, ) diff --git a/docs/source/docker_archive/advanced/code_examples/python_run_object_level_pretagging.py b/docs/source/docker_archive/advanced/code_examples/python_run_object_level_pretagging.py index e0253e5f7..76e237aa8 100644 --- a/docs/source/docker_archive/advanced/code_examples/python_run_object_level_pretagging.py +++ b/docs/source/docker_archive/advanced/code_examples/python_run_object_level_pretagging.py @@ -5,72 +5,47 @@ # Schedule the docker run with the "object_level.task_name" argument set to # "lightly_pretagging" and with "pretagging" set to True. -# All other settings are default values and we show them so you can easily edit -# the values according to your need. +# All other settings are default values and we show them so you can easily edit +# the values according to your need. client.schedule_compute_worker_run( worker_config={ - "object_level": { - "task_name": "lightly_pretagging" - }, + "object_level": {"task_name": "lightly_pretagging"}, "enable_corruptness_check": True, "remove_exact_duplicates": True, "enable_training": False, "pretagging": True, "pretagging_debug": False, "method": "coreset", - "stopping_condition": { - "n_samples": 0.1, - "min_distance": -1 - }, + "stopping_condition": {"n_samples": 0.1, "min_distance": -1}, "scorer": "object-frequency", - "scorer_config": { - "frequency_penalty": 0.25, - "min_score": 0.9 - }, - "active_learning": { - "task_name": "", - "score_name": "uncertainty_margin" - } + "scorer_config": {"frequency_penalty": 0.25, "min_score": 0.9}, + "active_learning": {"task_name": "", "score_name": "uncertainty_margin"}, }, lightly_config={ - 'loader': { - 'batch_size': 16, - 'shuffle': True, - 'num_workers': -1, - 'drop_last': True - }, - 'model': { - 'name': 'resnet-18', - 'out_dim': 128, - 'num_ftrs': 32, - 'width': 1 + "loader": { + "batch_size": 16, + "shuffle": True, + "num_workers": -1, + "drop_last": True, + }, + "model": {"name": "resnet-18", "out_dim": 128, "num_ftrs": 32, "width": 1}, + "trainer": {"gpus": 1, "max_epochs": 100, "precision": 32}, + "criterion": {"temperature": 0.5}, + "optimizer": {"lr": 1, "weight_decay": 0.00001}, + "collate": { + "input_size": 64, + "cj_prob": 0.8, + "cj_bright": 0.7, + "cj_contrast": 0.7, + "cj_sat": 0.7, + "cj_hue": 0.2, + "min_scale": 0.15, + "random_gray_scale": 0.2, + "gaussian_blur": 0.5, + "kernel_size": 0.1, + "vf_prob": 0, + "hf_prob": 0.5, + "rr_prob": 0, }, - 'trainer': { - 'gpus': 1, - 'max_epochs': 100, - 'precision': 32 - }, - 'criterion': { - 'temperature': 0.5 - }, - 'optimizer': { - 'lr': 1, - 'weight_decay': 0.00001 - }, - 'collate': { - 'input_size': 64, - 'cj_prob': 0.8, - 'cj_bright': 0.7, - 'cj_contrast': 0.7, - 'cj_sat': 0.7, - 'cj_hue': 0.2, - 'min_scale': 0.15, - 'random_gray_scale': 0.2, - 'gaussian_blur': 0.5, - 'kernel_size': 0.1, - 'vf_prob': 0, - 'hf_prob': 0.5, - 'rr_prob': 0 - } - } + }, ) diff --git a/docs/source/docker_archive/integration/examples/create_dataset.py b/docs/source/docker_archive/integration/examples/create_dataset.py index 296c1f259..ce8e9aee8 100644 --- a/docs/source/docker_archive/integration/examples/create_dataset.py +++ b/docs/source/docker_archive/integration/examples/create_dataset.py @@ -4,11 +4,11 @@ client = lightly.api.ApiWorkflowClient(token="LIGHTLY_TOKEN") # Create a new dataset on the Lightly Platform. -client.create_dataset('dataset-name') +client.create_dataset("dataset-name") # Connect the dataset to your cloud bucket. -# AWS S3 +# AWS S3 client.set_s3_config( resource_path="s3://bucket/dataset/", region="eu-central-1", @@ -19,17 +19,18 @@ # Google Cloud Storage import json + client.set_gcs_config( resource_path="gs://bucket/dataset/", project_id="PROJECT-ID", - credentials=json.dumps(json.load(open('credentials.json'))), + credentials=json.dumps(json.load(open("credentials.json"))), thumbnail_suffix=None, ) -# Azure Blob Storage +# Azure Blob Storage client.set_azure_config( container_name="container/dataset/", account_name="ACCOUNT-NAME", sas_token="SAS-TOKEN", thumbnail_suffix=None, -) \ No newline at end of file +) diff --git a/docs/source/docker_archive/integration/examples/trigger_job.py b/docs/source/docker_archive/integration/examples/trigger_job.py index 1e4b6a32c..f4749aae6 100644 --- a/docs/source/docker_archive/integration/examples/trigger_job.py +++ b/docs/source/docker_archive/integration/examples/trigger_job.py @@ -7,61 +7,41 @@ # values according to your needs. client.schedule_compute_worker_run( worker_config={ - 'enable_corruptness_check': True, - 'remove_exact_duplicates': True, - 'enable_training': False, - 'pretagging': False, - 'pretagging_debug': False, - 'method': 'coreset', - 'stopping_condition': { - 'n_samples': 0.1, - 'min_distance': -1 - }, - 'scorer': 'object-frequency', - 'scorer_config': { - 'frequency_penalty': 0.25, - 'min_score': 0.9 - } + "enable_corruptness_check": True, + "remove_exact_duplicates": True, + "enable_training": False, + "pretagging": False, + "pretagging_debug": False, + "method": "coreset", + "stopping_condition": {"n_samples": 0.1, "min_distance": -1}, + "scorer": "object-frequency", + "scorer_config": {"frequency_penalty": 0.25, "min_score": 0.9}, }, lightly_config={ - 'loader': { - 'batch_size': 16, - 'shuffle': True, - 'num_workers': -1, - 'drop_last': True - }, - 'model': { - 'name': 'resnet-18', - 'out_dim': 128, - 'num_ftrs': 32, - 'width': 1 - }, - 'trainer': { - 'gpus': 1, - 'max_epochs': 100, - 'precision': 32 + "loader": { + "batch_size": 16, + "shuffle": True, + "num_workers": -1, + "drop_last": True, }, - 'criterion': { - 'temperature': 0.5 + "model": {"name": "resnet-18", "out_dim": 128, "num_ftrs": 32, "width": 1}, + "trainer": {"gpus": 1, "max_epochs": 100, "precision": 32}, + "criterion": {"temperature": 0.5}, + "optimizer": {"lr": 1, "weight_decay": 0.00001}, + "collate": { + "input_size": 64, + "cj_prob": 0.8, + "cj_bright": 0.7, + "cj_contrast": 0.7, + "cj_sat": 0.7, + "cj_hue": 0.2, + "min_scale": 0.15, + "random_gray_scale": 0.2, + "gaussian_blur": 0.5, + "kernel_size": 0.1, + "vf_prob": 0, + "hf_prob": 0.5, + "rr_prob": 0, }, - 'optimizer': { - 'lr': 1, - 'weight_decay': 0.00001 - }, - 'collate': { - 'input_size': 64, - 'cj_prob': 0.8, - 'cj_bright': 0.7, - 'cj_contrast': 0.7, - 'cj_sat': 0.7, - 'cj_hue': 0.2, - 'min_scale': 0.15, - 'random_gray_scale': 0.2, - 'gaussian_blur': 0.5, - 'kernel_size': 0.1, - 'vf_prob': 0, - 'hf_prob': 0.5, - 'rr_prob': 0 - } - } + }, ) diff --git a/docs/source/getting_started/benchmarks/cifar10_benchmark.py b/docs/source/getting_started/benchmarks/cifar10_benchmark.py index 9673f4b9d..0187d6bf2 100644 --- a/docs/source/getting_started/benchmarks/cifar10_benchmark.py +++ b/docs/source/getting_started/benchmarks/cifar10_benchmark.py @@ -58,36 +58,37 @@ """ import copy import os - import time + import numpy as np import pytorch_lightning as pl import torch import torch.nn as nn import torchvision +from pytorch_lightning.loggers import TensorBoardLogger + from lightly.data import ( + DINOCollateFunction, LightlyDataset, SimCLRCollateFunction, SwaVCollateFunction, - DINOCollateFunction, collate, ) from lightly.loss import ( - NTXentLoss, - NegativeCosineSimilarity, - memory_bank, + BarlowTwinsLoss, DCLLoss, DCLWLoss, DINOLoss, - BarlowTwinsLoss, + NegativeCosineSimilarity, + NTXentLoss, SwaVLoss, + memory_bank, ) -from lightly.models import modules, ResNetGenerator, utils +from lightly.models import ResNetGenerator, modules, utils from lightly.models.modules import heads from lightly.utils.benchmarking import BenchmarkModule -from pytorch_lightning.loggers import TensorBoardLogger -logs_root_dir = os.path.join(os.getcwd(), 'benchmark_logs') +logs_root_dir = os.path.join(os.getcwd(), "benchmark_logs") # set max_epochs to 800 for long run (takes around 10h on a single V100) max_epochs = 200 @@ -96,30 +97,30 @@ knn_t = 0.1 classes = 10 -# Set to True to enable Distributed Data Parallel training. +# Set to True to enable Distributed Data Parallel training. distributed = False -# Set to True to enable Synchronized Batch Norm (requires distributed=True). +# Set to True to enable Synchronized Batch Norm (requires distributed=True). # If enabled the batch norm is calculated over all gpus, otherwise the batch # norm is only calculated from samples on the same gpu. sync_batchnorm = False -# Set to True to gather features from all gpus before calculating +# Set to True to gather features from all gpus before calculating # the loss (requires distributed=True). -# If enabled then the loss on every gpu is calculated with features from all +# If enabled then the loss on every gpu is calculated with features from all # gpus, otherwise only features from the same gpu are used. -gather_distributed = False +gather_distributed = False # benchmark -n_runs = 1 # optional, increase to create multiple runs and report mean + std +n_runs = 1 # optional, increase to create multiple runs and report mean + std batch_size = 128 -lr_factor = batch_size / 128 # scales the learning rate linearly with batch size +lr_factor = batch_size / 128 # scales the learning rate linearly with batch size # use a GPU if available gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 if distributed: - distributed_backend = 'ddp' + distributed_backend = "ddp" # reduce batch size for distributed training batch_size = batch_size // gpus else: @@ -133,7 +134,7 @@ # We assume we have a train folder with subfolders # for each class and .png images inside. # -# You can download `CIFAR-10 in folders from kaggle +# You can download `CIFAR-10 in folders from kaggle # `_. # The dataset structure should be like this: @@ -150,19 +151,19 @@ # L horse/ # L ship/ # L truck/ -path_to_train = '/datasets/cifar10/train/' -path_to_test = '/datasets/cifar10/test/' +path_to_train = "/datasets/cifar10/train/" +path_to_test = "/datasets/cifar10/test/" # Use SimCLR augmentations, additionally, disable blur for cifar10 collate_fn = SimCLRCollateFunction( input_size=32, - gaussian_blur=0., + gaussian_blur=0.0, ) # Multi crop augmentation for SwAV, additionally, disable blur for cifar10 swav_collate_fn = SwaVCollateFunction( crop_sizes=[32], - crop_counts=[2], # 2 crops @ 32x32px + crop_counts=[2], # 2 crops @ 32x32px crop_min_scales=[0.14], gaussian_blur=0, ) @@ -178,34 +179,29 @@ smog_collate_function = collate.SMoGCollateFunction( crop_sizes=[32, 32], crop_counts=[1, 1], - gaussian_blur_probs=[0., 0.], + gaussian_blur_probs=[0.0, 0.0], crop_min_scales=[0.2, 0.2], crop_max_scales=[1.0, 1.0], ) # No additional augmentations for the test set -test_transforms = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize( - mean=collate.imagenet_normalize['mean'], - std=collate.imagenet_normalize['std'], - ) -]) - -dataset_train_ssl = LightlyDataset( - input_dir=path_to_train +test_transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + mean=collate.imagenet_normalize["mean"], + std=collate.imagenet_normalize["std"], + ), + ] ) +dataset_train_ssl = LightlyDataset(input_dir=path_to_train) + # we use test transformations for getting the feature for kNN on train data -dataset_train_kNN = LightlyDataset( - input_dir=path_to_train, - transform=test_transforms -) +dataset_train_kNN = LightlyDataset(input_dir=path_to_train, transform=test_transforms) + +dataset_test = LightlyDataset(input_dir=path_to_test, transform=test_transforms) -dataset_test = LightlyDataset( - input_dir=path_to_test, - transform=test_transforms -) def get_data_loaders(batch_size: int, model): """Helper method to create dataloaders for ssl, kNN train and kNN test @@ -226,7 +222,7 @@ def get_data_loaders(batch_size: int, model): shuffle=True, collate_fn=col_fn, drop_last=True, - num_workers=num_workers + num_workers=num_workers, ) dataloader_train_kNN = torch.utils.data.DataLoader( @@ -234,7 +230,7 @@ def get_data_loaders(batch_size: int, model): batch_size=batch_size, shuffle=False, drop_last=False, - num_workers=num_workers + num_workers=num_workers, ) dataloader_test = torch.utils.data.DataLoader( @@ -242,21 +238,21 @@ def get_data_loaders(batch_size: int, model): batch_size=batch_size, shuffle=False, drop_last=False, - num_workers=num_workers + num_workers=num_workers, ) return dataloader_train_ssl, dataloader_train_kNN, dataloader_test + class MocoModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes) # create a ResNet backbone and remove the classification head num_splits = 0 if sync_batchnorm else 8 - resnet = ResNetGenerator('resnet-18', num_splits=num_splits) + resnet = ResNetGenerator("resnet-18", num_splits=num_splits) self.backbone = nn.Sequential( - *list(resnet.children())[:-1], - nn.AdaptiveAvgPool2d(1) + *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1) ) # create a moco model based on ResNet @@ -271,7 +267,7 @@ def __init__(self, dataloader_kNN, num_classes): temperature=0.1, memory_bank_size=4096, ) - + def forward(self, x): x = self.backbone(x).flatten(start_dim=1) return self.projection_head(x) @@ -299,15 +295,17 @@ def step(x0_, x1_): loss_2 = self.criterion(*step(x1, x0)) loss = 0.5 * (loss_1 + loss_2) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): - params = list(self.backbone.parameters()) + list(self.projection_head.parameters()) + params = list(self.backbone.parameters()) + list( + self.projection_head.parameters() + ) optim = torch.optim.SGD( - params, + params, lr=6e-2 * lr_factor, - momentum=0.9, + momentum=0.9, weight_decay=5e-4, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) @@ -318,10 +316,9 @@ class SimCLRModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes) # create a ResNet backbone and remove the classification head - resnet = ResNetGenerator('resnet-18') + resnet = ResNetGenerator("resnet-18") self.backbone = nn.Sequential( - *list(resnet.children())[:-1], - nn.AdaptiveAvgPool2d(1) + *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1) ) self.projection_head = heads.SimCLRProjectionHead(512, 512, 128) self.criterion = NTXentLoss() @@ -336,15 +333,12 @@ def training_step(self, batch, batch_index): z0 = self.forward(x0) z1 = self.forward(x1) loss = self.criterion(z0, z1) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): optim = torch.optim.SGD( - self.parameters(), - lr=6e-2 * lr_factor, - momentum=0.9, - weight_decay=5e-4 + self.parameters(), lr=6e-2 * lr_factor, momentum=0.9, weight_decay=5e-4 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] @@ -354,29 +348,20 @@ class SimSiamModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes) # create a ResNet backbone and remove the classification head - resnet = ResNetGenerator('resnet-18') + resnet = ResNetGenerator("resnet-18") self.backbone = nn.Sequential( - *list(resnet.children())[:-1], - nn.AdaptiveAvgPool2d(1) + *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1) ) self.prediction_head = heads.SimSiamPredictionHead(2048, 512, 2048) # use a 2-layer projection head for cifar10 as described in the paper - self.projection_head = heads.ProjectionHead([ - ( - 512, - 2048, - nn.BatchNorm1d(2048), - nn.ReLU(inplace=True) - ), - ( - 2048, - 2048, - nn.BatchNorm1d(2048), - None - ) - ]) + self.projection_head = heads.ProjectionHead( + [ + (512, 2048, nn.BatchNorm1d(2048), nn.ReLU(inplace=True)), + (2048, 2048, nn.BatchNorm1d(2048), None), + ] + ) self.criterion = NegativeCosineSimilarity() - + def forward(self, x): f = self.backbone(x).flatten(start_dim=1) z = self.projection_head(f) @@ -389,43 +374,35 @@ def training_step(self, batch, batch_idx): z0, p0 = self.forward(x0) z1, p1 = self.forward(x1) loss = 0.5 * (self.criterion(z0, p1) + self.criterion(z1, p0)) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): optim = torch.optim.SGD( - self.parameters(), - lr=6e-2, # no lr-scaling, results in better training stability + self.parameters(), + lr=6e-2, # no lr-scaling, results in better training stability momentum=0.9, - weight_decay=5e-4 + weight_decay=5e-4, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] + class BarlowTwinsModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes) # create a ResNet backbone and remove the classification head - resnet = ResNetGenerator('resnet-18') + resnet = ResNetGenerator("resnet-18") self.backbone = nn.Sequential( - *list(resnet.children())[:-1], - nn.AdaptiveAvgPool2d(1) + *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1) ) # use a 2-layer projection head for cifar10 as described in the paper - self.projection_head = heads.ProjectionHead([ - ( - 512, - 2048, - nn.BatchNorm1d(2048), - nn.ReLU(inplace=True) - ), - ( - 2048, - 2048, - None, - None - ) - ]) + self.projection_head = heads.ProjectionHead( + [ + (512, 2048, nn.BatchNorm1d(2048), nn.ReLU(inplace=True)), + (2048, 2048, None, None), + ] + ) self.criterion = BarlowTwinsLoss(gather_distributed=gather_distributed) @@ -439,27 +416,24 @@ def training_step(self, batch, batch_index): z0 = self.forward(x0) z1 = self.forward(x1) loss = self.criterion(z0, z1) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): optim = torch.optim.SGD( - self.parameters(), - lr=6e-2 * lr_factor, - momentum=0.9, - weight_decay=5e-4 + self.parameters(), lr=6e-2 * lr_factor, momentum=0.9, weight_decay=5e-4 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] + class BYOLModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes) # create a ResNet backbone and remove the classification head - resnet = ResNetGenerator('resnet-18') + resnet = ResNetGenerator("resnet-18") self.backbone = nn.Sequential( - *list(resnet.children())[:-1], - nn.AdaptiveAvgPool2d(1) + *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1) ) # create a byol model based on ResNet @@ -488,41 +462,45 @@ def forward_momentum(self, x): def training_step(self, batch, batch_idx): utils.update_momentum(self.backbone, self.backbone_momentum, m=0.99) - utils.update_momentum(self.projection_head, self.projection_head_momentum, m=0.99) + utils.update_momentum( + self.projection_head, self.projection_head_momentum, m=0.99 + ) (x0, x1), _, _ = batch p0 = self.forward(x0) z0 = self.forward_momentum(x0) p1 = self.forward(x1) z1 = self.forward_momentum(x1) loss = 0.5 * (self.criterion(p0, z1) + self.criterion(p1, z0)) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): - params = list(self.backbone.parameters()) \ - + list(self.projection_head.parameters()) \ + params = ( + list(self.backbone.parameters()) + + list(self.projection_head.parameters()) + list(self.prediction_head.parameters()) + ) optim = torch.optim.SGD( - params, + params, lr=6e-2 * lr_factor, - momentum=0.9, + momentum=0.9, weight_decay=5e-4, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] + class SwaVModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes) # create a ResNet backbone and remove the classification head - resnet = ResNetGenerator('resnet-18') + resnet = ResNetGenerator("resnet-18") self.backbone = nn.Sequential( - *list(resnet.children())[:-1], - nn.AdaptiveAvgPool2d(1) + *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1) ) self.projection_head = heads.SwaVProjectionHead(512, 512, 128) - self.prototypes = heads.SwaVPrototypes(128, 512) # use 512 prototypes + self.prototypes = heads.SwaVPrototypes(128, 512) # use 512 prototypes self.criterion = SwaVLoss(sinkhorn_gather_distributed=gather_distributed) @@ -547,12 +525,9 @@ def training_step(self, batch, batch_idx): low_resolution_features = multi_crop_features[2:] # calculate the SwaV loss - loss = self.criterion( - high_resolution_features, - low_resolution_features - ) + loss = self.criterion(high_resolution_features, low_resolution_features) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): @@ -569,27 +544,18 @@ class NNCLRModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes) # create a ResNet backbone and remove the classification head - resnet = ResNetGenerator('resnet-18') + resnet = ResNetGenerator("resnet-18") self.backbone = nn.Sequential( - *list(resnet.children())[:-1], - nn.AdaptiveAvgPool2d(1) + *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1) ) self.prediction_head = heads.NNCLRPredictionHead(256, 4096, 256) # use only a 2-layer projection head for cifar10 - self.projection_head = heads.ProjectionHead([ - ( - 512, - 2048, - nn.BatchNorm1d(2048), - nn.ReLU(inplace=True) - ), - ( - 2048, - 256, - nn.BatchNorm1d(256), - None - ) - ]) + self.projection_head = heads.ProjectionHead( + [ + (512, 2048, nn.BatchNorm1d(2048), nn.ReLU(inplace=True)), + (2048, 256, nn.BatchNorm1d(256), None), + ] + ) self.criterion = NTXentLoss() self.memory_bank = modules.NNMemoryBankModule(size=4096) @@ -612,9 +578,9 @@ def training_step(self, batch, batch_idx): def configure_optimizers(self): optim = torch.optim.SGD( - self.parameters(), + self.parameters(), lr=6e-2 * lr_factor, - momentum=0.9, + momentum=0.9, weight_decay=5e-4, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) @@ -625,10 +591,9 @@ class DINOModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes) # create a ResNet backbone and remove the classification head - resnet = ResNetGenerator('resnet-18') + resnet = ResNetGenerator("resnet-18") self.backbone = nn.Sequential( - *list(resnet.children())[:-1], - nn.AdaptiveAvgPool2d(1) + *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1) ) self.head = self._build_projection_head() self.teacher_backbone = copy.deepcopy(self.backbone) @@ -642,10 +607,12 @@ def __init__(self, dataloader_kNN, num_classes): def _build_projection_head(self): head = heads.DINOProjectionHead(512, 2048, 256, 2048, batch_norm=True) # use only 2 layers for cifar10 - head.layers = heads.ProjectionHead([ - (512, 2048, nn.BatchNorm1d(2048), nn.GELU()), - (2048, 256, None, None), - ]).layers + head.layers = heads.ProjectionHead( + [ + (512, 2048, nn.BatchNorm1d(2048), nn.GELU()), + (2048, 256, None, None), + ] + ).layers return head def forward(self, x): @@ -667,12 +634,11 @@ def training_step(self, batch, batch_idx): teacher_out = [self.forward_teacher(view) for view in global_views] student_out = [self.forward(view) for view in views] loss = self.criterion(teacher_out, student_out, epoch=self.current_epoch) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): - param = list(self.backbone.parameters()) \ - + list(self.head.parameters()) + param = list(self.backbone.parameters()) + list(self.head.parameters()) optim = torch.optim.SGD( param, lr=6e-2 * lr_factor, @@ -687,10 +653,9 @@ class DCL(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes) # create a ResNet backbone and remove the classification head - resnet = ResNetGenerator('resnet-18') + resnet = ResNetGenerator("resnet-18") self.backbone = nn.Sequential( - *list(resnet.children())[:-1], - nn.AdaptiveAvgPool2d(1) + *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1) ) self.projection_head = heads.SimCLRProjectionHead(512, 512, 128) self.criterion = DCLLoss() @@ -705,15 +670,12 @@ def training_step(self, batch, batch_index): z0 = self.forward(x0) z1 = self.forward(x1) loss = self.criterion(z0, z1) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): optim = torch.optim.SGD( - self.parameters(), - lr=6e-2 * lr_factor, - momentum=0.9, - weight_decay=5e-4 + self.parameters(), lr=6e-2 * lr_factor, momentum=0.9, weight_decay=5e-4 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] @@ -723,10 +685,9 @@ class DCLW(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes) # create a ResNet backbone and remove the classification head - resnet = ResNetGenerator('resnet-18') + resnet = ResNetGenerator("resnet-18") self.backbone = nn.Sequential( - *list(resnet.children())[:-1], - nn.AdaptiveAvgPool2d(1) + *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1) ) self.projection_head = heads.SimCLRProjectionHead(512, 512, 128) self.criterion = DCLWLoss() @@ -741,15 +702,12 @@ def training_step(self, batch, batch_index): z0 = self.forward(x0) z1 = self.forward(x1) loss = self.criterion(z0, z1) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): optim = torch.optim.SGD( - self.parameters(), - lr=6e-2 * lr_factor, - momentum=0.9, - weight_decay=5e-4 + self.parameters(), lr=6e-2 * lr_factor, momentum=0.9, weight_decay=5e-4 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] @@ -757,16 +715,15 @@ def configure_optimizers(self): from sklearn.cluster import KMeans -class SMoGModel(BenchmarkModule): +class SMoGModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes) # create a ResNet backbone and remove the classification head - resnet = ResNetGenerator('resnet-18') + resnet = ResNetGenerator("resnet-18") self.backbone = nn.Sequential( - *list(resnet.children())[:-1], - nn.AdaptiveAvgPool2d(1) + *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1) ) # create a model based on ResNet @@ -809,7 +766,6 @@ def _reset_momentum_weights(self): utils.deactivate_requires_grad(self.projection_head_momentum) def training_step(self, batch, batch_idx): - if self.global_step > 0 and self.global_step % 300 == 0: # reset group features and weights every 300 iterations self._reset_group_features() @@ -817,7 +773,9 @@ def training_step(self, batch, batch_idx): else: # update momentum utils.update_momentum(self.backbone, self.backbone_momentum, 0.99) - utils.update_momentum(self.projection_head, self.projection_head_momentum, 0.99) + utils.update_momentum( + self.projection_head, self.projection_head_momentum, 0.99 + ) (x0, x1), _, _ = batch @@ -845,9 +803,13 @@ def training_step(self, batch, batch_idx): return loss def configure_optimizers(self): - params = list(self.backbone.parameters()) + list(self.projection_head.parameters()) + list(self.prediction_head.parameters()) + params = ( + list(self.backbone.parameters()) + + list(self.projection_head.parameters()) + + list(self.prediction_head.parameters()) + ) optim = torch.optim.SGD( - params, + params, lr=0.01, momentum=0.9, weight_decay=1e-6, @@ -856,8 +818,6 @@ def configure_optimizers(self): return [optim], [scheduler] - - models = [ BarlowTwinsModel, BYOLModel, @@ -869,7 +829,7 @@ def configure_optimizers(self): SimCLRModel, SimSiamModel, SwaVModel, - SMoGModel + SMoGModel, ] bench_results = dict() @@ -877,21 +837,21 @@ def configure_optimizers(self): # loop through configurations and train models for BenchmarkModel in models: runs = [] - model_name = BenchmarkModel.__name__.replace('Model', '') + model_name = BenchmarkModel.__name__.replace("Model", "") for seed in range(n_runs): pl.seed_everything(seed) dataloader_train_ssl, dataloader_train_kNN, dataloader_test = get_data_loaders( - batch_size=batch_size, + batch_size=batch_size, model=BenchmarkModel, ) benchmark_model = BenchmarkModel(dataloader_train_kNN, classes) # Save logs to: {CWD}/benchmark_logs/cifar10/{experiment_version}/{model_name}/ # If multiple runs are specified a subdirectory for each run is created. - sub_dir = model_name if n_runs <= 1 else f'{model_name}/run{seed}' + sub_dir = model_name if n_runs <= 1 else f"{model_name}/run{seed}" logger = TensorBoardLogger( - save_dir=os.path.join(logs_root_dir, 'cifar10'), - name='', + save_dir=os.path.join(logs_root_dir, "cifar10"), + name="", sub_dir=sub_dir, version=experiment_version, ) @@ -899,32 +859,32 @@ def configure_optimizers(self): # Save results of all models under same version directory experiment_version = logger.version checkpoint_callback = pl.callbacks.ModelCheckpoint( - dirpath=os.path.join(logger.log_dir, 'checkpoints') + dirpath=os.path.join(logger.log_dir, "checkpoints") ) trainer = pl.Trainer( - max_epochs=max_epochs, + max_epochs=max_epochs, gpus=gpus, default_root_dir=logs_root_dir, strategy=distributed_backend, sync_batchnorm=sync_batchnorm, logger=logger, - callbacks=[checkpoint_callback] + callbacks=[checkpoint_callback], ) start = time.time() trainer.fit( benchmark_model, train_dataloaders=dataloader_train_ssl, - val_dataloaders=dataloader_test + val_dataloaders=dataloader_test, ) end = time.time() run = { - 'model': model_name, - 'batch_size': batch_size, - 'epochs': max_epochs, - 'max_accuracy': benchmark_model.max_accuracy, - 'runtime': end - start, - 'gpu_memory_usage': torch.cuda.max_memory_allocated(), - 'seed': seed, + "model": model_name, + "batch_size": batch_size, + "epochs": max_epochs, + "max_accuracy": benchmark_model.max_accuracy, + "runtime": end - start, + "gpu_memory_usage": torch.cuda.max_memory_allocated(), + "seed": seed, } runs.append(run) print(run) @@ -934,23 +894,23 @@ def configure_optimizers(self): del trainer torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() - + bench_results[model_name] = runs -# print results table +# print results table header = ( f"| {'Model':<13} | {'Batch Size':>10} | {'Epochs':>6} " f"| {'KNN Test Accuracy':>18} | {'Time':>10} | {'Peak GPU Usage':>14} |" ) -print('-' * len(header)) +print("-" * len(header)) print(header) -print('-' * len(header)) +print("-" * len(header)) for model, results in bench_results.items(): - runtime = np.array([result['runtime'] for result in results]) - runtime = runtime.mean() / 60 # convert to min - accuracy = np.array([result['max_accuracy'] for result in results]) - gpu_memory_usage = np.array([result['gpu_memory_usage'] for result in results]) - gpu_memory_usage = gpu_memory_usage.max() / (1024**3) # convert to gbyte + runtime = np.array([result["runtime"] for result in results]) + runtime = runtime.mean() / 60 # convert to min + accuracy = np.array([result["max_accuracy"] for result in results]) + gpu_memory_usage = np.array([result["gpu_memory_usage"] for result in results]) + gpu_memory_usage = gpu_memory_usage.max() / (1024**3) # convert to gbyte if len(accuracy) > 1: accuracy_msg = f"{accuracy.mean():>8.3f} +- {accuracy.std():>4.3f}" @@ -961,6 +921,6 @@ def configure_optimizers(self): f"| {model:<13} | {batch_size:>10} | {max_epochs:>6} " f"| {accuracy_msg} | {runtime:>6.1f} Min " f"| {gpu_memory_usage:>8.1f} GByte |", - flush=True + flush=True, ) -print('-' * len(header)) +print("-" * len(header)) diff --git a/docs/source/getting_started/benchmarks/imagenet100_benchmark.py b/docs/source/getting_started/benchmarks/imagenet100_benchmark.py index d98036166..da2f65005 100644 --- a/docs/source/getting_started/benchmarks/imagenet100_benchmark.py +++ b/docs/source/getting_started/benchmarks/imagenet100_benchmark.py @@ -26,37 +26,37 @@ """ import copy import os - import time + import numpy as np import pytorch_lightning as pl import torch import torch.nn as nn -from torch.optim.lr_scheduler import LambdaLR import torchvision +from pl_bolts.optimizers.lars import LARS +from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay +from pytorch_lightning.loggers import TensorBoardLogger +from torch.optim.lr_scheduler import LambdaLR + from lightly.data import ( + DINOCollateFunction, LightlyDataset, SimCLRCollateFunction, SwaVCollateFunction, - DINOCollateFunction, collate, ) from lightly.loss import ( - NTXentLoss, - NegativeCosineSimilarity, - DINOLoss, BarlowTwinsLoss, + DINOLoss, + NegativeCosineSimilarity, + NTXentLoss, SwaVLoss, ) from lightly.models import modules, utils from lightly.models.modules import heads from lightly.utils.benchmarking import BenchmarkModule -from pl_bolts.optimizers.lars import LARS -from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay -from pytorch_lightning.loggers import TensorBoardLogger - -logs_root_dir = os.path.join(os.getcwd(), 'benchmark_logs') +logs_root_dir = os.path.join(os.getcwd(), "benchmark_logs") num_workers = 12 memory_bank_size = 2**16 @@ -66,33 +66,33 @@ knn_k = 20 knn_t = 0.1 classes = 100 -input_size=224 +input_size = 224 # Set to True to enable Distributed Data Parallel training. distributed = False -# Set to True to enable Synchronized Batch Norm (requires distributed=True). +# Set to True to enable Synchronized Batch Norm (requires distributed=True). # If enabled the batch norm is calculated over all gpus, otherwise the batch # norm is only calculated from samples on the same gpu. sync_batchnorm = False -# Set to True to gather features from all gpus before calculating +# Set to True to gather features from all gpus before calculating # the loss (requires distributed=True). -# If enabled then the loss on every gpu is calculated with features from all +# If enabled then the loss on every gpu is calculated with features from all # gpus, otherwise only features from the same gpu are used. -gather_distributed = False +gather_distributed = False # benchmark -n_runs = 1 # optional, increase to create multiple runs and report mean + std +n_runs = 1 # optional, increase to create multiple runs and report mean + std batch_size = 256 -lr_factor = batch_size / 256 # scales the learning rate linearly with batch size +lr_factor = batch_size / 256 # scales the learning rate linearly with batch size # use a GPU if available gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 if distributed: - distributed_backend = 'ddp' + distributed_backend = "ddp" # reduce batch size for distributed training batch_size = batch_size // gpus else: @@ -102,8 +102,8 @@ # The dataset structure should be like this: -path_to_train = '/datasets/imagenet100/train/' -path_to_test = '/datasets/imagenet100/val/' +path_to_train = "/datasets/imagenet100/train/" +path_to_test = "/datasets/imagenet100/val/" # Use SimCLR augmentations collate_fn = SimCLRCollateFunction( @@ -117,30 +117,24 @@ dino_collate_fn = DINOCollateFunction() # No additional augmentations for the test set -test_transforms = torchvision.transforms.Compose([ - torchvision.transforms.Resize(input_size), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize( - mean=collate.imagenet_normalize['mean'], - std=collate.imagenet_normalize['std'], - ) -]) - -dataset_train_ssl = LightlyDataset( - input_dir=path_to_train +test_transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize(input_size), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + mean=collate.imagenet_normalize["mean"], + std=collate.imagenet_normalize["std"], + ), + ] ) +dataset_train_ssl = LightlyDataset(input_dir=path_to_train) + # we use test transformations for getting the feature for kNN on train data -dataset_train_kNN = LightlyDataset( - input_dir=path_to_train, - transform=test_transforms -) +dataset_train_kNN = LightlyDataset(input_dir=path_to_train, transform=test_transforms) -dataset_test = LightlyDataset( - input_dir=path_to_test, - transform=test_transforms -) +dataset_test = LightlyDataset(input_dir=path_to_test, transform=test_transforms) steps_per_epoch = len(dataset_train_ssl) // batch_size @@ -161,7 +155,7 @@ def get_data_loaders(batch_size: int, model): shuffle=True, collate_fn=col_fn, drop_last=True, - num_workers=num_workers + num_workers=num_workers, ) dataloader_train_kNN = torch.utils.data.DataLoader( @@ -169,7 +163,7 @@ def get_data_loaders(batch_size: int, model): batch_size=batch_size, shuffle=False, drop_last=False, - num_workers=num_workers + num_workers=num_workers, ) dataloader_test = torch.utils.data.DataLoader( @@ -177,7 +171,7 @@ def get_data_loaders(batch_size: int, model): batch_size=batch_size, shuffle=False, drop_last=False, - num_workers=num_workers + num_workers=num_workers, ) return dataloader_train_ssl, dataloader_train_kNN, dataloader_test @@ -192,9 +186,7 @@ def __init__(self, dataloader_kNN, num_classes): # TODO: Add split batch norm to the resnet model resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) # create a moco model based on ResNet self.projection_head = heads.MoCoProjectionHead(feature_dim, 2048, 128) @@ -204,9 +196,7 @@ def __init__(self, dataloader_kNN, num_classes): utils.deactivate_requires_grad(self.projection_head_momentum) # create our loss with the optional memory bank - self.criterion = NTXentLoss( - temperature=0.07, - memory_bank_size=memory_bank_size) + self.criterion = NTXentLoss(temperature=0.07, memory_bank_size=memory_bank_size) def forward(self, x): x = self.backbone(x).flatten(start_dim=1) @@ -217,7 +207,9 @@ def training_step(self, batch, batch_idx): # update momentum utils.update_momentum(self.backbone, self.backbone_momentum, 0.999) - utils.update_momentum(self.projection_head, self.projection_head_momentum, 0.999) + utils.update_momentum( + self.projection_head, self.projection_head_momentum, 0.999 + ) def step(x0_, x1_): x1_, shuffle = utils.batch_shuffle(x1_, distributed=distributed) @@ -235,13 +227,15 @@ def step(x0_, x1_): loss_2 = self.criterion(*step(x1, x0)) loss = 0.5 * (loss_1 + loss_2) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): - params = list(self.backbone.parameters()) + list(self.projection_head.parameters()) + params = list(self.backbone.parameters()) + list( + self.projection_head.parameters() + ) optim = torch.optim.SGD( - params, + params, lr=0.03 * lr_factor, momentum=0.9, weight_decay=1e-4, @@ -256,9 +250,7 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = heads.SimCLRProjectionHead(feature_dim, feature_dim, 128) self.criterion = NTXentLoss(temperature=0.1) @@ -272,39 +264,38 @@ def training_step(self, batch, batch_index): z0 = self.forward(x0) z1 = self.forward(x1) loss = self.criterion(z0, z1) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): optim = LARS( - self.parameters(), + self.parameters(), lr=0.3 * lr_factor, - momentum=0.9, + momentum=0.9, weight_decay=1e-6, ) scheduler = { "scheduler": LambdaLR( optimizer=optim, lr_lambda=linear_warmup_decay( - warmup_steps=steps_per_epoch * 10, - total_steps=steps_per_epoch * max_epochs, + warmup_steps=steps_per_epoch * 10, + total_steps=steps_per_epoch * max_epochs, cosine=True, - ) + ), ), "interval": "step", "frequency": 1, } return [optim], [scheduler] + class SimSiamModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes, knn_k=knn_k, knn_t=knn_t) # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.prediction_head = heads.SimSiamPredictionHead(2048, 512, 2048) self.projection_head = heads.SimSiamProjectionHead(feature_dim, 512, 2048) self.criterion = NegativeCosineSimilarity() @@ -321,12 +312,12 @@ def training_step(self, batch, batch_idx): z0, p0 = self.forward(x0) z1, p1 = self.forward(x1) loss = 0.5 * (self.criterion(z0, p1) + self.criterion(z1, p0)) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): optim = torch.optim.SGD( - self.parameters(), + self.parameters(), lr=0.05 * lr_factor, momentum=0.9, weight_decay=1e-4, @@ -334,15 +325,14 @@ def configure_optimizers(self): scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] + class BarlowTwinsModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes, knn_k=knn_k, knn_t=knn_t) # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) # use a 2-layer projection head for cifar10 as described in the paper self.projection_head = heads.BarlowTwinsProjectionHead(feature_dim, 2048, 2048) @@ -358,39 +348,38 @@ def training_step(self, batch, batch_index): z0 = self.forward(x0) z1 = self.forward(x1) loss = self.criterion(z0, z1) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): optim = LARS( - self.parameters(), + self.parameters(), lr=0.2 * lr_factor, - momentum=0.9, + momentum=0.9, weight_decay=1.5 * 1e-6, ) scheduler = { "scheduler": LambdaLR( optimizer=optim, lr_lambda=linear_warmup_decay( - warmup_steps=steps_per_epoch * 10, - total_steps=steps_per_epoch * max_epochs, + warmup_steps=steps_per_epoch * 10, + total_steps=steps_per_epoch * max_epochs, cosine=True, - ) + ), ), "interval": "step", "frequency": 1, } return [optim], [scheduler] + class BYOLModel(BenchmarkModule): def __init__(self, dataloader_kNN, num_classes): super().__init__(dataloader_kNN, num_classes, knn_k=knn_k, knn_t=knn_t) # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) # create a byol model based on ResNet self.projection_head = heads.BYOLProjectionHead(feature_dim, 4096, 256) @@ -418,20 +407,24 @@ def forward_momentum(self, x): def training_step(self, batch, batch_idx): utils.update_momentum(self.backbone, self.backbone_momentum, m=0.999) - utils.update_momentum(self.projection_head, self.projection_head_momentum, m=0.999) + utils.update_momentum( + self.projection_head, self.projection_head_momentum, m=0.999 + ) (x0, x1), _, _ = batch p0 = self.forward(x0) z0 = self.forward_momentum(x0) p1 = self.forward(x1) z1 = self.forward_momentum(x1) loss = 0.5 * (self.criterion(p0, z1) + self.criterion(p1, z0)) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): - params = list(self.backbone.parameters()) \ - + list(self.projection_head.parameters()) \ + params = ( + list(self.backbone.parameters()) + + list(self.projection_head.parameters()) + list(self.prediction_head.parameters()) + ) optim = LARS( params, lr=0.2 * lr_factor, @@ -442,10 +435,10 @@ def configure_optimizers(self): "scheduler": LambdaLR( optimizer=optim, lr_lambda=linear_warmup_decay( - warmup_steps=steps_per_epoch * 10, - total_steps=steps_per_epoch * max_epochs, + warmup_steps=steps_per_epoch * 10, + total_steps=steps_per_epoch * max_epochs, cosine=True, - ) + ), ), "interval": "step", "frequency": 1, @@ -459,9 +452,7 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.prediction_head = heads.NNCLRPredictionHead(256, 4096, 256) self.projection_head = heads.NNCLRProjectionHead(feature_dim, 4096, 256) @@ -486,9 +477,9 @@ def training_step(self, batch, batch_idx): def configure_optimizers(self): optim = LARS( - self.parameters(), + self.parameters(), lr=0.3 * lr_factor, - momentum=0.9, + momentum=0.9, weight_decay=1e-6, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) @@ -501,12 +492,10 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = heads.SwaVProjectionHead(feature_dim, 2048, 128) - self.prototypes = heads.SwaVPrototypes(128, 3000) # use 3000 prototypes + self.prototypes = heads.SwaVPrototypes(128, 3000) # use 3000 prototypes self.criterion = SwaVLoss(sinkhorn_gather_distributed=gather_distributed) @@ -517,7 +506,6 @@ def forward(self, x): return self.prototypes(x) def training_step(self, batch, batch_idx): - # normalize the prototypes so they are on the unit sphere self.prototypes.normalize() @@ -532,12 +520,9 @@ def training_step(self, batch, batch_idx): low_resolution_features = multi_crop_features[2:] # calculate the SwaV loss - loss = self.criterion( - high_resolution_features, - low_resolution_features - ) + loss = self.criterion(high_resolution_features, low_resolution_features) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): @@ -551,10 +536,10 @@ def configure_optimizers(self): "scheduler": LambdaLR( optimizer=optim, lr_lambda=linear_warmup_decay( - warmup_steps=steps_per_epoch * 10, - total_steps=steps_per_epoch * max_epochs, + warmup_steps=steps_per_epoch * 10, + total_steps=steps_per_epoch * max_epochs, cosine=True, - ) + ), ), "interval": "step", "frequency": 1, @@ -568,12 +553,14 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) + self.head = heads.DINOProjectionHead( + feature_dim, 2048, 256, 2048, batch_norm=True ) - self.head = heads.DINOProjectionHead(feature_dim, 2048, 256, 2048, batch_norm=True) self.teacher_backbone = copy.deepcopy(self.backbone) - self.teacher_head = heads.DINOProjectionHead(feature_dim, 2048, 256, 2048, batch_norm=True) + self.teacher_head = heads.DINOProjectionHead( + feature_dim, 2048, 256, 2048, batch_norm=True + ) utils.deactivate_requires_grad(self.teacher_backbone) utils.deactivate_requires_grad(self.teacher_head) @@ -599,12 +586,11 @@ def training_step(self, batch, batch_idx): teacher_out = [self.forward_teacher(view) for view in global_views] student_out = [self.forward(view) for view in views] loss = self.criterion(teacher_out, student_out, epoch=self.current_epoch) - self.log('train_loss_ssl', loss) + self.log("train_loss_ssl", loss) return loss def configure_optimizers(self): - param = list(self.backbone.parameters()) \ - + list(self.head.parameters()) + param = list(self.backbone.parameters()) + list(self.head.parameters()) optim = LARS( param, lr=0.3 * lr_factor, @@ -615,10 +601,10 @@ def configure_optimizers(self): "scheduler": LambdaLR( optimizer=optim, lr_lambda=linear_warmup_decay( - warmup_steps=steps_per_epoch * 10, - total_steps=steps_per_epoch * max_epochs, + warmup_steps=steps_per_epoch * 10, + total_steps=steps_per_epoch * max_epochs, cosine=True, - ) + ), ), "interval": "step", "frequency": 1, @@ -627,7 +613,7 @@ def configure_optimizers(self): models = [ - BarlowTwinsModel, + BarlowTwinsModel, BYOLModel, DINOModel, MocoModel, @@ -642,21 +628,21 @@ def configure_optimizers(self): # loop through configurations and train models for BenchmarkModel in models: runs = [] - model_name = BenchmarkModel.__name__.replace('Model', '') + model_name = BenchmarkModel.__name__.replace("Model", "") for seed in range(n_runs): pl.seed_everything(seed) dataloader_train_ssl, dataloader_train_kNN, dataloader_test = get_data_loaders( - batch_size=batch_size, + batch_size=batch_size, model=BenchmarkModel, ) benchmark_model = BenchmarkModel(dataloader_train_kNN, classes) # Save logs to: {CWD}/benchmark_logs/imagenet/{experiment_version}/{model_name}/ # If multiple runs are specified a subdirectory for each run is created. - sub_dir = model_name if n_runs <= 1 else f'{model_name}/run{seed}' + sub_dir = model_name if n_runs <= 1 else f"{model_name}/run{seed}" logger = TensorBoardLogger( - save_dir=os.path.join(logs_root_dir, 'imagenet'), - name='', + save_dir=os.path.join(logs_root_dir, "imagenet"), + name="", sub_dir=sub_dir, version=experiment_version, ) @@ -664,32 +650,32 @@ def configure_optimizers(self): # Save results of all models under same version directory experiment_version = logger.version checkpoint_callback = pl.callbacks.ModelCheckpoint( - dirpath=os.path.join(logger.log_dir, 'checkpoints') + dirpath=os.path.join(logger.log_dir, "checkpoints") ) trainer = pl.Trainer( - max_epochs=max_epochs, + max_epochs=max_epochs, gpus=gpus, default_root_dir=logs_root_dir, strategy=distributed_backend, sync_batchnorm=sync_batchnorm, logger=logger, - callbacks=[checkpoint_callback] + callbacks=[checkpoint_callback], ) start = time.time() trainer.fit( benchmark_model, train_dataloaders=dataloader_train_ssl, - val_dataloaders=dataloader_test + val_dataloaders=dataloader_test, ) end = time.time() run = { - 'model': model_name, - 'batch_size': batch_size, - 'epochs': max_epochs, - 'max_accuracy': benchmark_model.max_accuracy, - 'runtime': end - start, - 'gpu_memory_usage': torch.cuda.max_memory_allocated(), - 'seed': seed, + "model": model_name, + "batch_size": batch_size, + "epochs": max_epochs, + "max_accuracy": benchmark_model.max_accuracy, + "runtime": end - start, + "gpu_memory_usage": torch.cuda.max_memory_allocated(), + "seed": seed, } runs.append(run) print(run) @@ -707,15 +693,15 @@ def configure_optimizers(self): f"| {'Model':<13} | {'Batch Size':>10} | {'Epochs':>6} " f"| {'KNN Test Accuracy':>18} | {'Time':>10} | {'Peak GPU Usage':>14} |" ) -print('-' * len(header)) +print("-" * len(header)) print(header) -print('-' * len(header)) +print("-" * len(header)) for model, results in bench_results.items(): - runtime = np.array([result['runtime'] for result in results]) - runtime = runtime.mean() / 60 # convert to min - accuracy = np.array([result['max_accuracy'] for result in results]) - gpu_memory_usage = np.array([result['gpu_memory_usage'] for result in results]) - gpu_memory_usage = gpu_memory_usage.max() / (1024**3) # convert to gbyte + runtime = np.array([result["runtime"] for result in results]) + runtime = runtime.mean() / 60 # convert to min + accuracy = np.array([result["max_accuracy"] for result in results]) + gpu_memory_usage = np.array([result["gpu_memory_usage"] for result in results]) + gpu_memory_usage = gpu_memory_usage.max() / (1024**3) # convert to gbyte if len(accuracy) > 1: accuracy_msg = f"{accuracy.mean():>8.3f} +- {accuracy.std():>4.3f}" @@ -726,6 +712,6 @@ def configure_optimizers(self): f"| {model:<13} | {batch_size:>10} | {max_epochs:>6} " f"| {accuracy_msg} | {runtime:>6.1f} Min " f"| {gpu_memory_usage:>8.1f} GByte |", - flush=True + flush=True, ) -print('-' * len(header)) \ No newline at end of file +print("-" * len(header)) diff --git a/docs/source/getting_started/benchmarks/imagenette_benchmark.py b/docs/source/getting_started/benchmarks/imagenette_benchmark.py index d3e354f09..abd3604ac 100644 --- a/docs/source/getting_started/benchmarks/imagenette_benchmark.py +++ b/docs/source/getting_started/benchmarks/imagenette_benchmark.py @@ -61,43 +61,44 @@ """ import copy import os - import time + import numpy as np import pytorch_lightning as pl import torch import torch.nn as nn import torchvision +from pl_bolts.optimizers.lars import LARS +from pytorch_lightning.loggers import TensorBoardLogger + from lightly.data import ( + DINOCollateFunction, LightlyDataset, + MAECollateFunction, + MSNCollateFunction, SimCLRCollateFunction, SwaVCollateFunction, - DINOCollateFunction, - MSNCollateFunction, - MAECollateFunction, VICRegLCollateFunction, collate, ) from lightly.loss import ( - NTXentLoss, - NegativeCosineSimilarity, - DINOLoss, BarlowTwinsLoss, - SwaVLoss, - MSNLoss, DCLLoss, DCLWLoss, + DINOLoss, + MSNLoss, + NegativeCosineSimilarity, + NTXentLoss, + SwaVLoss, + TiCoLoss, VICRegLLoss, VICRegLoss, - TiCoLoss, memory_bank, ) from lightly.models import modules, utils from lightly.models.modules import heads, masked_autoencoder -from lightly.utils.benchmarking import BenchmarkModule from lightly.utils import scheduler -from pytorch_lightning.loggers import TensorBoardLogger -from pl_bolts.optimizers.lars import LARS +from lightly.utils.benchmarking import BenchmarkModule logs_root_dir = os.path.join(os.getcwd(), "benchmark_logs") @@ -202,13 +203,9 @@ dataset_train_ssl = LightlyDataset(input_dir=path_to_train) # we use test transformations for getting the feature for kNN on train data -dataset_train_kNN = LightlyDataset( - input_dir=path_to_train, transform=test_transforms -) +dataset_train_kNN = LightlyDataset(input_dir=path_to_train, transform=test_transforms) -dataset_test = LightlyDataset( - input_dir=path_to_test, transform=test_transforms -) +dataset_test = LightlyDataset(input_dir=path_to_test, transform=test_transforms) def get_data_loaders(batch_size: int, model): @@ -269,9 +266,7 @@ def __init__(self, dataloader_kNN, num_classes): # TODO: Add split batch norm to the resnet model resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) # create a moco model based on ResNet self.projection_head = heads.MoCoProjectionHead(feature_dim, 2048, 128) @@ -281,9 +276,7 @@ def __init__(self, dataloader_kNN, num_classes): utils.deactivate_requires_grad(self.projection_head_momentum) # create our loss with the optional memory bank - self.criterion = NTXentLoss( - temperature=0.1, memory_bank_size=memory_bank_size - ) + self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=memory_bank_size) def forward(self, x): x = self.backbone(x).flatten(start_dim=1) @@ -335,9 +328,7 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = heads.SimCLRProjectionHead(feature_dim, feature_dim, 128) self.criterion = NTXentLoss() @@ -368,9 +359,7 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = heads.SimSiamProjectionHead(feature_dim, 2048, 2048) self.prediction_head = heads.SimSiamPredictionHead(2048, 512, 2048) self.criterion = NegativeCosineSimilarity() @@ -407,15 +396,11 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) # use a 2-layer projection head for cifar10 as described in the paper self.projection_head = heads.BarlowTwinsProjectionHead(feature_dim, 2048, 2048) - self.criterion = BarlowTwinsLoss( - gather_distributed=gather_distributed - ) + self.criterion = BarlowTwinsLoss(gather_distributed=gather_distributed) def forward(self, x): x = self.backbone(x).flatten(start_dim=1) @@ -444,9 +429,7 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) # create a byol model based on ResNet self.projection_head = heads.BYOLProjectionHead(feature_dim, 4096, 256) @@ -508,9 +491,7 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = heads.NNCLRProjectionHead(feature_dim, 2048, 256) self.prediction_head = heads.NNCLRPredictionHead(256, 4096, 256) @@ -550,16 +531,12 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = heads.SwaVProjectionHead(feature_dim, 2048, 128) self.prototypes = heads.SwaVPrototypes(128, 3000) # use 3000 prototypes - self.criterion = SwaVLoss( - sinkhorn_gather_distributed=gather_distributed - ) + self.criterion = SwaVLoss(sinkhorn_gather_distributed=gather_distributed) def forward(self, x): x = self.backbone(x).flatten(start_dim=1) @@ -568,7 +545,6 @@ def forward(self, x): return self.prototypes(x) def training_step(self, batch, batch_idx): - # normalize the prototypes so they are on the unit sphere self.prototypes.normalize() @@ -604,9 +580,7 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.head = heads.DINOProjectionHead( feature_dim, 2048, 256, 2048, batch_norm=True ) @@ -660,9 +634,7 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = heads.SimCLRProjectionHead(feature_dim, feature_dim, 128) self.criterion = DCLLoss() @@ -693,9 +665,7 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() feature_dim = list(resnet.children())[-1].in_features - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = heads.SimCLRProjectionHead(feature_dim, feature_dim, 128) self.criterion = DCLWLoss() @@ -884,9 +854,7 @@ def __init__(self, dataloader_kNN, num_classes): # create a ResNet backbone and remove the classification head resnet = torchvision.models.resnet18() - self.backbone = nn.Sequential( - *list(resnet.children())[:-1] - ) + self.backbone = nn.Sequential(*list(resnet.children())[:-1]) # create a model based on ResNet self.projection_head = heads.SMoGProjectionHead(512, 2048, 128) @@ -899,9 +867,7 @@ def __init__(self, dataloader_kNN, num_classes): # smog self.n_groups = 300 memory_bank_size = 10000 - self.memory_bank = memory_bank.MemoryBankModule( - size=memory_bank_size - ) + self.memory_bank = memory_bank.MemoryBankModule(size=memory_bank_size) # create our loss group_features = torch.nn.functional.normalize( torch.rand(self.n_groups, 128), dim=1 @@ -930,7 +896,6 @@ def _reset_momentum_weights(self): utils.deactivate_requires_grad(self.projection_head_momentum) def training_step(self, batch, batch_idx): - if self.global_step > 0 and self.global_step % 300 == 0: # reset group features and weights every 300 iterations self._reset_group_features() @@ -1244,7 +1209,6 @@ def _subforward(self, input): @torch.no_grad() def _get_queue_prototypes(self, high_resolution_features): - if len(high_resolution_features) != len(self.queues): raise ValueError( f"The number of queues ({len(self.queues)}) should be equal to the number of high " diff --git a/docs/source/getting_started/code_examples/plot_image_augmentations.py b/docs/source/getting_started/code_examples/plot_image_augmentations.py index d67ce957a..e672dab0b 100644 --- a/docs/source/getting_started/code_examples/plot_image_augmentations.py +++ b/docs/source/getting_started/code_examples/plot_image_augmentations.py @@ -1,9 +1,11 @@ import glob + from PIL import Image + import lightly # let's get all jpg filenames from a folder -glob_to_data = '/datasets/clothing-dataset/images/*.jpg' +glob_to_data = "/datasets/clothing-dataset/images/*.jpg" fnames = glob.glob(glob_to_data) # load the first two images using pillow @@ -21,4 +23,4 @@ # we can also use the DINO collate function instead collate_fn_dino = lightly.data.DINOCollateFunction() -fig = lightly.utils.debug.plot_augmented_images(input_images, collate_fn_dino) \ No newline at end of file +fig = lightly.utils.debug.plot_augmented_images(input_images, collate_fn_dino) diff --git a/docs/source/tutorials_source/package/tutorial_custom_augmentations.py b/docs/source/tutorials_source/package/tutorial_custom_augmentations.py index e90c895a3..e3e97ada2 100644 --- a/docs/source/tutorials_source/package/tutorial_custom_augmentations.py +++ b/docs/source/tutorials_source/package/tutorial_custom_augmentations.py @@ -28,37 +28,41 @@ """ +import copy + # %% # Imports # ------- # # Import the Python frameworks we need for this tutorial. import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas +import pytorch_lightning as pl import torch import torch.nn as nn import torchvision -import pytorch_lightning as pl -import matplotlib.pyplot as plt +from PIL import Image from sklearn.neighbors import NearestNeighbors from sklearn.preprocessing import normalize -from PIL import Image -import numpy as np -import pandas -import copy -from lightly.data import LightlyDataset, BaseCollateFunction +from lightly.data import BaseCollateFunction, LightlyDataset from lightly.loss import NTXentLoss from lightly.models.modules.heads import MoCoProjectionHead -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum -from lightly.models.utils import batch_shuffle -from lightly.models.utils import batch_unshuffle +from lightly.models.utils import ( + batch_shuffle, + batch_unshuffle, + deactivate_requires_grad, + update_momentum, +) # %% # Configuration # ------------- # Let's set the configuration parameters for our experiments. -# +# # We will use eight workers to fetch the data from disc and a batch size of 128. # The input size of the images is set to 128. With these settings, the training # requires 2.5GB of GPU memory. @@ -77,7 +81,7 @@ # %% # Set the path to our dataset -path_to_data = '/datasets/vinbigdata/train_small' +path_to_data = "/datasets/vinbigdata/train_small" # %% # Setup custom data augmentations @@ -91,6 +95,7 @@ # Let's write an augmentation, which takes as input a numpy array with 16-bit input # depth and returns a histogram normalized 8-bit PIL image. + class HistogramNormalize: """Performs histogram normalization on numpy array and returns 8-bit image. @@ -103,7 +108,6 @@ def __init__(self, number_bins: int = 256): self.number_bins = number_bins def __call__(self, image: np.array) -> Image: - # get image histogram image_histogram, bins = np.histogram( image.flatten(), self.number_bins, density=True @@ -115,11 +119,13 @@ def __call__(self, image: np.array) -> Image: image_equalized = np.interp(image.flatten(), bins[:-1], cdf) return Image.fromarray(image_equalized.reshape(image.shape)) + # %% # Since we can't use color jitter on X-ray images, let's replace it and add some # Gaussian noise instead. It's easiest to apply this after the image has been # converted to a PyTorch tensor. + class GaussianNoise: """Applies random Gaussian noise to a tensor. @@ -135,6 +141,7 @@ def __call__(self, sample: torch.Tensor) -> torch.Tensor: noise = torch.normal(torch.zeros(sample.shape), sigma) return sample + noise + # %% # Now that we have implemented our custom augmentations, we can combine them # with available augmentations from the torchvision library to get to the same @@ -148,24 +155,26 @@ def __call__(self, sample: torch.Tensor) -> torch.Tensor: # is used. # compose the custom augmentations with available augmentations -transform = torchvision.transforms.Compose([ - HistogramNormalize(), - torchvision.transforms.Grayscale(num_output_channels=3), - torchvision.transforms.RandomResizedCrop(size=input_size, scale=(0.2, 1.0)), - torchvision.transforms.RandomHorizontalFlip(p=0.5), - torchvision.transforms.RandomVerticalFlip(p=0.5), - torchvision.transforms.GaussianBlur(21), - torchvision.transforms.ToTensor(), - GaussianNoise(), -]) +transform = torchvision.transforms.Compose( + [ + HistogramNormalize(), + torchvision.transforms.Grayscale(num_output_channels=3), + torchvision.transforms.RandomResizedCrop(size=input_size, scale=(0.2, 1.0)), + torchvision.transforms.RandomHorizontalFlip(p=0.5), + torchvision.transforms.RandomVerticalFlip(p=0.5), + torchvision.transforms.GaussianBlur(21), + torchvision.transforms.ToTensor(), + GaussianNoise(), + ] +) # %% # Let's take a look at what our augmentation pipeline does to an image! -# We plot the original image on the left and two random augmentations on the +# We plot the original image on the left and two random augmentations on the # right. -example_image_name = '55e8e3db7309febee415515d06418171.tiff' +example_image_name = "55e8e3db7309febee415515d06418171.tiff" example_image_path = os.path.join(path_to_data, example_image_name) example_image = np.array(Image.open(example_image_path)) @@ -177,7 +186,7 @@ def __call__(self, sample: torch.Tensor) -> torch.Tensor: axs[0].imshow(example_image) axs[0].set_axis_off() -axs[0].set_title('Original Image') +axs[0].set_title("Original Image") axs[1].imshow(augmented_image_1) axs[1].set_axis_off() @@ -197,7 +206,7 @@ def __call__(self, sample: torch.Tensor) -> torch.Tensor: # ------------------------------ # # We create a dataset which points to the images in the input directory. Since -# the input images are 16 bits deep, we need to overwrite the image loader such +# the input images are 16 bits deep, we need to overwrite the image loader such # that it doesn't convert the images to RGB (and hence to 8-bit) automatically. # # .. note:: The `LightlyDataset` uses a torchvision dataset underneath, which in turn uses @@ -205,14 +214,14 @@ def __call__(self, sample: torch.Tensor) -> torch.Tensor: # grayscale image is loaded that way, all pixel values above 255 are simply clamped. # Therefore, we overwrite the default image loader with our custom one. -def tiff_loader(f): - """Loads a 16-bit tiff image and returns it as a numpy array. - """ - with open(f, 'rb') as f: +def tiff_loader(f): + """Loads a 16-bit tiff image and returns it as a numpy array.""" + with open(f, "rb") as f: image = Image.open(f) return np.array(image) + # create the dataset and overwrite the image loader dataset_train = LightlyDataset(input_dir=path_to_data) dataset_train.dataset.loader = tiff_loader @@ -225,7 +234,7 @@ def tiff_loader(f): shuffle=True, collate_fn=collate_fn, drop_last=True, - num_workers=num_workers + num_workers=num_workers, ) # %% @@ -242,6 +251,7 @@ def tiff_loader(f): # The choice of the optimizer is left to the user. Here, we go with simple stochastic # gradient descent with momentum. + class MoCoModel(pl.LightningModule): def __init__(self): super().__init__() @@ -264,18 +274,14 @@ def __init__(self): deactivate_requires_grad(self.projection_head_momentum) # create our loss with the memory bank - self.criterion = NTXentLoss( - temperature=0.1, memory_bank_size=4096 - ) + self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=4096) def training_step(self, batch, batch_idx): (x_q, x_k), _, _ = batch # update momentum update_momentum(self.backbone, self.backbone_momentum, 0.99) - update_momentum( - self.projection_head, self.projection_head_momentum, 0.99 - ) + update_momentum(self.projection_head, self.projection_head_momentum, 0.99) # get queries q = self.backbone(x_q).flatten(start_dim=1) @@ -299,9 +305,7 @@ def configure_optimizers(self): momentum=0.9, weight_decay=1e-4, ) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optim, max_epochs - ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] @@ -337,29 +341,25 @@ def configure_optimizers(self): # test transforms differ from training transforms as they do not introduce # additional noise -test_transforms = torchvision.transforms.Compose([ - HistogramNormalize(), - torchvision.transforms.Grayscale(num_output_channels=3), - torchvision.transforms.Resize(input_size), - torchvision.transforms.ToTensor(), -]) +test_transforms = torchvision.transforms.Compose( + [ + HistogramNormalize(), + torchvision.transforms.Grayscale(num_output_channels=3), + torchvision.transforms.Resize(input_size), + torchvision.transforms.ToTensor(), + ] +) # create the dataset and overwrite the image loader as before -dataset_test = LightlyDataset( - input_dir=path_to_data, - transform=test_transforms -) +dataset_test = LightlyDataset(input_dir=path_to_data, transform=test_transforms) dataset_test.dataset.loader = tiff_loader # create the test dataloader dataloader_test = torch.utils.data.DataLoader( - dataset_test, - batch_size=1, - shuffle=False, - drop_last=False, - num_workers=num_workers + dataset_test, batch_size=1, shuffle=False, drop_last=False, num_workers=num_workers ) + # next we add a small helper function to generate embeddings of our images def generate_embeddings(model, dataloader): """Generates representations for all images in the dataloader""" @@ -391,22 +391,22 @@ def generate_embeddings(model, dataloader): # of the critical findings in the nearest neighbor images (light blue) as bar plots. # transform the original bounding box annotations to multiclass labels -fnames = [fname.split('.')[0] for fname in fnames] +fnames = [fname.split(".")[0] for fname in fnames] -df = pandas.read_csv('/datasets/vinbigdata/train.csv') +df = pandas.read_csv("/datasets/vinbigdata/train.csv") classes = list(np.unique(df.class_name)) filenames = list(np.unique(df.image_id)) # iterate over all bounding boxes and add a one-hot label if an image contains -# a bounding box of a given class, after that, the array "multilabels" will -# contain a row for every image in the input dataset and each row of the +# a bounding box of a given class, after that, the array "multilabels" will +# contain a row for every image in the input dataset and each row of the # array contains a one-hot vector of critical findings for this image multilabels = np.zeros((len(dataset_test.get_filenames()), len(classes))) for filename, label in zip(df.image_id, df.class_name): try: - i = fnames.index(filename.split('.')[0]) + i = fnames.index(filename.split(".")[0]) j = classes.index(label) - multilabels[i, j] = 1. + multilabels[i, j] = 1.0 except Exception: pass @@ -428,20 +428,18 @@ def plot_knn_multilabels( # loop through our randomly picked samples for idx in samples_idx: fig = plt.figure() - + bars1 = multilabels[idx] bars2 = np.mean(multilabels[indices[idx]], axis=0) plt.title(filenames[idx]) - plt.bar(r1, bars1, color='steelblue', edgecolor='black', width=bar_width) - plt.bar(r2, bars2, color='lightsteelblue', edgecolor='black', width=bar_width) + plt.bar(r1, bars1, color="steelblue", edgecolor="black", width=bar_width) + plt.bar(r2, bars2, color="lightsteelblue", edgecolor="black", width=bar_width) plt.xticks(0.5 * (r1 + r2), classes, rotation=90) plt.tight_layout() -# plot the distribution of the multilabels of the k nearest neighbors of +# plot the distribution of the multilabels of the k nearest neighbors of # the three example images at index 4111, 3340, 1796 k = 20 -plot_knn_multilabels( - embeddings, multilabels, [4111, 3340, 1796], fnames, n_neighbors=k -) +plot_knn_multilabels(embeddings, multilabels, [4111, 3340, 1796], fnames, n_neighbors=k) diff --git a/docs/source/tutorials_source/package/tutorial_moco_memory_bank.py b/docs/source/tutorials_source/package/tutorial_moco_memory_bank.py index 1960d6fbb..19b2b993b 100644 --- a/docs/source/tutorials_source/package/tutorial_moco_memory_bank.py +++ b/docs/source/tutorials_source/package/tutorial_moco_memory_bank.py @@ -42,25 +42,28 @@ # # pip install lightly +import copy + +import pytorch_lightning as pl import torch import torch.nn as nn import torchvision -import pytorch_lightning as pl -import copy from lightly.data import LightlyDataset, SimCLRCollateFunction, collate from lightly.loss import NTXentLoss from lightly.models import ResNetGenerator from lightly.models.modules.heads import MoCoProjectionHead -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum -from lightly.models.utils import batch_shuffle -from lightly.models.utils import batch_unshuffle +from lightly.models.utils import ( + batch_shuffle, + batch_unshuffle, + deactivate_requires_grad, + update_momentum, +) # %% # Configuration # ------------- -# +# # We set some configuration parameters for our experiment. # Feel free to change them and analyze the effect. # @@ -80,7 +83,7 @@ # We assume we have a train folder with subfolders # for each class and .png images inside. # -# You can download `CIFAR-10 in folders from Kaggle +# You can download `CIFAR-10 in folders from Kaggle # `_. # The dataset structure should be like this: @@ -97,8 +100,8 @@ # L horse/ # L ship/ # L truck/ -path_to_train = '/datasets/cifar10/train/' -path_to_test = '/datasets/cifar10/test/' +path_to_train = "/datasets/cifar10/train/" +path_to_test = "/datasets/cifar10/test/" # %% # Let's set the seed to ensure reproducibility of the experiments @@ -113,9 +116,9 @@ # from the MOCO paper using the collate functions provided by lightly. For MoCo v2, # we can use the same augmentations as SimCLR but override the input size and blur. # Images from the CIFAR-10 dataset have a resolution of 32x32 pixels. Let's use -# this resolution to train our model. +# this resolution to train our model. # -# .. note:: We could use a higher input resolution to train our model. However, +# .. note:: We could use a higher input resolution to train our model. However, # since the original resolution of CIFAR-10 images is low there is no real value # in increasing the resolution. A higher resolution results in higher memory # consumption and to compensate for that we would need to reduce the batch size. @@ -123,7 +126,7 @@ # MoCo v2 uses SimCLR augmentations, additionally, disable blur collate_fn = SimCLRCollateFunction( input_size=32, - gaussian_blur=0., + gaussian_blur=0.0, ) # %% @@ -133,48 +136,46 @@ # the same way as we do with the training data. # Augmentations typically used to train on cifar-10 -train_classifier_transforms = torchvision.transforms.Compose([ - torchvision.transforms.RandomCrop(32, padding=4), - torchvision.transforms.RandomHorizontalFlip(), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize( - mean=collate.imagenet_normalize['mean'], - std=collate.imagenet_normalize['std'], - ) -]) +train_classifier_transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.RandomCrop(32, padding=4), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + mean=collate.imagenet_normalize["mean"], + std=collate.imagenet_normalize["std"], + ), + ] +) # No additional augmentations for the test set -test_transforms = torchvision.transforms.Compose([ - torchvision.transforms.Resize((32, 32)), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize( - mean=collate.imagenet_normalize['mean'], - std=collate.imagenet_normalize['std'], - ) -]) +test_transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize((32, 32)), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + mean=collate.imagenet_normalize["mean"], + std=collate.imagenet_normalize["std"], + ), + ] +) # We use the moco augmentations for training moco -dataset_train_moco = LightlyDataset( - input_dir=path_to_train -) +dataset_train_moco = LightlyDataset(input_dir=path_to_train) # Since we also train a linear classifier on the pre-trained moco model we -# reuse the test augmentations here (MoCo augmentations are very strong and +# reuse the test augmentations here (MoCo augmentations are very strong and # usually reduce accuracy of models which are not used for contrastive learning. # Our linear layer will be trained using cross entropy loss and labels provided # by the dataset. Therefore we chose light augmentations.) dataset_train_classifier = LightlyDataset( - input_dir=path_to_train, - transform=train_classifier_transforms + input_dir=path_to_train, transform=train_classifier_transforms ) -dataset_test = LightlyDataset( - input_dir=path_to_test, - transform=test_transforms -) +dataset_test = LightlyDataset(input_dir=path_to_test, transform=test_transforms) # %% -# Create the dataloaders to load and preprocess the data +# Create the dataloaders to load and preprocess the data # in the background. dataloader_train_moco = torch.utils.data.DataLoader( @@ -183,7 +184,7 @@ shuffle=True, collate_fn=collate_fn, drop_last=True, - num_workers=num_workers + num_workers=num_workers, ) dataloader_train_classifier = torch.utils.data.DataLoader( @@ -191,7 +192,7 @@ batch_size=batch_size, shuffle=True, drop_last=True, - num_workers=num_workers + num_workers=num_workers, ) dataloader_test = torch.utils.data.DataLoader( @@ -199,9 +200,10 @@ batch_size=batch_size, shuffle=False, drop_last=False, - num_workers=num_workers + num_workers=num_workers, ) + # %% # Create the MoCo Lightning Module # -------------------------------- @@ -220,9 +222,9 @@ class MocoModel(pl.LightningModule): def __init__(self): super().__init__() - + # create a ResNet backbone and remove the classification head - resnet = ResNetGenerator('resnet-18', 1, num_splits=8) + resnet = ResNetGenerator("resnet-18", 1, num_splits=8) self.backbone = nn.Sequential( *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1), @@ -236,18 +238,14 @@ def __init__(self): deactivate_requires_grad(self.projection_head_momentum) # create our loss with the optional memory bank - self.criterion = NTXentLoss( - temperature=0.1, - memory_bank_size=memory_bank_size) + self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=memory_bank_size) def training_step(self, batch, batch_idx): (x_q, x_k), _, _ = batch # update momentum update_momentum(self.backbone, self.backbone_momentum, 0.99) - update_momentum( - self.projection_head, self.projection_head_momentum, 0.99 - ) + update_momentum(self.projection_head, self.projection_head_momentum, 0.99) # get queries q = self.backbone(x_q).flatten(start_dim=1) @@ -270,8 +268,7 @@ def training_epoch_end(self, outputs): # which is useful for debugging. def custom_histogram_weights(self): for name, params in self.named_parameters(): - self.logger.experiment.add_histogram( - name, params, self.current_epoch) + self.logger.experiment.add_histogram(name, params, self.current_epoch) def configure_optimizers(self): optim = torch.optim.SGD( @@ -280,9 +277,7 @@ def configure_optimizers(self): momentum=0.9, weight_decay=5e-4, ) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optim, max_epochs - ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] @@ -292,6 +287,7 @@ def configure_optimizers(self): # We create a linear classifier using the features we extract using MoCo # and train it on the dataset + class Classifier(pl.LightningModule): def __init__(self, backbone): super().__init__() @@ -325,9 +321,7 @@ def training_epoch_end(self, outputs): # which is useful for debugging. def custom_histogram_weights(self): for name, params in self.named_parameters(): - self.logger.experiment.add_histogram( - name, params, self.current_epoch - ) + self.logger.experiment.add_histogram(name, params, self.current_epoch) def validation_step(self, batch, batch_idx): x, y, _ = batch @@ -352,7 +346,7 @@ def validation_epoch_end(self, outputs): self.log("val_acc", acc, on_epoch=True, prog_bar=True) def configure_optimizers(self): - optim = torch.optim.SGD(self.fc.parameters(), lr=30.) + optim = torch.optim.SGD(self.fc.parameters(), lr=30.0) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] @@ -368,36 +362,27 @@ def configure_optimizers(self): gpus = 1 if torch.cuda.is_available() else 0 model = MocoModel() -trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, - progress_bar_refresh_rate=100) -trainer.fit( - model, - dataloader_train_moco -) +trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, progress_bar_refresh_rate=100) +trainer.fit(model, dataloader_train_moco) # %% # Train the Classifier model.eval() classifier = Classifier(model.backbone) -trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, - progress_bar_refresh_rate=100) -trainer.fit( - classifier, - dataloader_train_classifier, - dataloader_test -) +trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, progress_bar_refresh_rate=100) +trainer.fit(classifier, dataloader_train_classifier, dataloader_test) # %% # Checkout the tensorboard logs while the model is training. # # Run `tensorboard --logdir lightning_logs/` to start tensorboard -# +# # .. note:: If you run the code on a remote machine you can't just # access the tensorboard logs. You need to forward the port. # You can do this by using an editor such as Visual Studio Code # which has a port forwarding functionality (make sure # the remote extensions are installed and are connected with your machine). -# +# # Or you can use a shell command similar to this one to forward port # 6006 from your remote machine to your local machine: # diff --git a/docs/source/tutorials_source/package/tutorial_pretrain_detectron2.py b/docs/source/tutorials_source/package/tutorial_pretrain_detectron2.py index b69ec382e..776b7f64b 100644 --- a/docs/source/tutorials_source/package/tutorial_pretrain_detectron2.py +++ b/docs/source/tutorials_source/package/tutorial_pretrain_detectron2.py @@ -67,7 +67,6 @@ from lightly.loss import NTXentLoss from lightly.models.modules import SimCLRProjectionHead - # %% # Configuration # ------------- @@ -89,7 +88,7 @@ max_epochs = 5 # use cuda if possible -device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = "cuda" if torch.cuda.is_available() else "cpu" # %% @@ -97,13 +96,14 @@ # Set the path to the dataset accordingly. Additionally, make sure to set the # path to the config file of the Detectron2 model you want to use. # We will be using an RCNN with a feature pyramid network (FPN). -data_path = '/datasets/freiburg_groceries_dataset/images' -cfg_path = './Base-RCNN-FPN.yaml' +data_path = "/datasets/freiburg_groceries_dataset/images" +cfg_path = "./Base-RCNN-FPN.yaml" + # %% # Initialize the Detectron2 Model # -------------------------------- -# +# # The output of the Detectron2 ResNet50 backbone is a dictionary with the keys # `res1` through `res5` (see the `documentation `_). # The keys correspond to the different stages of the ResNet. In this tutorial, we are only @@ -112,13 +112,14 @@ class SelectStage(torch.nn.Module): """Selects features from a given stage.""" - def __init__(self, stage: str = 'res5'): + def __init__(self, stage: str = "res5"): super().__init__() self.stage = stage def forward(self, x): return x[self.stage] + # %% # Let's load the config file and make some adjustments to ensure smooth training. cfg = config.get_cfg() @@ -141,7 +142,7 @@ def forward(self, x): simclr_backbone = torch.nn.Sequential( detmodel.backbone.bottom_up, - SelectStage('res5'), + SelectStage("res5"), # res5 has shape bsz x 2048 x 4 x 4 torch.nn.AdaptiveAvgPool2d(1), ).to(device) @@ -149,7 +150,7 @@ def forward(self, x): # %% # # -#.. note:: +# .. note:: # # The Detectron2 ResNet is missing the average pooling layer used to get a tensor of shape bsz x 2048. # Therefore, we add an average pooling as in the `PyTorch ResNet `_. @@ -170,7 +171,7 @@ def forward(self, x): # # We start by defining the augmentations which should be used for training. # We use the same ones as in the SimCLR paper but change the input size and -# minimum scale of the random crop to adjust to our dataset. +# minimum scale of the random crop to adjust to our dataset. # # We don't go into detail here about using the optimal augmentations. # You can learn more about the different augmentations and learned invariances @@ -185,7 +186,7 @@ def forward(self, x): shuffle=True, collate_fn=collate_fn, drop_last=True, - num_workers=num_workers + num_workers=num_workers, ) # %% @@ -201,10 +202,8 @@ def forward(self, x): for e in range(max_epochs): - - mean_loss = 0. + mean_loss = 0.0 for (x0, x1), _, _ in dataloader_train_simclr: - x0 = x0.to(device) x1 = x1.to(device) @@ -221,7 +220,7 @@ def forward(self, x): # update average loss mean_loss += loss.detach().cpu().item() / len(dataloader_train_simclr) - print(f'[Epoch {e:2d}] Mean Loss = {mean_loss:.2f}') + print(f"[Epoch {e:2d}] Mean Loss = {mean_loss:.2f}") # %% @@ -237,8 +236,8 @@ def forward(self, x): # L AdaptiveAvgPool2d detmodel.backbone.bottom_up = simclr_backbone[0] -checkpointer = DetectionCheckpointer(detmodel, save_dir='./') -checkpointer.save('my_model') +checkpointer = DetectionCheckpointer(detmodel, save_dir="./") +checkpointer.save("my_model") # %% @@ -251,7 +250,7 @@ def forward(self, x): # # %% -#.. code-block:: none +# .. code-block:: none # # python train_net.py --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \ # MODEL.WEIGHTS path/to/my_model.pth \ @@ -261,16 +260,16 @@ def forward(self, x): # # %% -# +# # The :py:class:`lightly.data.collate.SimCLRCollateFunction` applies an ImageNet # normalization of the input images by default. Therefore, we have to normalize # the input images at training time, too. Since Detectron2 uses an input space # in the range 0 - 255, we use the numbers above. -# +# # %% # -#.. note:: +# .. note:: # # Since the model was pre-trained with images in the RGB input format, it's # necessary to set the permute the order of the pixel mean, and pixel std as shown above. diff --git a/docs/source/tutorials_source/package/tutorial_simclr_clothing.py b/docs/source/tutorials_source/package/tutorial_simclr_clothing.py index 842554e69..59c80fe9b 100644 --- a/docs/source/tutorials_source/package/tutorial_simclr_clothing.py +++ b/docs/source/tutorials_source/package/tutorial_simclr_clothing.py @@ -34,22 +34,23 @@ # # Import the Python frameworks we need for this tutorial. import os + +import matplotlib.pyplot as plt +import numpy as np +import pytorch_lightning as pl import torch import torch.nn as nn import torchvision -import pytorch_lightning as pl -import matplotlib.pyplot as plt +from PIL import Image from sklearn.neighbors import NearestNeighbors from sklearn.preprocessing import normalize -from PIL import Image -import numpy as np from lightly.data import LightlyDataset, SimCLRCollateFunction, collate # %% # Configuration # ------------- -# +# # We set some configuration parameters for our experiment. # Feel free to change them and analyze the effect. # @@ -68,49 +69,42 @@ # %% # Make sure `path_to_data` points to the downloaded clothing dataset. -# You can download it using +# You can download it using # `git clone https://github.com/alexeygrigorev/clothing-dataset.git` -path_to_data = '/datasets/clothing-dataset/images' +path_to_data = "/datasets/clothing-dataset/images" # %% # Setup data augmentations and loaders # ------------------------------------ # -# The images from the dataset have been taken from above when the clothing was +# The images from the dataset have been taken from above when the clothing was # on a table, bed or floor. Therefore, we can make use of additional augmentations -# such as vertical flip or random rotation (90 degrees). -# By adding these augmentations we learn our model invariance regarding the +# such as vertical flip or random rotation (90 degrees). +# By adding these augmentations we learn our model invariance regarding the # orientation of the clothing piece. E.g. we don't care if a shirt is upside down # but more about the strcture which make it a shirt. -# +# # You can learn more about the different augmentations and learned invariances # here: :ref:`lightly-advanced`. -collate_fn = SimCLRCollateFunction( - input_size=input_size, - vf_prob=0.5, - rr_prob=0.5 -) +collate_fn = SimCLRCollateFunction(input_size=input_size, vf_prob=0.5, rr_prob=0.5) -# We create a torchvision transformation for embedding the dataset after +# We create a torchvision transformation for embedding the dataset after # training -test_transforms = torchvision.transforms.Compose([ - torchvision.transforms.Resize((input_size, input_size)), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize( - mean=collate.imagenet_normalize['mean'], - std=collate.imagenet_normalize['std'], - ) -]) - -dataset_train_simclr = LightlyDataset( - input_dir=path_to_data +test_transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize((input_size, input_size)), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + mean=collate.imagenet_normalize["mean"], + std=collate.imagenet_normalize["std"], + ), + ] ) -dataset_test = LightlyDataset( - input_dir=path_to_data, - transform=test_transforms -) +dataset_train_simclr = LightlyDataset(input_dir=path_to_data) + +dataset_test = LightlyDataset(input_dir=path_to_data, transform=test_transforms) dataloader_train_simclr = torch.utils.data.DataLoader( dataset_train_simclr, @@ -118,7 +112,7 @@ shuffle=True, collate_fn=collate_fn, drop_last=True, - num_workers=num_workers + num_workers=num_workers, ) dataloader_test = torch.utils.data.DataLoader( @@ -126,7 +120,7 @@ batch_size=batch_size, shuffle=False, drop_last=False, - num_workers=num_workers + num_workers=num_workers, ) # %% @@ -138,8 +132,8 @@ # and `NTXentLoss` classes. We can simply import them and combine the building # blocks in the module. -from lightly.models.modules.heads import SimCLRProjectionHead from lightly.loss import NTXentLoss +from lightly.models.modules.heads import SimCLRProjectionHead class SimCLRModel(pl.LightningModule): @@ -172,9 +166,7 @@ def configure_optimizers(self): optim = torch.optim.SGD( self.parameters(), lr=6e-2, momentum=0.9, weight_decay=5e-4 ) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optim, max_epochs - ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) return [optim], [scheduler] @@ -185,9 +177,7 @@ def configure_optimizers(self): gpus = 1 if torch.cuda.is_available() else 0 model = SimCLRModel() -trainer = pl.Trainer( - max_epochs=max_epochs, gpus=gpus, progress_bar_refresh_rate=100 -) +trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, progress_bar_refresh_rate=100) trainer.fit(model, dataloader_train_simclr) # %% @@ -222,22 +212,21 @@ def generate_embeddings(model, dataloader): # %% # Visualize Nearest Neighbors -#---------------------------- -# Let's look at the trained embedding and visualize the nearest neighbors for +# ---------------------------- +# Let's look at the trained embedding and visualize the nearest neighbors for # a few random samples. # # We create some helper functions to simplify the work + def get_image_as_np_array(filename: str): - """Returns an image as an numpy array - """ + """Returns an image as an numpy array""" img = Image.open(filename) return np.asarray(img) def plot_knn_examples(embeddings, filenames, n_neighbors=3, num_examples=6): - """Plots multiple rows of random images with their nearest neighbors - """ + """Plots multiple rows of random images with their nearest neighbors""" # lets look at the nearest neighbors for some samples # we use the sklearn library nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(embeddings) @@ -258,9 +247,9 @@ def plot_knn_examples(embeddings, filenames, n_neighbors=3, num_examples=6): # plot the image plt.imshow(get_image_as_np_array(fname)) # set the title to the distance of the neighbor - ax.set_title(f'd={distances[idx][plot_x_offset]:.3f}') + ax.set_title(f"d={distances[idx][plot_x_offset]:.3f}") # let's disable the axis - plt.axis('off') + plt.axis("off") # %% @@ -272,16 +261,12 @@ def plot_knn_examples(embeddings, filenames, n_neighbors=3, num_examples=6): # %% # Color Invariance # --------------------- -# Let's train again without color augmentation. This will force our model to +# Let's train again without color augmentation. This will force our model to # respect the colors in the images. # Set color jitter and gray scale probability to 0 new_collate_fn = SimCLRCollateFunction( - input_size=input_size, - vf_prob=0.5, - rr_prob=0.5, - cj_prob=0.0, - random_gray_scale=0.0 + input_size=input_size, vf_prob=0.5, rr_prob=0.5, cj_prob=0.0, random_gray_scale=0.0 ) # let's update our collate method and reuse our dataloader @@ -289,9 +274,7 @@ def plot_knn_examples(embeddings, filenames, n_neighbors=3, num_examples=6): # then train a new model model = SimCLRModel() -trainer = pl.Trainer( - max_epochs=max_epochs, gpus=gpus, progress_bar_refresh_rate=100 -) +trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus, progress_bar_refresh_rate=100) trainer.fit(model, dataloader_train_simclr) # and generate again embeddings from the test set @@ -309,10 +292,8 @@ def plot_knn_examples(embeddings, filenames, n_neighbors=3, num_examples=6): pretrained_resnet_backbone = model.backbone # you can also store the backbone and use it in another code -state_dict = { - 'resnet18_parameters': pretrained_resnet_backbone.state_dict() -} -torch.save(state_dict, 'model.pth') +state_dict = {"resnet18_parameters": pretrained_resnet_backbone.state_dict()} +torch.save(state_dict, "model.pth") # %% # THIS COULD BE IN A NEW FILE (e.g. inference.py) @@ -325,8 +306,8 @@ def plot_knn_examples(embeddings, filenames, n_neighbors=3, num_examples=6): # note that we need to create exactly the same backbone in order to load the weights backbone_new = nn.Sequential(*list(resnet18_new.children())[:-1]) -ckpt = torch.load('model.pth') -backbone_new.load_state_dict(ckpt['resnet18_parameters']) +ckpt = torch.load("model.pth") +backbone_new.load_state_dict(ckpt["resnet18_parameters"]) # %% # Next Steps diff --git a/docs/source/tutorials_source/package/tutorial_simsiam_esa.py b/docs/source/tutorials_source/package/tutorial_simsiam_esa.py index d56eb1a38..1c0f0a43a 100644 --- a/docs/source/tutorials_source/package/tutorial_simsiam_esa.py +++ b/docs/source/tutorials_source/package/tutorial_simsiam_esa.py @@ -35,22 +35,21 @@ import math + +import numpy as np import torch import torch.nn as nn import torchvision -import numpy as np -from lightly.data import LightlyDataset, collate, ImageCollateFunction -from lightly.models.modules.heads import SimSiamPredictionHead -from lightly.models.modules.heads import SimSiamProjectionHead +from lightly.data import ImageCollateFunction, LightlyDataset, collate from lightly.loss import NegativeCosineSimilarity - +from lightly.models.modules.heads import SimSiamPredictionHead, SimSiamProjectionHead # %% # Configuration # ------------- -# -# We set some configuration parameters for our experiment. +# +# We set some configuration parameters for our experiment. # # The default configuration with a batch size and input resolution of 256 # requires 16GB of GPU memory. @@ -77,14 +76,14 @@ np.random.seed(0) # set the path to the dataset -path_to_data = '/datasets/sentinel-2-italy-v1/' +path_to_data = "/datasets/sentinel-2-italy-v1/" # %% # Setup data augmentations and loaders # ------------------------------------ # Since we're working on satellite images, it makes sense to use horizontal and -# vertical flips as well as random rotation transformations. We apply weak color +# vertical flips as well as random rotation transformations. We apply weak color # jitter to learn an invariance of the model with respect to slight changes in # the color of the water. # @@ -109,9 +108,7 @@ # create a lightly dataset for training, since the augmentations are handled # by the collate function, there is no need to apply additional ones here -dataset_train_simsiam = LightlyDataset( - input_dir=path_to_data -) +dataset_train_simsiam = LightlyDataset(input_dir=path_to_data) # create a dataloader for training dataloader_train_simsiam = torch.utils.data.DataLoader( @@ -120,26 +117,25 @@ shuffle=True, collate_fn=collate_fn, drop_last=True, - num_workers=num_workers + num_workers=num_workers, ) # create a torchvision transformation for embedding the dataset after training # here, we resize the images to match the input size during training and apply # a normalization of the color channel based on statistics from imagenet -test_transforms = torchvision.transforms.Compose([ - torchvision.transforms.Resize((input_size, input_size)), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize( - mean=collate.imagenet_normalize['mean'], - std=collate.imagenet_normalize['std'], - ) -]) +test_transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize((input_size, input_size)), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + mean=collate.imagenet_normalize["mean"], + std=collate.imagenet_normalize["std"], + ), + ] +) # create a lightly dataset for embedding -dataset_test = LightlyDataset( - input_dir=path_to_data, - transform=test_transforms -) +dataset_test = LightlyDataset(input_dir=path_to_data, transform=test_transforms) # create a dataloader for embedding dataloader_test = torch.utils.data.DataLoader( @@ -147,7 +143,7 @@ batch_size=batch_size, shuffle=False, drop_last=False, - num_workers=num_workers + num_workers=num_workers, ) # %% @@ -156,18 +152,13 @@ # # Create a ResNet backbone and remove the classification head + class SimSiam(nn.Module): - def __init__( - self, backbone, num_ftrs, proj_hidden_dim, pred_hidden_dim, out_dim - ): + def __init__(self, backbone, num_ftrs, proj_hidden_dim, pred_hidden_dim, out_dim): super().__init__() self.backbone = backbone - self.projection_head = SimSiamProjectionHead( - num_ftrs, proj_hidden_dim, out_dim - ) - self.prediction_head = SimSiamPredictionHead( - out_dim, pred_hidden_dim, out_dim - ) + self.projection_head = SimSiamProjectionHead(num_ftrs, proj_hidden_dim, out_dim) + self.prediction_head = SimSiamPredictionHead(out_dim, pred_hidden_dim, out_dim) def forward(self, x): # get representations @@ -194,43 +185,36 @@ def forward(self, x): # SimSiam uses a symmetric negative cosine similarity loss criterion = NegativeCosineSimilarity() -# scale the learning rate +# scale the learning rate lr = 0.05 * batch_size / 256 # use SGD with momentum and weight decay -optimizer = torch.optim.SGD( - model.parameters(), - lr=lr, - momentum=0.9, - weight_decay=5e-4 -) +optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) # %% # Train SimSiam # -------------------- -# +# # To train the SimSiam model, you can use a classic PyTorch training loop: # For every epoch, iterate over all batches in the training data, extract # the two transforms of every image, pass them through the model, and calculate # the loss. Then, simply update the weights with the optimizer. Don't forget to # reset the gradients! # -# Since SimSiam doesn't require negative samples, it is a good idea to check +# Since SimSiam doesn't require negative samples, it is a good idea to check # whether the outputs of the model have collapsed into a single direction. For # this we can simply check the standard deviation of the L2 normalized output -# vectors. If it is close to one divided by the square root of the output +# vectors. If it is close to one divided by the square root of the output # dimension, everything is fine (you can read # up on this idea `here `_). -device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) -avg_loss = 0. -avg_output_std = 0. +avg_loss = 0.0 +avg_output_std = 0.0 for e in range(epochs): - for (x0, x1), _, _ in dataloader_train_simsiam: - # move images to the gpu x0 = x0.to(device) x1 = x1.to(device) @@ -253,7 +237,7 @@ def forward(self, x): # we can use this later to check whether the embeddings are collapsing output = p0.detach() output = torch.nn.functional.normalize(output, dim=1) - + output_std = torch.std(output, 0) output_std = output_std.mean() @@ -264,11 +248,13 @@ def forward(self, x): # the level of collapse is large if the standard deviation of the l2 # normalized output is much smaller than 1 / sqrt(dim) - collapse_level = max(0., 1 - math.sqrt(out_dim) * avg_output_std) + collapse_level = max(0.0, 1 - math.sqrt(out_dim) * avg_output_std) # print intermediate results - print(f'[Epoch {e:3d}] ' - f'Loss = {avg_loss:.2f} | ' - f'Collapse Level: {collapse_level:.2f} / 1.00') + print( + f"[Epoch {e:3d}] " + f"Loss = {avg_loss:.2f} | " + f"Collapse Level: {collapse_level:.2f} / 1.00" + ) # %% @@ -302,24 +288,24 @@ def forward(self, x): # Now that we have the embeddings, we can visualize the data with a scatter plot. # Further down, we also check out the nearest neighbors of a few example images. # -# As a first step, we make a few additional imports. +# As a first step, we make a few additional imports. # for plotting import os -from PIL import Image -import matplotlib.pyplot as plt import matplotlib.offsetbox as osb -from matplotlib import rcParams as rcp +import matplotlib.pyplot as plt # for resizing images to thumbnails import torchvision.transforms.functional as functional +from matplotlib import rcParams as rcp +from PIL import Image # for clustering and 2d representations from sklearn import random_projection # %% -# Then, we transform the embeddings using UMAP and rescale them to fit in the +# Then, we transform the embeddings using UMAP and rescale them to fit in the # [0, 1] square. # @@ -338,16 +324,16 @@ def forward(self, x): # Let's start with a nice scatter plot of our dataset! The helper function # below will create one. + def get_scatter_plot_with_thumbnails(): - """Creates a scatter plot with image overlays. - """ + """Creates a scatter plot with image overlays.""" # initialize empty figure and add subplot fig = plt.figure() - fig.suptitle('Scatter Plot of the Sentinel-2 Dataset') + fig.suptitle("Scatter Plot of the Sentinel-2 Dataset") ax = fig.add_subplot(1, 1, 1) # shuffle images and find out which images to show shown_images_idx = [] - shown_images = np.array([[1., 1.]]) + shown_images = np.array([[1.0, 1.0]]) iterator = [i for i in range(embeddings_2d.shape[0])] np.random.shuffle(iterator) for i in iterator: @@ -360,7 +346,7 @@ def get_scatter_plot_with_thumbnails(): # plot image overlays for idx in shown_images_idx: - thumbnail_size = int(rcp['figure.figsize'][0] * 2.) + thumbnail_size = int(rcp["figure.figsize"][0] * 2.0) path = os.path.join(path_to_data, filenames[idx]) img = Image.open(path) img = functional.resize(img, thumbnail_size) @@ -373,8 +359,8 @@ def get_scatter_plot_with_thumbnails(): ax.add_artist(img_box) # set aspect ratio - ratio = 1. / ax.get_data_ratio() - ax.set_aspect(ratio, adjustable='box') + ratio = 1.0 / ax.get_data_ratio() + ax.set_aspect(ratio, adjustable="box") # get a scatter plot with thumbnail overlays @@ -386,32 +372,28 @@ def get_scatter_plot_with_thumbnails(): # embeddings generated above). This is a very simple approach to find more images # of a certain type where a few examples are already available. For example, # when a subset of the data is already labelled and one class of images is clearly -# underrepresented, one can easily query more images of this class from the +# underrepresented, one can easily query more images of this class from the # unlabelled dataset. # # Let's get to work! The plots are shown below. example_images = [ - 'S2B_MSIL1C_20200526T101559_N0209_R065_T31TGE/tile_00154.png', # water 1 - 'S2B_MSIL1C_20200526T101559_N0209_R065_T32SLJ/tile_00527.png', # water 2 - 'S2B_MSIL1C_20200526T101559_N0209_R065_T32TNL/tile_00556.png', # land - 'S2B_MSIL1C_20200526T101559_N0209_R065_T31SGD/tile_01731.png', # clouds 1 - 'S2B_MSIL1C_20200526T101559_N0209_R065_T32SMG/tile_00238.png', # clouds 2 + "S2B_MSIL1C_20200526T101559_N0209_R065_T31TGE/tile_00154.png", # water 1 + "S2B_MSIL1C_20200526T101559_N0209_R065_T32SLJ/tile_00527.png", # water 2 + "S2B_MSIL1C_20200526T101559_N0209_R065_T32TNL/tile_00556.png", # land + "S2B_MSIL1C_20200526T101559_N0209_R065_T31SGD/tile_01731.png", # clouds 1 + "S2B_MSIL1C_20200526T101559_N0209_R065_T32SMG/tile_00238.png", # clouds 2 ] def get_image_as_np_array(filename: str): - """Loads the image with filename and returns it as a numpy array. - - """ + """Loads the image with filename and returns it as a numpy array.""" img = Image.open(filename) return np.asarray(img) def get_image_as_np_array_with_frame(filename: str, w: int = 5): - """Returns an image as a numpy array with a black frame of width w. - - """ + """Returns an image as a numpy array with a black frame of width w.""" img = get_image_as_np_array(filename) ny, nx, _ = img.shape # create an empty image with padding for the frame @@ -423,9 +405,7 @@ def get_image_as_np_array_with_frame(filename: str, w: int = 5): def plot_nearest_neighbors_3x3(example_image: str, i: int): - """Plots the example image and its eight nearest neighbors. - - """ + """Plots the example image and its eight nearest neighbors.""" n_subplots = 9 # initialize empty figure fig = plt.figure() diff --git a/docs/source/tutorials_source/platform/tutorial_active_learning.py b/docs/source/tutorials_source/platform/tutorial_active_learning.py index 1c04511fb..c7a9d1b63 100644 --- a/docs/source/tutorials_source/platform/tutorial_active_learning.py +++ b/docs/source/tutorials_source/platform/tutorial_active_learning.py @@ -126,9 +126,10 @@ # # Import the Python frameworks we need for this tutorial. -import os import csv -from typing import List, Dict, Tuple +import os +from typing import Dict, List, Tuple + import numpy as np from sklearn.linear_model import LogisticRegression @@ -147,17 +148,17 @@ class CSVEmbeddingDataset: def __init__(self, path_to_embeddings_csv: str): - with open(path_to_embeddings_csv, 'r') as f: + with open(path_to_embeddings_csv, "r") as f: data = csv.reader(f) rows = list(data) header_row = rows[0] rows_without_header = rows[1:] - index_filenames = header_row.index('filenames') + index_filenames = header_row.index("filenames") filenames = [row[index_filenames] for row in rows_without_header] - index_labels = header_row.index('labels') + index_labels = header_row.index("labels") labels = [row[index_labels] for row in rows_without_header] embeddings = rows_without_header @@ -167,9 +168,12 @@ def __init__(self, path_to_embeddings_csv: str): del embedding_row[index_to_delete] # create the dataset as a dictionary mapping from the filename to a tuple of the embedding and the label - self.dataset: Dict[str, Tuple[np.ndarray, int]] = \ - dict([(filename, (np.array(embedding_row, dtype=float), int(label))) - for filename, embedding_row, label in zip(filenames, embeddings, labels)]) + self.dataset: Dict[str, Tuple[np.ndarray, int]] = dict( + [ + (filename, (np.array(embedding_row, dtype=float), int(label))) + for filename, embedding_row, label in zip(filenames, embeddings, labels) + ] + ) def get_features(self, filenames: List[str]) -> np.ndarray: features_array = np.array([self.dataset[filename][0] for filename in filenames]) @@ -183,7 +187,9 @@ def get_labels(self, filenames: List[str]) -> np.ndarray: # %% # First we read the variables we set before as environment variables via the console token = os.getenv("LIGHTLY_TOKEN", default="YOUR_TOKEN") -path_to_embeddings_csv = os.getenv("LIGHTLY_EMBEDDINGS_CSV", default="path_to_your_embeddings_csv") +path_to_embeddings_csv = os.getenv( + "LIGHTLY_EMBEDDINGS_CSV", default="path_to_your_embeddings_csv" +) # We define the client to the Lightly Platform API api_workflow_client = ApiWorkflowClient(token=token) @@ -199,7 +205,9 @@ def get_labels(self, filenames: List[str]) -> np.ndarray: # 1. Choose an initial subset of your dataset. # We want to start with 200 samples and use the CORESET selection strategy for selecting them. print("Starting the initial selection") -selection_config = SelectionConfig(n_samples=200, method=SamplingMethod.CORESET, name='initial-selection') +selection_config = SelectionConfig( + n_samples=200, method=SamplingMethod.CORESET, name="initial-selection" +) agent.query(selection_config=selection_config) print(f"There are {len(agent.labeled_set)} samples in the labeled set.") @@ -221,7 +229,9 @@ def get_labels(self, filenames: List[str]) -> np.ndarray: # %% # 5. Use an active learning agent to choose the next samples to be labeled based on the active learning scores. # We want to sample another 100 samples to have 300 samples in total and use the active learning strategy CORAL for it. -selection_config = SelectionConfig(n_samples=300, method=SamplingMethod.CORAL, name='al-iteration-1') +selection_config = SelectionConfig( + n_samples=300, method=SamplingMethod.CORAL, name="al-iteration-1" +) agent.query(selection_config=selection_config, al_scorer=active_learning_scorer) print(f"There are {len(agent.labeled_set)} samples in the labeled set.") @@ -240,15 +250,17 @@ def get_labels(self, filenames: List[str]) -> np.ndarray: # evaluate on unlabeled set unlabeled_set_features = dataset.get_features(agent.unlabeled_set) unlabeled_set_labels = dataset.get_labels(agent.unlabeled_set) -accuracy_on_unlabeled_set = classifier.score(X=unlabeled_set_features, y=unlabeled_set_labels) +accuracy_on_unlabeled_set = classifier.score( + X=unlabeled_set_features, y=unlabeled_set_labels +) print(f"accuracy on unlabeled set: {accuracy_on_unlabeled_set}") # %% # Optional: here we created tags 'initial-selection' and 'al-iteration-1' for our dataset ("active_learning_clothing_dataset"). # These can be viewed on the `Lightly Platform `_. -# To re-use the dataset without tags from past experiments, we can (optionally!) remove +# To re-use the dataset without tags from past experiments, we can (optionally!) remove # tags other than the initial-tag: for tag in api_workflow_client.get_all_tags(): - if tag.prev_tag_id is not None: - api_workflow_client.delete_tag_by_id(tag.id) + if tag.prev_tag_id is not None: + api_workflow_client.delete_tag_by_id(tag.id) diff --git a/docs/source/tutorials_source/platform/tutorial_active_learning_detectron2.py b/docs/source/tutorials_source/platform/tutorial_active_learning_detectron2.py index 9d6799fc5..d87f0d26c 100644 --- a/docs/source/tutorials_source/platform/tutorial_active_learning_detectron2.py +++ b/docs/source/tutorials_source/platform/tutorial_active_learning_detectron2.py @@ -76,28 +76,37 @@ # Setup detectron2 logger import detectron2 from detectron2.utils.logger import setup_logger + setup_logger() +import gc +import glob +import json +import os +import random + +import cv2 +import matplotlib.pyplot as plt + # import some common libraries import numpy as np -import os, json, cv2, random, glob -import tqdm, gc -import matplotlib.pyplot as plt +import tqdm # import some common detectron2 utilities from detectron2 import model_zoo -from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg +from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.engine import DefaultPredictor from detectron2.utils.visualizer import Visualizer -from detectron2.data import MetadataCatalog, DatasetCatalog + +from lightly.active_learning.agents import ActiveLearningAgent +from lightly.active_learning.config import SelectionConfig +from lightly.active_learning.scorers import ScorerObjectDetection # imports for lightly from lightly.active_learning.utils.bounding_box import BoundingBox from lightly.active_learning.utils.object_detection_output import ObjectDetectionOutput -from lightly.active_learning.scorers import ScorerObjectDetection from lightly.api.api_workflow_client import ApiWorkflowClient -from lightly.active_learning.agents import ActiveLearningAgent -from lightly.active_learning.config import SelectionConfig from lightly.openapi_generated.swagger_client import SamplingMethod # %% @@ -105,17 +114,17 @@ # ------------------------- # # To work with the Lightly Platform and use the active learning feature we -# need to upload the dataset. -# -# First, head over to `the Lightly Platform `_ and +# need to upload the dataset. +# +# First, head over to `the Lightly Platform `_ and # create a new dataset. # -# We can now upload the data using the command line interface. Replace +# We can now upload the data using the command line interface. Replace # **yourToken** and **yourDatasetId** with the two provided values from the web app. # Don't forget to adjust the **input_dir** to the location of your dataset. # -# .. code:: -# +# .. code:: +# # lightly-magic token="yourToken" dataset_id="yourDatasetId" \ # input_dir='/datasets/comma10k/imgs/' trainer.max_epochs=20 \ # loader.batch_size=64 loader.num_workers=8 @@ -129,32 +138,33 @@ YOUR_TOKEN = "yourToken" # your token of the web platform YOUR_DATASET_ID = "yourDatasetId" # the id of your dataset on the web platform -DATASET_ROOT = '/datasets/comma10k/imgs/' +DATASET_ROOT = "/datasets/comma10k/imgs/" + # allow setting of token and dataset_id from environment variables def try_get_token_and_id_from_env(): - token = os.getenv('LIGHTLY_TOKEN', YOUR_TOKEN) - dataset_id = os.getenv('AL_TUTORIAL_DATASET_ID', YOUR_DATASET_ID) + token = os.getenv("LIGHTLY_TOKEN", YOUR_TOKEN) + dataset_id = os.getenv("AL_TUTORIAL_DATASET_ID", YOUR_DATASET_ID) return token, dataset_id -YOUR_TOKEN, YOUR_DATASET_ID = try_get_token_and_id_from_env() +YOUR_TOKEN, YOUR_DATASET_ID = try_get_token_and_id_from_env() # %% # Inference on unlabeled data # ---------------------------- # -# In active learning, we want to pick the new data for which our model struggles -# the most. If we have an image with a single car in it and our model has -# high confidence that there is a car we don't gain a lot by including -# this example in our training data. However, if we focus on images where the -# model is not sure whether the object is a car or a building we want +# In active learning, we want to pick the new data for which our model struggles +# the most. If we have an image with a single car in it and our model has +# high confidence that there is a car we don't gain a lot by including +# this example in our training data. However, if we focus on images where the +# model is not sure whether the object is a car or a building we want # to include these images to refine the decision boundary. # -# First, we need to create an active learning agent in order to -# provide lightly with the model predictions. -# We can use the ApiWorkflowClient for this. Make sure that we use the +# First, we need to create an active learning agent in order to +# provide lightly with the model predictions. +# We can use the ApiWorkflowClient for this. Make sure that we use the # right dataset_id and token. # create Lightly API client @@ -173,9 +183,9 @@ def try_get_token_and_id_from_env(): # Note, that our active learning agent already synchronized with the Lightly # Platform and knows the filenames present in our dataset. # -# Let's verify the length of the `query_set`. The `query_set` is the set of +# Let's verify the length of the `query_set`. The `query_set` is the set of # images from which we want to query. By default this is our full -# dataset uploaded to Lightly. You can learn more about the different sets we +# dataset uploaded to Lightly. You can learn more about the different sets we # can access through the active learning agent here # :py:class:`lightly.api.api_workflow_client.ApiWorkflowClient` @@ -188,22 +198,27 @@ def try_get_token_and_id_from_env(): # Create our Detectron2 model # ---------------------------- # -# Next, we create a detectron2 config and a detectron2 `DefaultPredictor` to +# Next, we create a detectron2 config and a detectron2 `DefaultPredictor` to # run predictions on the new images. -# +# # - We use a pre-trained Faster R-CNN with a ResNet-50 backbone # - We use an MS COCO pre-trained model from detectron2 cfg = get_cfg() # add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library ###cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) -cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml")) +cfg.merge_from_file( + model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml") +) cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model # Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well ###cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") -cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml") +cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url( + "COCO-Detection/faster_rcnn_R_50_FPN_3x.yaml" +) predictor = DefaultPredictor(cfg) + # %% # We use this little helper method to overlay the model predictions on a # given image. @@ -212,37 +227,45 @@ def predict_and_overlay(model, filename): im = cv2.imread(filename) out = model(im) # We can use `Visualizer` to draw the predictions on the image. - v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2) + v = Visualizer( + im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2 + ) out = v.draw_instance_predictions(out["instances"].to("cpu")) - plt.figure(figsize=(16,12)) + plt.figure(figsize=(16, 12)) plt.imshow(out.get_image()[:, :, ::-1]) - plt.axis('off') + plt.axis("off") plt.tight_layout() + # %% -# The lightly framework expects a certain bounding box and prediction format. -# We create another helper method to convert the detectron2 output into +# The lightly framework expects a certain bounding box and prediction format. +# We create another helper method to convert the detectron2 output into # the desired format. def convert_bbox_detectron2lightly(outputs): # convert detectron2 predictions into lightly format - height, width = outputs['instances'].image_size + height, width = outputs["instances"].image_size boxes = [] - for (bbox_raw, score, class_idx) in zip(outputs['instances'].pred_boxes.tensor, - outputs['instances'].scores, - outputs['instances'].pred_classes): + for bbox_raw, score, class_idx in zip( + outputs["instances"].pred_boxes.tensor, + outputs["instances"].scores, + outputs["instances"].pred_classes, + ): x0, y0, x1, y1 = bbox_raw.cpu().numpy() x0 /= width y0 /= height x1 /= width y1 /= height - + boxes.append(BoundingBox(x0, y0, x1, y1)) output = ObjectDetectionOutput.from_scores( - boxes, outputs['instances'].scores.cpu().numpy(), - outputs['instances'].pred_classes.cpu().numpy().tolist()) + boxes, + outputs["instances"].scores.cpu().numpy(), + outputs["instances"].pred_classes.cpu().numpy().tolist(), + ) return output + # %% # Get Model Predictions # ---------------------- @@ -254,34 +277,34 @@ def convert_bbox_detectron2lightly(outputs): obj_detection_outputs = [] pbar = tqdm.tqdm(al_agent.query_set, miniters=500, mininterval=60, maxinterval=120) for fname in pbar: - fname_full = os.path.join(DATASET_ROOT, fname) - im = cv2.imread(fname_full) - out = predictor(im) - obj_detection_output = convert_bbox_detectron2lightly(out) - obj_detection_outputs.append(obj_detection_output) + fname_full = os.path.join(DATASET_ROOT, fname) + im = cv2.imread(fname_full) + out = predictor(im) + obj_detection_output = convert_bbox_detectron2lightly(out) + obj_detection_outputs.append(obj_detection_output) # %% # Now, we need to turn the predictions into scores. -# The scorer assigns scores between 0.0 and 1.0 to +# The scorer assigns scores between 0.0 and 1.0 to # each sample and for each scoring method. scorer = ScorerObjectDetection(obj_detection_outputs) scores = scorer.calculate_scores() -# %% +# %% # Let's have a look at the sample with the highest # uncertainty_margin score. # # .. note:: -# A high uncertainty margin means that the image contains at least one -# bounding box for which the model is unsure about the class of the object -# in the bounding box. Read more about how our active learning scores are +# A high uncertainty margin means that the image contains at least one +# bounding box for which the model is unsure about the class of the object +# in the bounding box. Read more about how our active learning scores are # calculated here: # :py:class:`lightly.active_learning.scorers.detection.ScorerObjectDetection` -max_score = scores['uncertainty_margin'].max() -idx = scores['uncertainty_margin'].argmax() -print(f'Highest uncertainty_margin score found for idx {idx}: {max_score}') +max_score = scores["uncertainty_margin"].max() +idx = scores["uncertainty_margin"].argmax() +print(f"Highest uncertainty_margin score found for idx {idx}: {max_score}") # %% -# Let's have a look at this particular image and show the model +# Let's have a look at this particular image and show the model # prediction for it. fname = os.path.join(DATASET_ROOT, al_agent.query_set[idx]) predict_and_overlay(predictor, fname) @@ -296,9 +319,7 @@ def convert_bbox_detectron2lightly(outputs): # the image diversity based on the embeddings, active learning aims at selecting # images where our model struggles the most. config = SelectionConfig( - n_samples=100, - method=SamplingMethod.CORAL, - name='active-learning-loop-1' + n_samples=100, method=SamplingMethod.CORAL, name="active-learning-loop-1" ) al_agent.query(config, scorer) @@ -314,41 +335,41 @@ def convert_bbox_detectron2lightly(outputs): # Let's show model predictions for the first 5 images. to_label = [os.path.join(DATASET_ROOT, x) for x in al_agent.added_set] for i in range(5): - predict_and_overlay(predictor, to_label[i]) + predict_and_overlay(predictor, to_label[i]) # %% # Samples selected in the step above were placed in the 'active-learning-loop-1' tag. # This can be viewed on the `Lightly Platform `_. # %% -# To re-use a dataset without tags from past experiments, we can (optionally!) remove +# To re-use a dataset without tags from past experiments, we can (optionally!) remove # tags other than the initial-tag: for tag in api_client.get_all_tags(): - if tag.prev_tag_id is not None: - api_client.delete_tag_by_id(tag.id) + if tag.prev_tag_id is not None: + api_client.delete_tag_by_id(tag.id) # %% # Next Steps # ------------- -# -# We showed in this tutorial how you can use Lightly Active Learning to discover -# the images you should label next. You can close the loop by annotating -# the 100 images and re-training your model. Then start the next iteration +# +# We showed in this tutorial how you can use Lightly Active Learning to discover +# the images you should label next. You can close the loop by annotating +# the 100 images and re-training your model. Then start the next iteration # by making new model predictions on the `query_set`. # # Using Lightly Active Learning has two advantages: # -# - By letting the model chose the next batch of images to label we achieve +# - By letting the model chose the next batch of images to label we achieve # a higher accuracy faster. We're only labeling the images having a great impact. -# -# - By combining the model predictions with the image embeddings we make sure we -# don't select many similar images. Imagine the model being very bad at small -# red cars and the 100 images therefore would only contain this set of images. -# We might overfit the model because it suddenly has too many training examples +# +# - By combining the model predictions with the image embeddings we make sure we +# don't select many similar images. Imagine the model being very bad at small +# red cars and the 100 images therefore would only contain this set of images. +# We might overfit the model because it suddenly has too many training examples # of small red cars. # %% -# After re-training our model on the newly labeled 100 images +# After re-training our model on the newly labeled 100 images # we can do another active learning iteration by running predictions on the # the `query_set`. diff --git a/docs/source/tutorials_source/platform/tutorial_cropped_objects_metadata.py b/docs/source/tutorials_source/platform/tutorial_cropped_objects_metadata.py index c202020a3..5487797ca 100644 --- a/docs/source/tutorials_source/platform/tutorial_cropped_objects_metadata.py +++ b/docs/source/tutorials_source/platform/tutorial_cropped_objects_metadata.py @@ -184,4 +184,4 @@ or using the CLI command. -""" \ No newline at end of file +""" diff --git a/docs/source/tutorials_source/platform/tutorial_pizza_filter.py b/docs/source/tutorials_source/platform/tutorial_pizza_filter.py index 2e1376fac..eddaae7ca 100644 --- a/docs/source/tutorials_source/platform/tutorial_pizza_filter.py +++ b/docs/source/tutorials_source/platform/tutorial_pizza_filter.py @@ -86,17 +86,17 @@ # Now we can start training our model using PyTorch Lightning # We start by importing the necessary dependencies import os -import torch + import pytorch_lightning as pl -from torchvision.datasets import ImageFolder +import torch +import torchmetrics from torch.utils.data import DataLoader, random_split from torchvision import transforms +from torchvision.datasets import ImageFolder from torchvision.models import resnet18 -import torchmetrics - # %% -# We use a small batch size to make sure we can run the training on all kinds +# We use a small batch size to make sure we can run the training on all kinds # of machines. Feel free to adjust the value to one that works on your machine. batch_size = 8 seed = 42 @@ -105,21 +105,25 @@ # Set the seed to make the experiment reproducible pl.seed_everything(seed) -#%% +# %% # Let's set up the augmentations for the train and the test data. -train_transform = transforms.Compose([ - transforms.RandomResizedCrop((224, 224), scale=(0.7, 1.0)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) -]) +train_transform = transforms.Compose( + [ + transforms.RandomResizedCrop((224, 224), scale=(0.7, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] +) # we don't do any resizing or mirroring for the test data -test_transform = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) -]) +test_transform = transforms.Compose( + [ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] +) # %% @@ -130,7 +134,7 @@ # pizzas # L salami # L margherita -dset = ImageFolder('pizzas', transform=train_transform) +dset = ImageFolder("pizzas", transform=train_transform) # to use the random_split method we need to obtain the length # of the train and test set @@ -140,8 +144,8 @@ dataset_train, dataset_test = random_split(dset, [train_len, test_len]) dataset_test.transforms = test_transform -print('Training set consists of {} images'.format(len(dataset_train))) -print('Test set consists of {} images'.format(len(dataset_test))) +print("Training set consists of {} images".format(len(dataset_train))) +print("Test set consists of {} images".format(len(dataset_test))) # %% # We can create our data loaders to fetch the data from the training and test @@ -149,8 +153,9 @@ dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True) dataloader_test = DataLoader(dataset_test, batch_size=batch_size) + # %% -# PyTorch Lightning allows us to pack the loss as well as the +# PyTorch Lightning allows us to pack the loss as well as the # optimizer into a single module. class MyModel(pl.LightningModule): def __init__(self, num_classes=2): @@ -173,7 +178,7 @@ def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = torch.nn.functional.cross_entropy(y_hat, y) - self.log('train_loss', loss, prog_bar=True) + self.log("train_loss", loss, prog_bar=True) return loss def validation_step(self, batch, batch_idx): @@ -182,8 +187,8 @@ def validation_step(self, batch, batch_idx): loss = torch.nn.functional.cross_entropy(y_hat, y) y_hat = torch.nn.functional.softmax(y_hat, dim=1) self.accuracy(y_hat, y) - self.log('val_loss', loss, on_epoch=True, prog_bar=True) - self.log('val_acc', self.accuracy.compute(), on_epoch=True, prog_bar=True) + self.log("val_loss", loss, on_epoch=True, prog_bar=True) + self.log("val_acc", self.accuracy.compute(), on_epoch=True, prog_bar=True) def configure_optimizers(self): return torch.optim.SGD(self.model.fc.parameters(), lr=0.001, momentum=0.9) diff --git a/examples/pytorch/barlowtwins.py b/examples/pytorch/barlowtwins.py index cf39c332b..13f0353fb 100644 --- a/examples/pytorch/barlowtwins.py +++ b/examples/pytorch/barlowtwins.py @@ -1,15 +1,14 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. import torch -from torch import nn import torchvision +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import ImageCollateFunction -from lightly.models.modules import BarlowTwinsProjectionHead +from lightly.data import ImageCollateFunction, LightlyDataset from lightly.loss import BarlowTwinsLoss +from lightly.models.modules import BarlowTwinsProjectionHead class BarlowTwins(nn.Module): diff --git a/examples/pytorch/byol.py b/examples/pytorch/byol.py index 79fdf270b..e4e26e78a 100644 --- a/examples/pytorch/byol.py +++ b/examples/pytorch/byol.py @@ -2,18 +2,17 @@ # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. +import copy + import torch -from torch import nn import torchvision -import copy +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import NegativeCosineSimilarity -from lightly.models.modules import BYOLProjectionHead, BYOLPredictionHead +from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead +from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum class BYOL(nn.Module): diff --git a/examples/pytorch/dcl.py b/examples/pytorch/dcl.py index bdbdc062e..2221870f0 100644 --- a/examples/pytorch/dcl.py +++ b/examples/pytorch/dcl.py @@ -1,13 +1,12 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. import torch -from torch import nn import torchvision +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import DCLLoss, DCLWLoss from lightly.models.modules import SimCLRProjectionHead @@ -38,7 +37,7 @@ def forward(self, x): collate_fn = SimCLRCollateFunction( input_size=32, - gaussian_blur=0., + gaussian_blur=0.0, ) dataloader = torch.utils.data.DataLoader( @@ -52,7 +51,7 @@ def forward(self, x): criterion = DCLLoss() # or use the weighted DCLW loss: -# criterion = DCLWLoss() +# criterion = DCLWLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.06) diff --git a/examples/pytorch/dino.py b/examples/pytorch/dino.py index 1fc45fded..847ed7fba 100644 --- a/examples/pytorch/dino.py +++ b/examples/pytorch/dino.py @@ -2,18 +2,17 @@ # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. +import copy + import torch -from torch import nn import torchvision -import copy +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import DINOCollateFunction +from lightly.data import DINOCollateFunction, LightlyDataset from lightly.loss import DINOLoss from lightly.models.modules import DINOProjectionHead +from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum class DINO(torch.nn.Module): diff --git a/examples/pytorch/mae.py b/examples/pytorch/mae.py index 187671cb1..74d6bde94 100644 --- a/examples/pytorch/mae.py +++ b/examples/pytorch/mae.py @@ -1,10 +1,10 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. import torch -from torch import nn import torchvision +from torch import nn from lightly.data import LightlyDataset from lightly.data.collate import MAECollateFunction @@ -15,7 +15,7 @@ class MAE(nn.Module): def __init__(self, vit): super().__init__() - + decoder_dim = 512 self.mask_ratio = 0.75 self.patch_size = vit.patch_size @@ -29,7 +29,7 @@ def __init__(self, vit): embed_input_dim=vit.hidden_dim, hidden_dim=decoder_dim, mlp_dim=decoder_dim * 4, - out_dim=vit.patch_size ** 2 * 3, + out_dim=vit.patch_size**2 * 3, dropout=0, attention_dropout=0, ) @@ -41,7 +41,9 @@ def forward_decoder(self, x_encoded, idx_keep, idx_mask): # build decoder input batch_size = x_encoded.shape[0] x_decode = self.decoder.embed(x_encoded) - x_masked = utils.repeat_token(self.mask_token, (batch_size, self.sequence_length)) + x_masked = utils.repeat_token( + self.mask_token, (batch_size, self.sequence_length) + ) x_masked = utils.set_at_index(x_masked, idx_keep, x_decode) # decoder forward pass diff --git a/examples/pytorch/moco.py b/examples/pytorch/moco.py index f01ec3e4f..c9dcc8652 100644 --- a/examples/pytorch/moco.py +++ b/examples/pytorch/moco.py @@ -2,18 +2,17 @@ # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. +import copy + import torch -from torch import nn import torchvision -import copy +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import MoCoCollateFunction +from lightly.data import LightlyDataset, MoCoCollateFunction from lightly.loss import NTXentLoss from lightly.models.modules import MoCoProjectionHead +from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum class MoCo(nn.Module): diff --git a/examples/pytorch/msn.py b/examples/pytorch/msn.py index cca452e5f..b94e4ff53 100644 --- a/examples/pytorch/msn.py +++ b/examples/pytorch/msn.py @@ -1,17 +1,17 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. import copy import torch -from torch import nn import torchvision +from torch import nn from lightly.data import LightlyDataset from lightly.data.collate import MSNCollateFunction from lightly.loss import MSNLoss -from lightly.models.modules.heads import MSNProjectionHead from lightly.models import utils +from lightly.models.modules.heads import MSNProjectionHead from lightly.models.modules.masked_autoencoder import MAEBackbone @@ -46,7 +46,8 @@ def forward_masked(self, images): out = self.anchor_backbone(images, idx_keep) return self.anchor_projection_head(out) -# ViT small configuration (ViT-S/16) + +# ViT small configuration (ViT-S/16) vit = torchvision.models.VisionTransformer( image_size=224, patch_size=16, @@ -56,9 +57,9 @@ def forward_masked(self, images): mlp_dim=384 * 4, ) model = MSN(vit) -# # or use a torchvision ViT backbone: +# # or use a torchvision ViT backbone: # vit = torchvision.models.vit_b_32(pretrained=False) -# moel = MSN(vit) +# moel = MSN(vit) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) @@ -95,7 +96,9 @@ def forward_masked(self, images): total_loss = 0 for views, _, _ in dataloader: utils.update_momentum(model.anchor_backbone, model.backbone, 0.996) - utils.update_momentum(model.anchor_projection_head, model.projection_head, 0.996) + utils.update_momentum( + model.anchor_projection_head, model.projection_head, 0.996 + ) views = [view.to(device, non_blocking=True) for view in views] targets = views[0] diff --git a/examples/pytorch/nnclr.py b/examples/pytorch/nnclr.py index 4f87427ed..1c66e64b3 100644 --- a/examples/pytorch/nnclr.py +++ b/examples/pytorch/nnclr.py @@ -1,17 +1,18 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. import torch -from torch import nn import torchvision +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import NTXentLoss -from lightly.models.modules import NNCLRProjectionHead -from lightly.models.modules import NNCLRPredictionHead -from lightly.models.modules import NNMemoryBankModule +from lightly.models.modules import ( + NNCLRPredictionHead, + NNCLRProjectionHead, + NNMemoryBankModule, +) class NNCLR(nn.Module): @@ -75,4 +76,4 @@ def forward(self, x): optimizer.step() optimizer.zero_grad() avg_loss = total_loss / len(dataloader) - print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}") \ No newline at end of file + print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}") diff --git a/examples/pytorch/simclr.py b/examples/pytorch/simclr.py index e5e93614e..8827faa37 100644 --- a/examples/pytorch/simclr.py +++ b/examples/pytorch/simclr.py @@ -1,13 +1,12 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. import torch -from torch import nn import torchvision +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import NTXentLoss from lightly.models.modules import SimCLRProjectionHead @@ -38,7 +37,7 @@ def forward(self, x): collate_fn = SimCLRCollateFunction( input_size=32, - gaussian_blur=0., + gaussian_blur=0.0, ) dataloader = torch.utils.data.DataLoader( diff --git a/examples/pytorch/simmim.py b/examples/pytorch/simmim.py index 2bea12d05..54f7b02f5 100644 --- a/examples/pytorch/simmim.py +++ b/examples/pytorch/simmim.py @@ -1,9 +1,9 @@ import torch -from torch import nn import torchvision +from torch import nn from lightly.data import LightlyDataset -from lightly.data.collate import MAECollateFunction # Same collate as MAE +from lightly.data.collate import MAECollateFunction # Same collate as MAE from lightly.models import utils from lightly.models.modules import masked_autoencoder @@ -11,7 +11,7 @@ class SimMIM(nn.Module): def __init__(self, vit): super().__init__() - + decoder_dim = vit.hidden_dim self.mask_ratio = 0.75 self.patch_size = vit.patch_size @@ -19,16 +19,15 @@ def __init__(self, vit): self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) # same backbone as MAE - self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) + self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) # the decoder is a simple linear layer - self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size ** 2 * 3) - + self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size**2 * 3) def forward_encoder(self, images, batch_size, idx_mask): # pass all the tokens to the encoder, both masked and non masked ones tokens = self.backbone.images_to_tokens(images, prepend_class_token=True) - tokens_masked = utils.mask_at_index(tokens, idx_mask , self.mask_token) + tokens_masked = utils.mask_at_index(tokens, idx_mask, self.mask_token) return self.backbone.encoder(tokens_masked) def forward_decoder(self, x_encoded): @@ -41,7 +40,7 @@ def forward(self, images): mask_ratio=self.mask_ratio, device=images.device, ) - + # Encoding... x_encoded = self.forward_encoder(images, batch_size, idx_mask) x_encoded_masked = utils.get_at_index(x_encoded, idx_mask) @@ -51,7 +50,7 @@ def forward(self, images): # get image patches for masked tokens patches = utils.patchify(images, self.patch_size) - + # must adjust idx_mask for missing class token target = utils.get_at_index(patches, idx_mask - 1) @@ -93,7 +92,7 @@ def forward(self, images): for images, _, _ in dataloader: images = images.to(device) predictions, targets = model(images) - + loss = criterion(predictions, targets) total_loss += loss.detach() loss.backward() diff --git a/examples/pytorch/simsiam.py b/examples/pytorch/simsiam.py index 2399c87c0..ec881a684 100644 --- a/examples/pytorch/simsiam.py +++ b/examples/pytorch/simsiam.py @@ -1,16 +1,14 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. import torch -from torch import nn import torchvision +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import NegativeCosineSimilarity -from lightly.models.modules import SimSiamProjectionHead -from lightly.models.modules import SimSiamPredictionHead +from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead class SimSiam(nn.Module): diff --git a/examples/pytorch/smog.py b/examples/pytorch/smog.py index 7be43e980..2cb4ba831 100644 --- a/examples/pytorch/smog.py +++ b/examples/pytorch/smog.py @@ -3,20 +3,21 @@ # run on a small dataset with a single GPU. import copy + import torch -from torch import nn import torchvision from sklearn.cluster import KMeans - +from torch import nn from lightly.data import LightlyDataset from lightly.data.collate import SMoGCollateFunction from lightly.loss.memory_bank import MemoryBankModule -from lightly.models.modules.heads import SMoGProjectionHead -from lightly.models.modules.heads import SMoGPredictionHead -from lightly.models.modules.heads import SMoGPrototypes from lightly.models import utils - +from lightly.models.modules.heads import ( + SMoGPredictionHead, + SMoGProjectionHead, + SMoGPrototypes, +) class SMoGModel(nn.Module): @@ -93,7 +94,7 @@ def forward_momentum(self, x): collate_fn = SMoGCollateFunction( crop_sizes=[32, 32], crop_counts=[1, 1], - gaussian_blur_probs=[0., 0.], + gaussian_blur_probs=[0.0, 0.0], crop_min_scales=[0.2, 0.2], crop_max_scales=[1.0, 1.0], ) @@ -109,10 +110,7 @@ def forward_momentum(self, x): criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD( - model.parameters(), - lr=0.01, - momentum=0.9, - weight_decay=1e-6 + model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-6 ) global_step = 0 @@ -121,7 +119,6 @@ def forward_momentum(self, x): for epoch in range(10): total_loss = 0 for batch_idx, batch in enumerate(dataloader): - (x0, x1), _, _ = batch if batch_idx % 2: @@ -138,7 +135,9 @@ def forward_momentum(self, x): else: # update momentum utils.update_momentum(model.backbone, model.backbone_momentum, 0.99) - utils.update_momentum(model.projection_head, model.projection_head_momentum, 0.99) + utils.update_momentum( + model.projection_head, model.projection_head_momentum, 0.99 + ) x0_encoded, x0_predicted = model(x0) x1_encoded = model.forward_momentum(x1) diff --git a/examples/pytorch/swav.py b/examples/pytorch/swav.py index 65bf6d87f..6293b8680 100644 --- a/examples/pytorch/swav.py +++ b/examples/pytorch/swav.py @@ -1,16 +1,14 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. import torch -from torch import nn import torchvision +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SwaVCollateFunction +from lightly.data import LightlyDataset, SwaVCollateFunction from lightly.loss import SwaVLoss -from lightly.models.modules import SwaVProjectionHead -from lightly.models.modules import SwaVPrototypes +from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes class SwaV(nn.Module): diff --git a/examples/pytorch/tico.py b/examples/pytorch/tico.py index ef1831e51..cef1f786e 100644 --- a/examples/pytorch/tico.py +++ b/examples/pytorch/tico.py @@ -2,18 +2,17 @@ # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. +import copy + import torch -from torch import nn import torchvision -import copy +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead +from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum class TiCo(nn.Module): diff --git a/examples/pytorch/vicreg.py b/examples/pytorch/vicreg.py index e4ba67912..fbcd99d6c 100644 --- a/examples/pytorch/vicreg.py +++ b/examples/pytorch/vicreg.py @@ -1,13 +1,14 @@ import torch -from torch import nn import torchvision +from torch import nn from lightly.data import LightlyDataset from lightly.data.collate import VICRegCollateFunction -from lightly.models.modules import BarlowTwinsProjectionHead ## The projection head is the same as the Barlow Twins one from lightly.loss import VICRegLoss +from lightly.models.modules import BarlowTwinsProjectionHead + class VICReg(nn.Module): def __init__(self, backbone): diff --git a/examples/pytorch/vicregl.py b/examples/pytorch/vicregl.py index f0758530f..cf336dfad 100644 --- a/examples/pytorch/vicregl.py +++ b/examples/pytorch/vicregl.py @@ -1,13 +1,15 @@ import torch -from torch import nn import torchvision +from torch import nn from lightly.data import LightlyDataset from lightly.data.collate import VICRegLCollateFunction +from lightly.loss import VICRegLLoss + ## The global projection head is the same as the Barlow Twins one from lightly.models.modules import BarlowTwinsProjectionHead from lightly.models.modules.heads import VicRegLLocalProjectionHead -from lightly.loss import VICRegLLoss + class VICRegL(nn.Module): def __init__(self, backbone): @@ -21,10 +23,11 @@ def forward(self, x): x = self.backbone(x) y = self.average_pool(x).flatten(start_dim=1) z = self.projection_head(y) - y_local = x.permute(0, 2, 3, 1) # (B, D, W, H) to (B, W, H, D) - z_local = self.local_projection_head(y_local) + y_local = x.permute(0, 2, 3, 1) # (B, D, W, H) to (B, W, H, D) + z_local = self.local_projection_head(y_local) return z, z_local + resnet = torchvision.models.resnet18() backbone = nn.Sequential(*list(resnet.children())[:-2]) model = VICRegL(backbone) @@ -41,7 +44,7 @@ def forward(self, x): dataloader = torch.utils.data.DataLoader( dataset, - batch_size=128, #2048 from the paper if enough memory + batch_size=128, # 2048 from the paper if enough memory collate_fn=collate_fn, shuffle=True, drop_last=True, @@ -62,12 +65,12 @@ def forward(self, x): z_global, z_global_local_features = model(view_global) z_local, z_local_local_features = model(view_local) loss = criterion( - z_global=z_global, - z_local=z_local, - z_global_local_features=z_global_local_features, - z_local_local_features=z_local_local_features, - grid_global=grid_global, - grid_local=grid_local + z_global=z_global, + z_local=z_local, + z_global_local_features=z_global_local_features, + z_local_local_features=z_local_local_features, + grid_global=grid_global, + grid_local=grid_local, ) total_loss += loss.detach() loss.backward() diff --git a/examples/pytorch_lightning/barlowtwins.py b/examples/pytorch_lightning/barlowtwins.py index ab27d134b..60b7a7e20 100644 --- a/examples/pytorch_lightning/barlowtwins.py +++ b/examples/pytorch_lightning/barlowtwins.py @@ -1,14 +1,13 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import ImageCollateFunction +from lightly.data import ImageCollateFunction, LightlyDataset from lightly.loss import BarlowTwinsLoss from lightly.models.modules import BarlowTwinsProjectionHead diff --git a/examples/pytorch_lightning/byol.py b/examples/pytorch_lightning/byol.py index 7a5a1356d..49d71f0a4 100644 --- a/examples/pytorch_lightning/byol.py +++ b/examples/pytorch_lightning/byol.py @@ -2,19 +2,18 @@ # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. -import torch -from torch import nn -import torchvision import copy + import pytorch_lightning as pl +import torch +import torchvision +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import NegativeCosineSimilarity -from lightly.models.modules import BYOLProjectionHead, BYOLPredictionHead +from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead +from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum class BYOL(pl.LightningModule): diff --git a/examples/pytorch_lightning/dcl.py b/examples/pytorch_lightning/dcl.py index cc5ba3e4e..e270161eb 100644 --- a/examples/pytorch_lightning/dcl.py +++ b/examples/pytorch_lightning/dcl.py @@ -1,14 +1,13 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import DCLLoss, DCLWLoss from lightly.models.modules import SimCLRProjectionHead @@ -21,7 +20,7 @@ def __init__(self): self.projection_head = SimCLRProjectionHead(512, 2048, 2048) self.criterion = DCLLoss() # or use the weighted DCLW loss: - # self.criterion = DCLWLoss() + # self.criterion = DCLWLoss() def forward(self, x): x = self.backbone(x).flatten(start_dim=1) @@ -49,7 +48,7 @@ def configure_optimizers(self): collate_fn = SimCLRCollateFunction( input_size=32, - gaussian_blur=0., + gaussian_blur=0.0, ) dataloader = torch.utils.data.DataLoader( diff --git a/examples/pytorch_lightning/dino.py b/examples/pytorch_lightning/dino.py index a8ff840d1..42f7e474d 100644 --- a/examples/pytorch_lightning/dino.py +++ b/examples/pytorch_lightning/dino.py @@ -1,21 +1,19 @@ import copy -# Note: The model and training settings do not follow the reference settings -# from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. - +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import DINOCollateFunction +from lightly.data import DINOCollateFunction, LightlyDataset from lightly.loss import DINOLoss from lightly.models.modules import DINOProjectionHead +from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum + +# Note: The model and training settings do not follow the reference settings +# from the paper. The settings are chosen such that the example can easily be +# run on a small dataset with a single GPU. class DINO(pl.LightningModule): diff --git a/examples/pytorch_lightning/mae.py b/examples/pytorch_lightning/mae.py index 08ad30897..6edf067cb 100644 --- a/examples/pytorch_lightning/mae.py +++ b/examples/pytorch_lightning/mae.py @@ -1,11 +1,11 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn from lightly.data import LightlyDataset from lightly.data.collate import MAECollateFunction @@ -16,7 +16,7 @@ class MAE(pl.LightningModule): def __init__(self): super().__init__() - + decoder_dim = 512 vit = torchvision.models.vit_b_32(pretrained=False) self.mask_ratio = 0.75 @@ -31,7 +31,7 @@ def __init__(self): embed_input_dim=vit.hidden_dim, hidden_dim=decoder_dim, mlp_dim=decoder_dim * 4, - out_dim=vit.patch_size ** 2 * 3, + out_dim=vit.patch_size**2 * 3, dropout=0, attention_dropout=0, ) @@ -44,7 +44,9 @@ def forward_decoder(self, x_encoded, idx_keep, idx_mask): # build decoder input batch_size = x_encoded.shape[0] x_decode = self.decoder.embed(x_encoded) - x_masked = utils.repeat_token(self.mask_token, (batch_size, self.sequence_length)) + x_masked = utils.repeat_token( + self.mask_token, (batch_size, self.sequence_length) + ) x_masked = utils.set_at_index(x_masked, idx_keep, x_decode) # decoder forward pass @@ -57,7 +59,7 @@ def forward_decoder(self, x_encoded, idx_keep, idx_mask): def training_step(self, batch, batch_idx): images, _, _ = batch - + batch_size = images.shape[0] idx_keep, idx_mask = utils.random_token_mask( size=(batch_size, self.sequence_length), @@ -71,7 +73,7 @@ def training_step(self, batch, batch_idx): patches = utils.patchify(images, self.patch_size) # must adjust idx_mask for missing class token target = utils.get_at_index(patches, idx_mask - 1) - + loss = self.criterion(x_pred, target) return loss diff --git a/examples/pytorch_lightning/moco.py b/examples/pytorch_lightning/moco.py index 27a20cf83..f4626d574 100644 --- a/examples/pytorch_lightning/moco.py +++ b/examples/pytorch_lightning/moco.py @@ -2,19 +2,18 @@ # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. -import torch -from torch import nn -import torchvision import copy + import pytorch_lightning as pl +import torch +import torchvision +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import MoCoCollateFunction +from lightly.data import LightlyDataset, MoCoCollateFunction from lightly.loss import NTXentLoss from lightly.models.modules import MoCoProjectionHead +from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum class MoCo(pl.LightningModule): diff --git a/examples/pytorch_lightning/msn.py b/examples/pytorch_lightning/msn.py index 18a76b7ec..f5b463c4c 100644 --- a/examples/pytorch_lightning/msn.py +++ b/examples/pytorch_lightning/msn.py @@ -1,18 +1,18 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. import copy +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn from lightly.data import LightlyDataset from lightly.data.collate import MSNCollateFunction from lightly.loss import MSNLoss -from lightly.models.modules.heads import MSNProjectionHead from lightly.models import utils +from lightly.models.modules.heads import MSNProjectionHead from lightly.models.modules.masked_autoencoder import MAEBackbone @@ -20,7 +20,7 @@ class MSN(pl.LightningModule): def __init__(self): super().__init__() - # ViT small configuration (ViT-S/16) + # ViT small configuration (ViT-S/16) self.mask_ratio = 0.15 self.backbone = MAEBackbone( image_size=224, @@ -30,7 +30,7 @@ def __init__(self): hidden_dim=384, mlp_dim=384 * 4, ) - # or use a torchvision ViT backbone: + # or use a torchvision ViT backbone: # vit = torchvision.models.vit_b_32(pretrained=False) # self.backbone = MAEBackbone.from_vit(vit) self.projection_head = MSNProjectionHead(384) diff --git a/examples/pytorch_lightning/nnclr.py b/examples/pytorch_lightning/nnclr.py index cfdc2af8f..5c5fd2c19 100644 --- a/examples/pytorch_lightning/nnclr.py +++ b/examples/pytorch_lightning/nnclr.py @@ -1,18 +1,19 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import NTXentLoss -from lightly.models.modules import NNCLRProjectionHead -from lightly.models.modules import NNCLRPredictionHead -from lightly.models.modules import NNMemoryBankModule +from lightly.models.modules import ( + NNCLRPredictionHead, + NNCLRProjectionHead, + NNMemoryBankModule, +) class NNCLR(pl.LightningModule): diff --git a/examples/pytorch_lightning/simclr.py b/examples/pytorch_lightning/simclr.py index 4cc75d579..11efe5025 100644 --- a/examples/pytorch_lightning/simclr.py +++ b/examples/pytorch_lightning/simclr.py @@ -1,14 +1,13 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import NTXentLoss from lightly.models.modules import SimCLRProjectionHead @@ -47,7 +46,7 @@ def configure_optimizers(self): collate_fn = SimCLRCollateFunction( input_size=32, - gaussian_blur=0., + gaussian_blur=0.0, ) dataloader = torch.utils.data.DataLoader( diff --git a/examples/pytorch_lightning/simmim.py b/examples/pytorch_lightning/simmim.py index 782d33526..e05808626 100644 --- a/examples/pytorch_lightning/simmim.py +++ b/examples/pytorch_lightning/simmim.py @@ -1,10 +1,10 @@ +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn from lightly.data import LightlyDataset -from lightly.data.collate import MAECollateFunction # Same collate as MAE +from lightly.data.collate import MAECollateFunction # Same collate as MAE from lightly.models import utils from lightly.models.modules import masked_autoencoder @@ -12,7 +12,7 @@ class SimMIM(pl.LightningModule): def __init__(self): super().__init__() - + vit = torchvision.models.vit_b_32(pretrained=False) decoder_dim = vit.hidden_dim self.mask_ratio = 0.75 @@ -21,19 +21,18 @@ def __init__(self): self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) # same backbone as MAE - self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) + self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) # the decoder is a simple linear layer - self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size ** 2 * 3) + self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size**2 * 3) # L1 loss as paper suggestion self.criterion = nn.L1Loss() - def forward_encoder(self, images, batch_size, idx_mask): # pass all the tokens to the encoder, both masked and non masked ones tokens = self.backbone.images_to_tokens(images, prepend_class_token=True) - tokens_masked = utils.mask_at_index(tokens, idx_mask , self.mask_token) + tokens_masked = utils.mask_at_index(tokens, idx_mask, self.mask_token) return self.backbone.encoder(tokens_masked) def forward_decoder(self, x_encoded): @@ -48,7 +47,7 @@ def training_step(self, batch, batch_idx): mask_ratio=self.mask_ratio, device=images.device, ) - + # Encoding... x_encoded = self.forward_encoder(images, batch_size, idx_mask) x_encoded_masked = utils.get_at_index(x_encoded, idx_mask) @@ -58,13 +57,13 @@ def training_step(self, batch, batch_idx): # get image patches for masked tokens patches = utils.patchify(images, self.patch_size) - + # must adjust idx_mask for missing class token target = utils.get_at_index(patches, idx_mask - 1) loss = self.criterion(x_out, target) return loss - + def configure_optimizers(self): optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4) return optim diff --git a/examples/pytorch_lightning/simsiam.py b/examples/pytorch_lightning/simsiam.py index e9bf1f934..dac6d6efe 100644 --- a/examples/pytorch_lightning/simsiam.py +++ b/examples/pytorch_lightning/simsiam.py @@ -1,17 +1,15 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import NegativeCosineSimilarity -from lightly.models.modules import SimSiamProjectionHead -from lightly.models.modules import SimSiamPredictionHead +from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead class SimSiam(pl.LightningModule): diff --git a/examples/pytorch_lightning/smog.py b/examples/pytorch_lightning/smog.py index 1cc51049b..e925cc38d 100644 --- a/examples/pytorch_lightning/smog.py +++ b/examples/pytorch_lightning/smog.py @@ -2,30 +2,26 @@ # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. -import torch -from torch import nn -import torchvision import copy + import pytorch_lightning as pl +import torch +import torchvision +from sklearn.cluster import KMeans +from torch import nn -from lightly import data -from lightly.models.modules import heads +from lightly import data, loss, models from lightly.models import utils -from lightly import loss -from lightly import models - -from sklearn.cluster import KMeans +from lightly.models.modules import heads class SMoGModel(pl.LightningModule): - def __init__(self): super().__init__() # create a ResNet backbone and remove the classification head - resnet = models.ResNetGenerator('resnet-18') + resnet = models.ResNetGenerator("resnet-18") self.backbone = nn.Sequential( - *list(resnet.children())[:-1], - nn.AdaptiveAvgPool2d(1) + *list(resnet.children())[:-1], nn.AdaptiveAvgPool2d(1) ) # create a model based on ResNet @@ -68,7 +64,6 @@ def _reset_momentum_weights(self): utils.deactivate_requires_grad(self.projection_head_momentum) def training_step(self, batch, batch_idx): - if self.global_step > 0 and self.global_step % 300 == 0: # reset group features and weights every 300 iterations self._reset_group_features() @@ -76,7 +71,9 @@ def training_step(self, batch, batch_idx): else: # update momentum utils.update_momentum(self.backbone, self.backbone_momentum, 0.99) - utils.update_momentum(self.projection_head, self.projection_head_momentum, 0.99) + utils.update_momentum( + self.projection_head, self.projection_head_momentum, 0.99 + ) (x0, x1), _, _ = batch @@ -103,11 +100,14 @@ def training_step(self, batch, batch_idx): return loss - def configure_optimizers(self): - params = list(self.backbone.parameters()) + list(self.projection_head.parameters()) + list(self.prediction_head.parameters()) + params = ( + list(self.backbone.parameters()) + + list(self.projection_head.parameters()) + + list(self.prediction_head.parameters()) + ) optim = torch.optim.SGD( - params, + params, lr=0.01, momentum=0.9, weight_decay=1e-6, @@ -125,7 +125,7 @@ def configure_optimizers(self): collate_fn = data.collate.SMoGCollateFunction( crop_sizes=[32, 32], crop_counts=[1, 1], - gaussian_blur_probs=[0., 0.], + gaussian_blur_probs=[0.0, 0.0], crop_min_scales=[0.2, 0.2], crop_max_scales=[1.0, 1.0], ) @@ -145,4 +145,3 @@ def configure_optimizers(self): trainer.fit(model=model, train_dataloaders=dataloader) - diff --git a/examples/pytorch_lightning/swav.py b/examples/pytorch_lightning/swav.py index 8180e8bb4..84e111556 100644 --- a/examples/pytorch_lightning/swav.py +++ b/examples/pytorch_lightning/swav.py @@ -1,17 +1,15 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SwaVCollateFunction +from lightly.data import LightlyDataset, SwaVCollateFunction from lightly.loss import SwaVLoss -from lightly.models.modules import SwaVProjectionHead -from lightly.models.modules import SwaVPrototypes +from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes class SwaV(pl.LightningModule): diff --git a/examples/pytorch_lightning/swav_queue.py b/examples/pytorch_lightning/swav_queue.py index a21688314..8a33cd740 100644 --- a/examples/pytorch_lightning/swav_queue.py +++ b/examples/pytorch_lightning/swav_queue.py @@ -2,17 +2,15 @@ # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SwaVCollateFunction +from lightly.data import LightlyDataset, SwaVCollateFunction from lightly.loss import SwaVLoss from lightly.loss.memory_bank import MemoryBankModule -from lightly.models.modules import SwaVProjectionHead -from lightly.models.modules import SwaVPrototypes +from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes class SwaV(pl.LightningModule): @@ -58,7 +56,6 @@ def _subforward(self, input): @torch.no_grad() def _get_queue_prototypes(self, high_resolution_features): - if len(high_resolution_features) != len(self.queues): raise ValueError( f"The number of queues ({len(self.queues)}) should be equal to the number of high " @@ -83,7 +80,9 @@ def _get_queue_prototypes(self, high_resolution_features): return None # Assign prototypes - queue_prototypes = [self.prototypes(x, self.current_epoch) for x in queue_features] + queue_prototypes = [ + self.prototypes(x, self.current_epoch) for x in queue_features + ] return queue_prototypes diff --git a/examples/pytorch_lightning/tico.py b/examples/pytorch_lightning/tico.py index 4fe4f1c78..e4572c9b7 100644 --- a/examples/pytorch_lightning/tico.py +++ b/examples/pytorch_lightning/tico.py @@ -1,16 +1,15 @@ -import torch -from torch import nn -import torchvision import copy + import pytorch_lightning as pl +import torch +import torchvision +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead +from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum class TiCo(pl.LightningModule): diff --git a/examples/pytorch_lightning/vicreg.py b/examples/pytorch_lightning/vicreg.py index e8eb0c7a7..8066802dc 100644 --- a/examples/pytorch_lightning/vicreg.py +++ b/examples/pytorch_lightning/vicreg.py @@ -1,11 +1,11 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn from lightly.data import LightlyDataset from lightly.data.collate import VICRegCollateFunction diff --git a/examples/pytorch_lightning/vicregl.py b/examples/pytorch_lightning/vicregl.py index 2840b2175..fdd9aa6fc 100644 --- a/examples/pytorch_lightning/vicregl.py +++ b/examples/pytorch_lightning/vicregl.py @@ -1,19 +1,19 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn from lightly.data import LightlyDataset from lightly.data.collate import VICRegLCollateFunction +from lightly.loss import VICRegLLoss + ## The global projection head is the same as the Barlow Twins one from lightly.models.modules import BarlowTwinsProjectionHead from lightly.models.modules.heads import VicRegLLocalProjectionHead -from lightly.loss import VICRegLLoss - class VICRegL(pl.LightningModule): @@ -30,22 +30,21 @@ def forward(self, x): x = self.backbone(x) y = self.average_pool(x).flatten(start_dim=1) z = self.projection_head(y) - y_local = x.permute(0, 2, 3, 1) # (B, D, W, H) to (B, W, H, D) - z_local = self.local_projection_head(y_local) + y_local = x.permute(0, 2, 3, 1) # (B, D, W, H) to (B, W, H, D) + z_local = self.local_projection_head(y_local) return z, z_local - def training_step(self, batch, batch_index): (view_global, view_local, grid_global, grid_local), _, _ = batch z_global, z_global_local_features = self.forward(view_global) z_local, z_local_local_features = self.forward(view_local) loss = self.criterion( - z_global=z_global, - z_local=z_local, - z_global_local_features=z_global_local_features, - z_local_local_features=z_local_local_features, - grid_global=grid_global, - grid_local=grid_local + z_global=z_global, + z_local=z_local, + z_global_local_features=z_global_local_features, + z_local_local_features=z_local_local_features, + grid_global=grid_global, + grid_local=grid_local, ) return loss diff --git a/examples/pytorch_lightning_distributed/barlowtwins.py b/examples/pytorch_lightning_distributed/barlowtwins.py index cc6c0cfd7..9f7b5e152 100644 --- a/examples/pytorch_lightning_distributed/barlowtwins.py +++ b/examples/pytorch_lightning_distributed/barlowtwins.py @@ -1,17 +1,17 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import ImageCollateFunction +from lightly.data import ImageCollateFunction, LightlyDataset from lightly.loss import BarlowTwinsLoss from lightly.models.modules import BarlowTwinsProjectionHead + class BarlowTwins(pl.LightningModule): def __init__(self): super().__init__() @@ -19,7 +19,7 @@ def __init__(self): self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048) - # enable gather_distributed to gather features from all gpus + # enable gather_distributed to gather features from all gpus # before calculating the loss self.criterion = BarlowTwinsLoss(gather_distributed=True) @@ -63,9 +63,9 @@ def configure_optimizers(self): # train with DDP and use Synchronized Batch Norm for a more accurate batch norm # calculation trainer = pl.Trainer( - max_epochs=10, + max_epochs=10, gpus=gpus, - strategy='ddp', + strategy="ddp", sync_batchnorm=True, ) trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/examples/pytorch_lightning_distributed/byol.py b/examples/pytorch_lightning_distributed/byol.py index ec099862d..d4b39f5ff 100644 --- a/examples/pytorch_lightning_distributed/byol.py +++ b/examples/pytorch_lightning_distributed/byol.py @@ -2,20 +2,19 @@ # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. -import torch -from torch import nn -import torchvision import copy + import pytorch_lightning as pl +import torch +import torchvision +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import NegativeCosineSimilarity from lightly.models.modules import BYOLProjectionHead from lightly.models.modules.heads import BYOLPredictionHead +from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum class BYOL(pl.LightningModule): diff --git a/examples/pytorch_lightning_distributed/dcl.py b/examples/pytorch_lightning_distributed/dcl.py index 814380628..767993d16 100644 --- a/examples/pytorch_lightning_distributed/dcl.py +++ b/examples/pytorch_lightning_distributed/dcl.py @@ -1,17 +1,17 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import DCLLoss, DCLWLoss from lightly.models.modules import SimCLRProjectionHead + class DCL(pl.LightningModule): def __init__(self): super().__init__() @@ -19,11 +19,11 @@ def __init__(self): self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = SimCLRProjectionHead(512, 2048, 2048) - # enable gather_distributed to gather features from all gpus + # enable gather_distributed to gather features from all gpus # before calculating the loss self.criterion = DCLLoss(gather_distributed=True) # or use the weighted DCLW loss: - # self.criterion = DCLWLoss(gather_distributed=True) + # self.criterion = DCLWLoss(gather_distributed=True) def forward(self, x): x = self.backbone(x).flatten(start_dim=1) @@ -51,7 +51,7 @@ def configure_optimizers(self): collate_fn = SimCLRCollateFunction( input_size=32, - gaussian_blur=0., + gaussian_blur=0.0, ) dataloader = torch.utils.data.DataLoader( @@ -68,9 +68,9 @@ def configure_optimizers(self): # train with DDP and use Synchronized Batch Norm for a more accurate batch norm # calculation trainer = pl.Trainer( - max_epochs=10, + max_epochs=10, gpus=gpus, - strategy='ddp', + strategy="ddp", sync_batchnorm=True, ) trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/examples/pytorch_lightning_distributed/dino.py b/examples/pytorch_lightning_distributed/dino.py index c7a1f87a1..2a4ee37dd 100644 --- a/examples/pytorch_lightning_distributed/dino.py +++ b/examples/pytorch_lightning_distributed/dino.py @@ -1,21 +1,19 @@ import copy -# Note: The model and training settings do not follow the reference settings -# from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. - +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import DINOCollateFunction +from lightly.data import DINOCollateFunction, LightlyDataset from lightly.loss import DINOLoss from lightly.models.modules import DINOProjectionHead +from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum + +# Note: The model and training settings do not follow the reference settings +# from the paper. The settings are chosen such that the example can easily be +# run on a small dataset with a single GPU. class DINO(pl.LightningModule): diff --git a/examples/pytorch_lightning_distributed/mae.py b/examples/pytorch_lightning_distributed/mae.py index 680ab7f3e..e52fd0230 100644 --- a/examples/pytorch_lightning_distributed/mae.py +++ b/examples/pytorch_lightning_distributed/mae.py @@ -1,11 +1,11 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn from lightly.data import LightlyDataset from lightly.data.collate import MAECollateFunction @@ -31,7 +31,7 @@ def __init__(self): embed_input_dim=vit.hidden_dim, hidden_dim=decoder_dim, mlp_dim=decoder_dim * 4, - out_dim=vit.patch_size ** 2 * 3, + out_dim=vit.patch_size**2 * 3, dropout=0, attention_dropout=0, ) @@ -44,7 +44,9 @@ def forward_decoder(self, x_encoded, idx_keep, idx_mask): # build decoder input batch_size = x_encoded.shape[0] x_decode = self.decoder.embed(x_encoded) - x_masked = utils.repeat_token(self.mask_token, (batch_size, self.sequence_length)) + x_masked = utils.repeat_token( + self.mask_token, (batch_size, self.sequence_length) + ) x_masked = utils.set_at_index(x_masked, idx_keep, x_decode) # decoder forward pass @@ -57,7 +59,7 @@ def forward_decoder(self, x_encoded, idx_keep, idx_mask): def training_step(self, batch, batch_idx): images, _, _ = batch - + batch_size = images.shape[0] idx_keep, idx_mask = utils.random_token_mask( size=(batch_size, self.sequence_length), @@ -71,7 +73,7 @@ def training_step(self, batch, batch_idx): patches = utils.patchify(images, self.patch_size) # must adjust idx_mask for missing class token target = utils.get_at_index(patches, idx_mask - 1) - + loss = self.criterion(x_pred, target) return loss @@ -103,12 +105,12 @@ def configure_optimizers(self): gpus = torch.cuda.device_count() -# Train with DDP on multiple gpus. Distributed sampling is also enabled with +# Train with DDP on multiple gpus. Distributed sampling is also enabled with # replace_sampler_ddp=True. trainer = pl.Trainer( - max_epochs=10, + max_epochs=10, gpus=gpus, - strategy='ddp', + strategy="ddp", replace_sampler_ddp=True, ) trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/examples/pytorch_lightning_distributed/moco.py b/examples/pytorch_lightning_distributed/moco.py index 4055b0529..2d13071ea 100644 --- a/examples/pytorch_lightning_distributed/moco.py +++ b/examples/pytorch_lightning_distributed/moco.py @@ -2,19 +2,18 @@ # from the paper. The settings are chosen such that the example can easily be # run on a small dataset with a single GPU. -import torch -from torch import nn -import torchvision import copy + import pytorch_lightning as pl +import torch +import torchvision +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import MoCoCollateFunction +from lightly.data import LightlyDataset, MoCoCollateFunction from lightly.loss import NTXentLoss from lightly.models.modules import MoCoProjectionHead +from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum class MoCo(pl.LightningModule): diff --git a/examples/pytorch_lightning_distributed/msn.py b/examples/pytorch_lightning_distributed/msn.py index 325960d91..33a2d2164 100644 --- a/examples/pytorch_lightning_distributed/msn.py +++ b/examples/pytorch_lightning_distributed/msn.py @@ -1,18 +1,18 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. import copy +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn from lightly.data import LightlyDataset from lightly.data.collate import MSNCollateFunction from lightly.loss import MSNLoss -from lightly.models.modules.heads import MSNProjectionHead from lightly.models import utils +from lightly.models.modules.heads import MSNProjectionHead from lightly.models.modules.masked_autoencoder import MAEBackbone @@ -20,7 +20,7 @@ class MSN(pl.LightningModule): def __init__(self): super().__init__() - # ViT small configuration (ViT-S/16) + # ViT small configuration (ViT-S/16) self.mask_ratio = 0.15 self.backbone = MAEBackbone( image_size=224, @@ -30,7 +30,7 @@ def __init__(self): hidden_dim=384, mlp_dim=384 * 4, ) - # or use a torchvision ViT backbone: + # or use a torchvision ViT backbone: # vit = torchvision.models.vit_b_32(pretrained=False) # self.backbone = MAEBackbone.from_vit(vit) self.projection_head = MSNProjectionHead(384) @@ -43,7 +43,7 @@ def __init__(self): self.prototypes = nn.Linear(256, 1024, bias=False).weight - # set gather_distributed to True for distributed training + # set gather_distributed to True for distributed training self.criterion = MSNLoss(gather_distributed=True) def training_step(self, batch, batch_idx): @@ -109,12 +109,12 @@ def configure_optimizers(self): gpus = torch.cuda.device_count() -# Train with DDP on multiple gpus. Distributed sampling is also enabled with +# Train with DDP on multiple gpus. Distributed sampling is also enabled with # replace_sampler_ddp=True. trainer = pl.Trainer( - max_epochs=10, + max_epochs=10, gpus=gpus, - strategy='ddp', + strategy="ddp", replace_sampler_ddp=True, ) trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/examples/pytorch_lightning_distributed/nnclr.py b/examples/pytorch_lightning_distributed/nnclr.py index 109c4255e..9d74f30ad 100644 --- a/examples/pytorch_lightning_distributed/nnclr.py +++ b/examples/pytorch_lightning_distributed/nnclr.py @@ -1,18 +1,19 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import NTXentLoss -from lightly.models.modules import NNCLRProjectionHead -from lightly.models.modules import NNCLRPredictionHead -from lightly.models.modules import NNMemoryBankModule +from lightly.models.modules import ( + NNCLRPredictionHead, + NNCLRProjectionHead, + NNMemoryBankModule, +) class NNCLR(pl.LightningModule): @@ -70,9 +71,9 @@ def configure_optimizers(self): # train with DDP and use Synchronized Batch Norm for a more accurate batch norm # calculation trainer = pl.Trainer( - max_epochs=10, + max_epochs=10, gpus=gpus, - strategy='ddp', + strategy="ddp", sync_batchnorm=True, ) trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/examples/pytorch_lightning_distributed/simclr.py b/examples/pytorch_lightning_distributed/simclr.py index 52a86f0d0..26d381bd6 100644 --- a/examples/pytorch_lightning_distributed/simclr.py +++ b/examples/pytorch_lightning_distributed/simclr.py @@ -1,13 +1,13 @@ +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import NTXentLoss from lightly.models.modules import SimCLRProjectionHead + class SimCLR(pl.LightningModule): def __init__(self): super().__init__() @@ -15,7 +15,7 @@ def __init__(self): self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = SimCLRProjectionHead(512, 2048, 2048) - # enable gather_distributed to gather features from all gpus + # enable gather_distributed to gather features from all gpus # before calculating the loss self.criterion = NTXentLoss(gather_distributed=True) @@ -45,7 +45,7 @@ def configure_optimizers(self): collate_fn = SimCLRCollateFunction( input_size=32, - gaussian_blur=0., + gaussian_blur=0.0, ) dataloader = torch.utils.data.DataLoader( @@ -62,9 +62,9 @@ def configure_optimizers(self): # train with DDP and use Synchronized Batch Norm for a more accurate batch norm # calculation trainer = pl.Trainer( - max_epochs=10, + max_epochs=10, gpus=gpus, - strategy='ddp', + strategy="ddp", sync_batchnorm=True, ) trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/examples/pytorch_lightning_distributed/simmim.py b/examples/pytorch_lightning_distributed/simmim.py index 36f1b3345..9a2ad0708 100644 --- a/examples/pytorch_lightning_distributed/simmim.py +++ b/examples/pytorch_lightning_distributed/simmim.py @@ -1,10 +1,10 @@ +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn from lightly.data import LightlyDataset -from lightly.data.collate import MAECollateFunction # Same collate as MAE +from lightly.data.collate import MAECollateFunction # Same collate as MAE from lightly.models import utils from lightly.models.modules import masked_autoencoder @@ -12,7 +12,7 @@ class SimMIM(pl.LightningModule): def __init__(self): super().__init__() - + decoder_dim = vit.hidden_dim vit = torchvision.models.vit_b_32(pretrained=False) self.mask_ratio = 0.75 @@ -21,19 +21,18 @@ def __init__(self): self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim)) # same backbone as MAE - self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) + self.backbone = masked_autoencoder.MAEBackbone.from_vit(vit) # the decoder is a simple linear layer - self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size ** 2 * 3) + self.decoder = nn.Linear(vit.hidden_dim, vit.patch_size**2 * 3) # L1 loss as paper suggestion self.criterion = nn.L1Loss() - def forward_encoder(self, images, batch_size, idx_mask): # pass all the tokens to the encoder, both masked and non masked ones tokens = self.backbone.images_to_tokens(images, prepend_class_token=True) - tokens_masked = utils.mask_at_index(tokens, idx_mask , self.mask_token) + tokens_masked = utils.mask_at_index(tokens, idx_mask, self.mask_token) return self.backbone.encoder(tokens_masked) def forward_decoder(self, x_encoded): @@ -48,7 +47,7 @@ def training_step(self, batch, batch_idx): mask_ratio=self.mask_ratio, device=images.device, ) - + # Encoding... x_encoded = self.forward_encoder(images, batch_size, idx_mask) x_encoded_masked = utils.get_at_index(x_encoded, idx_mask) @@ -58,13 +57,13 @@ def training_step(self, batch, batch_idx): # get image patches for masked tokens patches = utils.patchify(images, self.patch_size) - + # must adjust idx_mask for missing class token target = utils.get_at_index(patches, idx_mask - 1) loss = self.criterion(x_out, target) return loss - + def configure_optimizers(self): optim = torch.optim.AdamW(self.parameters(), lr=1.5e-4) return optim @@ -96,12 +95,12 @@ def configure_optimizers(self): gpus = torch.cuda.device_count() -# Train with DDP on multiple gpus. Distributed sampling is also enabled with +# Train with DDP on multiple gpus. Distributed sampling is also enabled with # replace_sampler_ddp=True. trainer = pl.Trainer( - max_epochs=10, + max_epochs=10, gpus=gpus, - strategy='ddp', + strategy="ddp", replace_sampler_ddp=True, ) trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/examples/pytorch_lightning_distributed/simsiam.py b/examples/pytorch_lightning_distributed/simsiam.py index b7c873da3..fa1ef8896 100644 --- a/examples/pytorch_lightning_distributed/simsiam.py +++ b/examples/pytorch_lightning_distributed/simsiam.py @@ -1,17 +1,15 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss import NegativeCosineSimilarity -from lightly.models.modules import SimSiamProjectionHead -from lightly.models.modules import SimSiamPredictionHead +from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead class SimSiam(pl.LightningModule): @@ -67,9 +65,9 @@ def configure_optimizers(self): # train with DDP and use Synchronized Batch Norm for a more accurate batch norm # calculation trainer = pl.Trainer( - max_epochs=10, + max_epochs=10, gpus=gpus, - strategy='ddp', + strategy="ddp", sync_batchnorm=True, ) trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/examples/pytorch_lightning_distributed/swav.py b/examples/pytorch_lightning_distributed/swav.py index a2f28c10b..f173811b9 100644 --- a/examples/pytorch_lightning_distributed/swav.py +++ b/examples/pytorch_lightning_distributed/swav.py @@ -1,17 +1,15 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SwaVCollateFunction +from lightly.data import LightlyDataset, SwaVCollateFunction from lightly.loss import SwaVLoss -from lightly.models.modules import SwaVProjectionHead -from lightly.models.modules import SwaVPrototypes +from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes class SwaV(pl.LightningModule): @@ -73,9 +71,9 @@ def configure_optimizers(self): # train with DDP and use Synchronized Batch Norm for a more accurate batch norm # calculation trainer = pl.Trainer( - max_epochs=10, + max_epochs=10, gpus=gpus, - strategy='ddp', + strategy="ddp", sync_batchnorm=True, ) trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/examples/pytorch_lightning_distributed/tico.py b/examples/pytorch_lightning_distributed/tico.py index 5fb56e2e1..1f257e696 100644 --- a/examples/pytorch_lightning_distributed/tico.py +++ b/examples/pytorch_lightning_distributed/tico.py @@ -1,16 +1,15 @@ -import torch -from torch import nn -import torchvision import copy + import pytorch_lightning as pl +import torch +import torchvision +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import SimCLRCollateFunction +from lightly.data import LightlyDataset, SimCLRCollateFunction from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead +from lightly.models.utils import deactivate_requires_grad, update_momentum from lightly.utils.scheduler import cosine_schedule -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum class TiCo(pl.LightningModule): diff --git a/examples/pytorch_lightning_distributed/vicreg.py b/examples/pytorch_lightning_distributed/vicreg.py index e840668ea..f1ca6eb6b 100644 --- a/examples/pytorch_lightning_distributed/vicreg.py +++ b/examples/pytorch_lightning_distributed/vicreg.py @@ -1,19 +1,19 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn -from lightly.data import LightlyDataset -from lightly.data import VICRegCollateFunction +from lightly.data import LightlyDataset, VICRegCollateFunction from lightly.loss import VICRegLoss ## The projection head is the same as the Barlow Twins one from lightly.models.modules import BarlowTwinsProjectionHead + class VICReg(pl.LightningModule): def __init__(self): super().__init__() @@ -21,7 +21,7 @@ def __init__(self): self.backbone = nn.Sequential(*list(resnet.children())[:-1]) self.projection_head = BarlowTwinsProjectionHead(512, 2048, 2048) - # enable gather_distributed to gather features from all gpus + # enable gather_distributed to gather features from all gpus # before calculating the loss self.criterion = VICRegLoss(gather_distributed=True) @@ -65,9 +65,9 @@ def configure_optimizers(self): # train with DDP and use Synchronized Batch Norm for a more accurate batch norm # calculation trainer = pl.Trainer( - max_epochs=10, + max_epochs=10, gpus=gpus, - strategy='ddp', + strategy="ddp", sync_batchnorm=True, ) trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/examples/pytorch_lightning_distributed/vicregl.py b/examples/pytorch_lightning_distributed/vicregl.py index 335ca5369..5f9321d70 100644 --- a/examples/pytorch_lightning_distributed/vicregl.py +++ b/examples/pytorch_lightning_distributed/vicregl.py @@ -1,19 +1,19 @@ -# Note: The model and training settings do not follow the reference settings +# Note: The model and training settings do not follow the reference settings # from the paper. The settings are chosen such that the example can easily be -# run on a small dataset with a single GPU. +# run on a small dataset with a single GPU. +import pytorch_lightning as pl import torch -from torch import nn import torchvision -import pytorch_lightning as pl +from torch import nn from lightly.data import LightlyDataset from lightly.data.collate import VICRegLCollateFunction +from lightly.loss import VICRegLLoss + ## The global projection head is the same as the Barlow Twins one from lightly.models.modules import BarlowTwinsProjectionHead from lightly.models.modules.heads import VicRegLLocalProjectionHead -from lightly.loss import VICRegLLoss - class VICRegL(pl.LightningModule): @@ -30,21 +30,21 @@ def forward(self, x): x = self.backbone(x) y = self.average_pool(x).flatten(start_dim=1) z = self.projection_head(y) - y_local = x.permute(0, 2, 3, 1) # (B, D, W, H) to (B, W, H, D) - z_local = self.local_projection_head(y_local) + y_local = x.permute(0, 2, 3, 1) # (B, D, W, H) to (B, W, H, D) + z_local = self.local_projection_head(y_local) return z, z_local - + def training_step(self, batch, batch_index): (view_global, view_local, grid_global, grid_local), _, _ = batch z_global, z_global_local_features = self.forward(view_global) z_local, z_local_local_features = self.forward(view_local) loss = self.criterion( - z_global=z_global, - z_local=z_local, - z_global_local_features=z_global_local_features, - z_local_local_features=z_local_local_features, - grid_global=grid_global, - grid_local=grid_local + z_global=z_global, + z_local=z_local, + z_global_local_features=z_global_local_features, + z_local_local_features=z_local_local_features, + grid_global=grid_global, + grid_local=grid_local, ) return loss @@ -76,9 +76,9 @@ def configure_optimizers(self): # train with DDP and use Synchronized Batch Norm for a more accurate batch norm # calculation trainer = pl.Trainer( - max_epochs=10, + max_epochs=10, gpus=gpus, - strategy='ddp', + strategy="ddp", sync_batchnorm=True, ) trainer.fit(model=model, train_dataloaders=dataloader) diff --git a/lightly/__init__.py b/lightly/__init__.py index 20207129e..035f54c48 100644 --- a/lightly/__init__.py +++ b/lightly/__init__.py @@ -74,8 +74,8 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -__name__ = 'lightly' -__version__ = '1.3.0' +__name__ = "lightly" +__version__ = "1.3.0" import os @@ -91,7 +91,7 @@ if __LIGHTLY_SETUP__: # setting up lightly - msg = f'Partial import of {__name__}=={__version__} during build process.' + msg = f"Partial import of {__name__}=={__version__} during build process." print(msg) else: # see if prefetch_generator is available @@ -108,21 +108,32 @@ def _is_prefetch_generator_available(): # see if torchvision vision transformer is available try: import torchvision.models.vision_transformer + _torchvision_vit_available = True except ( RuntimeError, # Different CUDA versions for torch and torchvision - OSError, # Different CUDA versions for torch and torchvision (old) - ImportError, # No installation or old version of torchvision + OSError, # Different CUDA versions for torch and torchvision (old) + ImportError, # No installation or old version of torchvision ): _torchvision_vit_available = False - - if os.getenv('LIGHTLY_DID_VERSION_CHECK', 'False') == 'False': - os.environ['LIGHTLY_DID_VERSION_CHECK'] = 'True' + + if os.getenv("LIGHTLY_DID_VERSION_CHECK", "False") == "False": + os.environ["LIGHTLY_DID_VERSION_CHECK"] = "True" from multiprocessing import current_process - if current_process().name == 'MainProcess': - from lightly.api.version_checking import is_latest_version, LightlyAPITimeoutException + + if current_process().name == "MainProcess": + from lightly.api.version_checking import ( + LightlyAPITimeoutException, + is_latest_version, + ) from lightly.openapi_generated.swagger_client.rest import ApiException + try: is_latest_version(current_version=__version__) - except (ValueError, ApiException, LightlyAPITimeoutException, AttributeError): + except ( + ValueError, + ApiException, + LightlyAPITimeoutException, + AttributeError, + ): pass diff --git a/lightly/active_learning/agents/agent.py b/lightly/active_learning/agents/agent.py index 6ba45ec0e..0a627ecd5 100644 --- a/lightly/active_learning/agents/agent.py +++ b/lightly/active_learning/agents/agent.py @@ -1,5 +1,5 @@ -from typing import * import warnings +from typing import * from lightly.active_learning.config.selection_config import SelectionConfig from lightly.active_learning.scorers.scorer import Scorer @@ -24,9 +24,9 @@ class ActiveLearningAgent: Set of filenames corresponding to samples which are in the query set but not in the labeled set. added_set: - Set of filenames corresponding to samples which were added to the + Set of filenames corresponding to samples which were added to the labeled set in the last query. - + Raises: RuntimeError: If executed before a query. @@ -60,11 +60,12 @@ class ActiveLearningAgent: """ - def __init__(self, - api_workflow_client: ApiWorkflowClient, - query_tag_name: str = 'initial-tag', - preselected_tag_name: str = None): - + def __init__( + self, + api_workflow_client: ApiWorkflowClient, + query_tag_name: str = "initial-tag", + preselected_tag_name: str = None, + ): self.api_workflow_client = api_workflow_client # set the query_tag_id and preselected_tag_id @@ -86,60 +87,49 @@ def __init__(self, # keep track of the last preselected tag to compute added samples self._old_preselected_tag_bitmask = None - def _get_query_tag_bitmask(self): - """Initializes the query tag bitmask. - - """ + """Initializes the query tag bitmask.""" # get query tag from api and set bitmask accordingly query_tag_data = self.api_workflow_client._tags_api.get_tag_by_tag_id( - self.api_workflow_client.dataset_id, - tag_id=self._query_tag_id + self.api_workflow_client.dataset_id, tag_id=self._query_tag_id ) query_tag_bitmask = BitMask.from_hex(query_tag_data.bit_mask_data) return query_tag_bitmask def _get_preselected_tag_bitmask(self): - """Initializes the preselected tag bitmask. - - """ + """Initializes the preselected tag bitmask.""" if self._preselected_tag_id is None: # if not specified, no samples belong to the preselected tag - preselected_tag_bitmask = BitMask.from_hex('0x0') + preselected_tag_bitmask = BitMask.from_hex("0x0") else: # get preselected tag from api and set bitmask accordingly preselected_tag_data = self.api_workflow_client._tags_api.get_tag_by_tag_id( - self.api_workflow_client.dataset_id, - tag_id=self._preselected_tag_id + self.api_workflow_client.dataset_id, tag_id=self._preselected_tag_id + ) + preselected_tag_bitmask = BitMask.from_hex( + preselected_tag_data.bit_mask_data ) - preselected_tag_bitmask = BitMask.from_hex(preselected_tag_data.bit_mask_data) return preselected_tag_bitmask @property def query_set(self): - """List of filenames for which to calculate active learning scores. - - """ + """List of filenames for which to calculate active learning scores.""" return self._query_tag_bitmask.masked_select_from_list( self.api_workflow_client.get_filenames() ) @property def labeled_set(self): - """List of filenames indicating selected samples. - - """ + """List of filenames indicating selected samples.""" return self._preselected_tag_bitmask.masked_select_from_list( self.api_workflow_client.get_filenames() ) @property def unlabeled_set(self): - """List of filenames which belong to the query set but are not selected. - - """ + """List of filenames which belong to the query set but are not selected.""" # unlabeled set is the query set minus the preselected set unlabeled_tag_bitmask = self._query_tag_bitmask - self._preselected_tag_bitmask return unlabeled_tag_bitmask.masked_select_from_list( @@ -156,14 +146,15 @@ def added_set(self): """ # the added set only exists after a query if self._old_preselected_tag_bitmask is None: - raise RuntimeError('Cannot compute \"added set\" before querying.') + raise RuntimeError('Cannot compute "added set" before querying.') # added set is new preselected set minus the old one - added_tag_bitmask = self._preselected_tag_bitmask - self._old_preselected_tag_bitmask + added_tag_bitmask = ( + self._preselected_tag_bitmask - self._old_preselected_tag_bitmask + ) return added_tag_bitmask.masked_select_from_list( self.api_workflow_client.get_filenames() ) - def upload_scores(self, al_scorer: Scorer): """Computes and uploads active learning scores to the Lightly webapp. @@ -177,10 +168,10 @@ def upload_scores(self, al_scorer: Scorer): if al_scores_dict == {}: raise ValueError( - 'No scores found when calling `.calculate_scores()` of the ' - 'Scorer! If you use a generator, please make sure it is freshly ' - ' initialized.' - ) + "No scores found when calling `.calculate_scores()` of the " + "Scorer! If you use a generator, please make sure it is freshly " + " initialized." + ) # Check if the length of the query_set and each of the scores are the same no_query_samples = len(self.query_set) @@ -188,15 +179,12 @@ def upload_scores(self, al_scorer: Scorer): no_query_samples_with_scores = len(score) if no_query_samples != no_query_samples_with_scores: raise ValueError( - f'Number of query samples ({no_query_samples}) must match ' - f'the number of predictions ({no_query_samples_with_scores})!' + f"Number of query samples ({no_query_samples}) must match " + f"the number of predictions ({no_query_samples_with_scores})!" ) self.api_workflow_client.upload_scores(al_scores_dict, self._query_tag_id) - - def query(self, - selection_config: SelectionConfig, - al_scorer: Scorer = None): + def query(self, selection_config: SelectionConfig, al_scorer: Scorer = None): """Performs an active learning query. First the active learning scores are computed and uploaded, @@ -216,9 +204,9 @@ def query(self, # handle illogical stopping condition if selection_config.n_samples < len(self.labeled_set): warnings.warn( - f'ActiveLearningAgent.query: The number of samples ({selection_config.n_samples}) is ' - f'smaller than the number of preselected samples ({len(self.labeled_set)}).' - 'Skipping the active learning query.' + f"ActiveLearningAgent.query: The number of samples ({selection_config.n_samples}) is " + f"smaller than the number of preselected samples ({len(self.labeled_set)})." + "Skipping the active learning query." ) return @@ -229,7 +217,7 @@ def query(self, new_tag_data = self.api_workflow_client.selection( selection_config=selection_config, preselected_tag_id=self._preselected_tag_id, - query_tag_id=self._query_tag_id + query_tag_id=self._query_tag_id, ) # update the old preselected_tag diff --git a/lightly/active_learning/config/selection_config.py b/lightly/active_learning/config/selection_config.py index 426ff9a5f..b5f7545cc 100644 --- a/lightly/active_learning/config/selection_config.py +++ b/lightly/active_learning/config/selection_config.py @@ -1,8 +1,9 @@ import warnings from datetime import datetime -from lightly.openapi_generated.swagger_client.models.sampling_method import SamplingMethod - +from lightly.openapi_generated.swagger_client.models.sampling_method import ( + SamplingMethod, +) class SelectionConfig: @@ -34,9 +35,14 @@ class SelectionConfig: >>> config = SelectionConfig(method=SamplingMethod.CORESET, n_samples=-1, min_distance=0.1) """ - def __init__(self, method: SamplingMethod = SamplingMethod.CORESET, n_samples: int = 32, min_distance: float = -1, - name: str = None): + def __init__( + self, + method: SamplingMethod = SamplingMethod.CORESET, + n_samples: int = 32, + min_distance: float = -1, + name: str = None, + ): self.method = method self.n_samples = n_samples self.min_distance = min_distance @@ -47,12 +53,12 @@ def __init__(self, method: SamplingMethod = SamplingMethod.CORESET, n_samples: i class SamplingConfig(SelectionConfig): - def __init__(self, *args, **kwargs): - warnings.warn(PendingDeprecationWarning( - "SamplingConfig() is deprecated " - "in favour of SelectionConfig() " - "and will be removed in the future." - ), ) + warnings.warn( + PendingDeprecationWarning( + "SamplingConfig() is deprecated " + "in favour of SelectionConfig() " + "and will be removed in the future." + ), + ) SelectionConfig.__init__(self, *args, **kwargs) - diff --git a/lightly/active_learning/scorers/__init__.py b/lightly/active_learning/scorers/__init__.py index ebf9d76d9..2d28f62f5 100644 --- a/lightly/active_learning/scorers/__init__.py +++ b/lightly/active_learning/scorers/__init__.py @@ -3,7 +3,9 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -from lightly.active_learning.scorers.scorer import Scorer from lightly.active_learning.scorers.classification import ScorerClassification from lightly.active_learning.scorers.detection import ScorerObjectDetection -from lightly.active_learning.scorers.semantic_segmentation import ScorerSemanticSegmentation +from lightly.active_learning.scorers.scorer import Scorer +from lightly.active_learning.scorers.semantic_segmentation import ( + ScorerSemanticSegmentation, +) diff --git a/lightly/active_learning/scorers/classification.py b/lightly/active_learning/scorers/classification.py index 112020c15..9923e8404 100644 --- a/lightly/active_learning/scorers/classification.py +++ b/lightly/active_learning/scorers/classification.py @@ -26,19 +26,20 @@ def _entropy(probs: np.ndarray, axis: int = 1) -> np.ndarray: entropies = -1 * np.sum(probs * log_probs, axis=axis) return entropies + def _margin_largest_secondlargest(probs: np.ndarray) -> np.ndarray: """Computes the margin of a probability matrix - Args: - probs: - A probability matrix of shape (N, M) + Args: + probs: + A probability matrix of shape (N, M) - Exammple: - if probs.shape = (N, C) then margins.shape = (N, ) + Exammple: + if probs.shape = (N, C) then margins.shape = (N, ) - Returns: - The margin of the prediction vectors - """ + Returns: + The margin of the prediction vectors + """ sorted_probs = np.partition(probs, -2, axis=1) margins = sorted_probs[:, -1] - sorted_probs[:, -2] return margins @@ -87,13 +88,14 @@ class ScorerClassification(Scorer): >>> [0.1, 0.9], # predictions for img0.jpg >>> [0.3, 0.7], # predictions for img1.jpg >>> [0.8, 0.2], # predictions for img2.jpg - >>> ] + >>> ] >>> ) >>> np.sum(predictions, axis=1) >>> > array([1., 1., 1.]) >>> scorer = ScorerClassification(predictions) """ + def __init__(self, model_output: Union[np.ndarray, List[List[float]]]): if not isinstance(model_output, np.ndarray): model_output = np.array(model_output) @@ -104,19 +106,22 @@ def ensure_valid_model_output(self, model_output: np.ndarray) -> np.ndarray: if len(model_output) == 0: return model_output if len(model_output.shape) != 2: - raise ValueError("ScorerClassification model_output must be a 2-dimensional array") + raise ValueError( + "ScorerClassification model_output must be a 2-dimensional array" + ) if model_output.shape[1] == 0: - raise ValueError("ScorerClassification model_output must not have an empty dimension 1") + raise ValueError( + "ScorerClassification model_output must not have an empty dimension 1" + ) if model_output.shape[1] == 1: # assuming a binary classification problem with # the model_output denoting the probability of the first class - model_output = np.concatenate([model_output, 1-model_output], axis=1) + model_output = np.concatenate([model_output, 1 - model_output], axis=1) return model_output @classmethod def score_names(cls) -> List[str]: - """Returns the names of the calculated active learning scores - """ + """Returns the names of the calculated active learning scores""" score_names = list(cls(model_output=[[0.5, 0.5]]).calculate_scores().keys()) return score_names @@ -138,7 +143,7 @@ def calculate_scores(self, normalize_to_0_1: bool = True) -> Dict[str, np.ndarra scores_with_names = [ self._get_scores_uncertainty_least_confidence(), self._get_scores_uncertainty_margin(), - self._get_scores_uncertainty_entropy() + self._get_scores_uncertainty_entropy(), ] scores = dict() @@ -151,17 +156,26 @@ def calculate_scores(self, normalize_to_0_1: bool = True) -> Dict[str, np.ndarra return scores - def normalize_scores_0_1(self, scores: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + def normalize_scores_0_1( + self, scores: Dict[str, np.ndarray] + ) -> Dict[str, np.ndarray]: num_classes = self.model_output.shape[1] - model_output_very_sure = np.zeros(shape=(1,num_classes)) + model_output_very_sure = np.zeros(shape=(1, num_classes)) model_output_very_sure[0, 0] = 1 - model_output_very_unsure = np.ones_like(model_output_very_sure)/num_classes + model_output_very_unsure = np.ones_like(model_output_very_sure) / num_classes - scores_minimum = ScorerClassification(model_output_very_sure).calculate_scores(normalize_to_0_1=False) - scores_maximum = ScorerClassification(model_output_very_unsure).calculate_scores(normalize_to_0_1=False) + scores_minimum = ScorerClassification(model_output_very_sure).calculate_scores( + normalize_to_0_1=False + ) + scores_maximum = ScorerClassification( + model_output_very_unsure + ).calculate_scores(normalize_to_0_1=False) for score_name in scores.keys(): - interp_xp = [float(scores_minimum[score_name]), float(scores_maximum[score_name])] + interp_xp = [ + float(scores_minimum[score_name]), + float(scores_maximum[score_name]), + ] interp_fp = [0, 1] scores[score_name] = np.interp(scores[score_name], interp_xp, interp_fp) diff --git a/lightly/active_learning/scorers/detection.py b/lightly/active_learning/scorers/detection.py index 75fdfacd2..93c32a672 100644 --- a/lightly/active_learning/scorers/detection.py +++ b/lightly/active_learning/scorers/detection.py @@ -1,7 +1,4 @@ -from typing import Callable -from typing import Dict -from typing import List - +from typing import Callable, Dict, List import numpy as np @@ -10,9 +7,11 @@ from lightly.active_learning.utils.object_detection_output import ObjectDetectionOutput -def _object_frequency(model_output: List[ObjectDetectionOutput], - frequency_penalty: float, - min_score: float) -> np.ndarray: +def _object_frequency( + model_output: List[ObjectDetectionOutput], + frequency_penalty: float, + min_score: float, +) -> np.ndarray: """Score which prefers samples with many and diverse objects. Args: @@ -48,7 +47,9 @@ def _object_frequency(model_output: List[ObjectDetectionOutput], return np.asarray(scores) -def _objectness_least_confidence(model_output: List[ObjectDetectionOutput]) -> np.ndarray: +def _objectness_least_confidence( + model_output: List[ObjectDetectionOutput], +) -> np.ndarray: """Score which prefers samples with low max(class prob) * objectness. Args: @@ -64,19 +65,19 @@ def _objectness_least_confidence(model_output: List[ObjectDetectionOutput]) -> n if len(output.scores) > 0: # prediction margin is 1 - max(class probs), therefore the mean margin # is mean(1 - max(class probs)) which is 1 - mean(max(class probs)) - score = 1. - np.mean(output.scores) + score = 1.0 - np.mean(output.scores) else: # set the score to 0 if there was no bounding box detected - score = 0. + score = 0.0 scores.append(score) return np.asarray(scores) def _reduce_classification_scores_over_boxes( - model_output: List[ObjectDetectionOutput], - reduce_fn_over_bounding_boxes: Callable[[np.ndarray], float] = np.max, - default_value_no_bounding_box: float = 0 + model_output: List[ObjectDetectionOutput], + reduce_fn_over_bounding_boxes: Callable[[np.ndarray], float] = np.max, + default_value_no_bounding_box: float = 0, ) -> Dict[str, List[float]]: """Calculates classification scores over the mean of all found objects @@ -97,7 +98,9 @@ def _reduce_classification_scores_over_boxes( scores_dict_list: List[Dict[str, np.ndarray]] = [] for index_sample, output in enumerate(model_output): probs = np.array(output.class_probabilities) - scores_dict_this_sample = ScorerClassification(model_output=probs).calculate_scores() + scores_dict_this_sample = ScorerClassification( + model_output=probs + ).calculate_scores() scores_dict_list.append(scores_dict_this_sample) score_names = ScorerClassification.score_names() @@ -194,18 +197,13 @@ class only 50%. Lowering this value results in a more balanced """ - def __init__(self, - model_output: List[ObjectDetectionOutput], - config: Dict = None): + def __init__(self, model_output: List[ObjectDetectionOutput], config: Dict = None): self.model_output = model_output self.config = config self._check_config() def _check_config(self): - default_conf = { - 'frequency_penalty': 0.25, - 'min_score': 0.9 - } + default_conf = {"frequency_penalty": 0.25, "min_score": 0.9} # Check if we have a config dictionary passed in constructor if self.config is not None and isinstance(self.config, dict): @@ -213,22 +211,22 @@ def _check_config(self): for k in self.config.keys(): if k not in default_conf.keys(): raise KeyError( - f'Scorer config parameter {k} is not a valid key. ' - f'Use one of: {default_conf.keys()}' + f"Scorer config parameter {k} is not a valid key. " + f"Use one of: {default_conf.keys()}" ) # for now all values in config should be between 0.0 and 1.0 and numbers for k, v in self.config.items(): if not (isinstance(v, float) or isinstance(v, int)): raise ValueError( - f'Scorer config values must be numbers. However, ' - f'{k} has a value of type {type(v)}.' + f"Scorer config values must be numbers. However, " + f"{k} has a value of type {type(v)}." ) if v < 0.0 or v > 1.0: raise ValueError( - f'Scorer config parameter {k} value ({v}) out of range. ' - f'Should be between 0.0 and 1.0.' + f"Scorer config parameter {k} value ({v}) out of range. " + f"Should be between 0.0 and 1.0." ) # use default config if not specified in config @@ -239,8 +237,7 @@ def _check_config(self): @classmethod def score_names(cls) -> List[str]: - """Returns the names of the calculated active learning scores - """ + """Returns the names of the calculated active learning scores""" scorer = cls(model_output=[ObjectDetectionOutput([], [], [])]) score_names = list(scorer.calculate_scores().keys()) return score_names @@ -255,25 +252,33 @@ def calculate_scores(self) -> Dict[str, np.ndarray]: scores = dict() scores_with_names = [ self._get_object_frequency(), - self._get_objectness_least_confidence() + self._get_objectness_least_confidence(), ] for score, score_name in scores_with_names: score = np.nan_to_num(score) scores[score_name] = score # add classification scores - scores_dict_classification = \ - _reduce_classification_scores_over_boxes(model_output=self.model_output) + scores_dict_classification = _reduce_classification_scores_over_boxes( + model_output=self.model_output + ) for score_name, score in scores_dict_classification.items(): scores[score_name] = np.array(score) return scores def _get_object_frequency(self): - return _object_frequency( - self.model_output, - self.config['frequency_penalty'], - self.config['min_score']), "object_frequency" + return ( + _object_frequency( + self.model_output, + self.config["frequency_penalty"], + self.config["min_score"], + ), + "object_frequency", + ) def _get_objectness_least_confidence(self): - return _objectness_least_confidence(self.model_output), "objectness_least_confidence" + return ( + _objectness_least_confidence(self.model_output), + "objectness_least_confidence", + ) diff --git a/lightly/active_learning/scorers/keypoint_detection.py b/lightly/active_learning/scorers/keypoint_detection.py index 14dbd19db..cab737a7c 100644 --- a/lightly/active_learning/scorers/keypoint_detection.py +++ b/lightly/active_learning/scorers/keypoint_detection.py @@ -1,16 +1,14 @@ -from typing import List, Dict +from typing import Dict, List import numpy as np from lightly.active_learning.scorers import Scorer -from lightly.active_learning.utils.keypoint_predictions import \ - KeypointPrediction +from lightly.active_learning.utils.keypoint_predictions import KeypointPrediction -def _mean_uncertainty( - model_output: List[KeypointPrediction]) -> np.ndarray: +def _mean_uncertainty(model_output: List[KeypointPrediction]) -> np.ndarray: """Score which prefers samples with low confidence score. - + The uncertainty score per image is 1 minus the mean confidence score of all its keypoints. @@ -25,13 +23,15 @@ def _mean_uncertainty( scores = [] for keypoint_prediction in model_output: confidences_image = [] - for keypoint_instance_prediction in keypoint_prediction.keypoint_instance_predictions: + for ( + keypoint_instance_prediction + ) in keypoint_prediction.keypoint_instance_predictions: confidences_instance = keypoint_instance_prediction.get_confidences() if len(confidences_instance) > 0: conf = np.mean(confidences_instance) confidences_image.append(conf) if len(confidences_image) > 0: - score = 1. - np.mean(confidences_image) + score = 1.0 - np.mean(confidences_image) scores.append(score) else: scores.append(0) @@ -81,17 +81,15 @@ def __init__(self, model_output: List[KeypointPrediction]): self.model_output = model_output def calculate_scores(self) -> Dict[str, np.ndarray]: - """Calculates and returns active learning scores in a dictionary. - """ + """Calculates and returns active learning scores in a dictionary.""" # add classification scores scores = dict() - scores['mean_uncertainty'] = _mean_uncertainty(self.model_output) + scores["mean_uncertainty"] = _mean_uncertainty(self.model_output) return scores @classmethod def score_names(cls) -> List[str]: - """Returns the names of the calculated active learning scores - """ + """Returns the names of the calculated active learning scores""" scorer = cls(model_output=[]) score_names = list(scorer.calculate_scores().keys()) return score_names diff --git a/lightly/active_learning/scorers/semantic_segmentation.py b/lightly/active_learning/scorers/semantic_segmentation.py index 0396ddeff..3cc3d8c1c 100644 --- a/lightly/active_learning/scorers/semantic_segmentation.py +++ b/lightly/active_learning/scorers/semantic_segmentation.py @@ -4,16 +4,17 @@ # All Rights Reserved -from typing import Callable, Union, Generator, List, Dict +from typing import Callable, Dict, Generator, List, Union import numpy as np -from lightly.active_learning.scorers.scorer import Scorer from lightly.active_learning.scorers import ScorerClassification +from lightly.active_learning.scorers.scorer import Scorer -def _reduce_classification_scores_over_pixels(scores: np.ndarray, - reduce_fn_over_pixels: Callable[[np.ndarray], float] = np.mean): +def _reduce_classification_scores_over_pixels( + scores: np.ndarray, reduce_fn_over_pixels: Callable[[np.ndarray], float] = np.mean +): """Reduces classification scores to a single floating point number. Args: @@ -43,8 +44,8 @@ def _calculate_scores_for_single_prediction(prediction: np.ndarray): """ if len(prediction.shape) != 3: raise ValueError( - 'Invalid shape for semantic segmentation prediction! Expected ' - f'input of shape W x H x C but got {prediction.shape}.' + "Invalid shape for semantic segmentation prediction! Expected " + f"input of shape W x H x C but got {prediction.shape}." ) # reshape the W x H x C prediction into (W x H) x C @@ -59,8 +60,7 @@ def _calculate_scores_for_single_prediction(prediction: np.ndarray): # reduce over pixels for score_name, scores in classification_scorer.calculate_scores().items(): - scores_dict[score_name] = \ - _reduce_classification_scores_over_pixels(scores) + scores_dict[score_name] = _reduce_classification_scores_over_pixels(scores) return scores_dict @@ -73,14 +73,14 @@ class ScorerSemanticSegmentation(Scorer): Currently supports the following scores: `uncertainty scores`: - These scores are calculated by treating each pixel as its own + These scores are calculated by treating each pixel as its own classification task and taking the average of the classification uncertainty scores. Attributes: model_output: List or generator of semantic segmentation predictions. Each - prediction should be of shape W x H x C, where C is the number + prediction should be of shape W x H x C, where C is the number of classes (e.g. C=2 for two classes foreground and background). Examples: @@ -88,7 +88,7 @@ class ScorerSemanticSegmentation(Scorer): >>> def generator(filenames: List[string]): >>> for filename in filenames: >>> path = os.path.join(ROOT_PATH, filename) - >>> img_tensor = prepare_img(path).to('cuda') + >>> img_tensor = prepare_img(path).to('cuda') >>> with torch.no_grad(): >>> out = model(img_tensor) >>> out = torch.softmax(out, axis=1) @@ -103,8 +103,9 @@ class ScorerSemanticSegmentation(Scorer): """ - def __init__(self, - model_output: Union[List[np.ndarray], Generator[np.ndarray, None, None]]): + def __init__( + self, model_output: Union[List[np.ndarray], Generator[np.ndarray, None, None]] + ): self.model_output = model_output def calculate_scores(self) -> Dict[str, np.ndarray]: @@ -119,9 +120,8 @@ def calculate_scores(self) -> Dict[str, np.ndarray]: # iterate over list or generator of model outputs # careful! we can only iterate once if it's a generator for prediction in self.model_output: - # get all active learning scores for this prediction - # scores_ is a dictionary where each key is a score name and each + # scores_ is a dictionary where each key is a score name and each # item is a floating point number indicating the score scores_ = _calculate_scores_for_single_prediction(prediction) diff --git a/lightly/active_learning/utils/__init__.py b/lightly/active_learning/utils/__init__.py index 81ba8feb5..d91e548cb 100644 --- a/lightly/active_learning/utils/__init__.py +++ b/lightly/active_learning/utils/__init__.py @@ -4,4 +4,4 @@ # All Rights Reserved from lightly.active_learning.utils.bounding_box import BoundingBox -from lightly.active_learning.utils.object_detection_output import ObjectDetectionOutput \ No newline at end of file +from lightly.active_learning.utils.object_detection_output import ObjectDetectionOutput diff --git a/lightly/active_learning/utils/bounding_box.py b/lightly/active_learning/utils/bounding_box.py index 931ad63e5..7f22657a0 100644 --- a/lightly/active_learning/utils/bounding_box.py +++ b/lightly/active_learning/utils/bounding_box.py @@ -33,13 +33,16 @@ class BoundingBox: """ - def __init__(self, x0: float, y0: float, x1: float, y1: float, clip_values: bool = True): + def __init__( + self, x0: float, y0: float, x1: float, y1: float, clip_values: bool = True + ): """ - clip_values: - Set to true to clip the values into [0, 1] instead of raising an error if they lie outside. + clip_values: + Set to true to clip the values into [0, 1] instead of raising an error if they lie outside. """ if clip_values: + def clip_to_0_1(value): return min(1, max(0, value)) @@ -48,23 +51,22 @@ def clip_to_0_1(value): x1 = clip_to_0_1(x1) y1 = clip_to_0_1(y1) - if x0 > 1 or x1 > 1 or y0 > 1 or y1 > 1 or \ - x0 < 0 or x1 < 0 or y0 < 0 or y1 < 0: + if x0 > 1 or x1 > 1 or y0 > 1 or y1 > 1 or x0 < 0 or x1 < 0 or y0 < 0 or y1 < 0: raise ValueError( - f'Bounding Box Coordinates must be relative to ' - f'image width and height but are ({x0}, {y0}, {x1}, {y1}).' + f"Bounding Box Coordinates must be relative to " + f"image width and height but are ({x0}, {y0}, {x1}, {y1})." ) if x0 >= x1: raise ValueError( - f'x0 must be smaller than x1 for bounding box ' - f'[{x0}, {y0}, {x1}, {y1}]' + f"x0 must be smaller than x1 for bounding box " + f"[{x0}, {y0}, {x1}, {y1}]" ) if y0 >= y1: raise ValueError( - 'y0 must be smaller than y1 for bounding box ' - f'[{x0}, {y0}, {x1}, {y1}]' + "y0 must be smaller than y1 for bounding box " + f"[{x0}, {y0}, {x1}, {y1}]" ) self.x0 = x0 @@ -91,25 +93,25 @@ def from_yolo_label(cls, x_center: float, y_center: float, w: float, h: float): >>> bbox = BoundingBox.from_yolo(0.5, 0.4, 0.2, 0.3) """ - return cls(x_center - w / 2, y_center - h / 2, x_center + w / 2, y_center + h / 2, clip_values=True) + return cls( + x_center - w / 2, + y_center - h / 2, + x_center + w / 2, + y_center + h / 2, + clip_values=True, + ) @property def width(self): - """Returns the width of the bounding box relative to the image size. - - """ + """Returns the width of the bounding box relative to the image size.""" return self.x1 - self.x0 @property def height(self): - """Returns the height of the bounding box relative to the image size. - - """ + """Returns the height of the bounding box relative to the image size.""" return self.y1 - self.y0 @property def area(self): - """Returns the area of the bounding box relative to the area of the image. - - """ + """Returns the area of the bounding box relative to the area of the image.""" return self.width * self.height diff --git a/lightly/active_learning/utils/keypoint_predictions.py b/lightly/active_learning/utils/keypoint_predictions.py index e5bd2343b..eda388f16 100644 --- a/lightly/active_learning/utils/keypoint_predictions.py +++ b/lightly/active_learning/utils/keypoint_predictions.py @@ -1,6 +1,6 @@ """ Keypoint """ import json -from typing import Union, Tuple, List, Dict +from typing import Dict, List, Tuple, Union import numpy as np @@ -26,8 +26,9 @@ class KeypointInstancePrediction: """ - def __init__(self, keypoints: List[float], category_id: int = -1, - score: float = -1.): + def __init__( + self, keypoints: List[float], category_id: int = -1, score: float = -1.0 + ): self.keypoints = keypoints self.category_id = category_id self.score = score @@ -45,24 +46,24 @@ def from_dict(cls, dict_: Dict[str, Union[int, List[float], float]]): Returns: """ - category_id = dict_['category_id'] - keypoints = dict_['keypoints'] - score = dict_['score'] + category_id = dict_["category_id"] + keypoints = dict_["keypoints"] + score = dict_["score"] return cls(keypoints=keypoints, category_id=category_id, score=score) def _format_check(self): - """Raises a ValueError if the format is not as required. - """ + """Raises a ValueError if the format is not as required.""" if not isinstance(self.category_id, int): raise ValueError( - f"Category_id must be an int, but is a {type(self.category_id)}") + f"Category_id must be an int, but is a {type(self.category_id)}" + ) if not isinstance(self.score, float): - raise ValueError( - f"Score must be a float, but is a {type(self.score)}") + raise ValueError(f"Score must be a float, but is a {type(self.score)}") if len(self.keypoints) % 3 != 0: - raise ValueError("Keypoints must be in the format of " - "[x0, y0, c0, ... xk, yk, ck].") + raise ValueError( + "Keypoints must be in the format of " "[x0, y0, c0, ... xk, yk, ck]." + ) confidences = self.get_confidences() if any(c < 0 for c in confidences): raise ValueError("Confidences contain values < 0.") @@ -70,9 +71,7 @@ def _format_check(self): raise ValueError("Confidences contain values > 1.") def get_confidences(self) -> List[float]: - """Returns the confidence of each keypoint - - """ + """Returns the confidence of each keypoint""" confidences = self.keypoints[2::3] return confidences @@ -80,25 +79,19 @@ def get_confidences(self) -> List[float]: class KeypointPrediction: """Class which represents all keypoint instance detections in one image. - Attributes: - keypoint_instance_predictions: - One KeypointInstancePrediction for each instance having keypoints - detected in the image. + Attributes: + keypoint_instance_predictions: + One KeypointInstancePrediction for each instance having keypoints + detected in the image. """ - def __init__( - self, - keypoint_instance_predictions: List[KeypointInstancePrediction] - ): + def __init__(self, keypoint_instance_predictions: List[KeypointInstancePrediction]): self.keypoint_instance_predictions = keypoint_instance_predictions @classmethod - def from_dicts( - cls, - dicts: List[Dict[str, Union[int, List[float], float]]] - ): - """ Creates a KeypointPrediction from predictions for each instance. + def from_dicts(cls, dicts: List[Dict[str, Union[int, List[float], float]]]): + """Creates a KeypointPrediction from predictions for each instance. Args: dicts: @@ -114,7 +107,7 @@ def from_dicts( @classmethod def from_json_string(cls, json_string: str): - """ Creates a KeypointPrediction from predictions for each instance. + """Creates a KeypointPrediction from predictions for each instance. Args: json_string: diff --git a/lightly/active_learning/utils/object_detection_output.py b/lightly/active_learning/utils/object_detection_output.py index 349566396..c4eae0d21 100644 --- a/lightly/active_learning/utils/object_detection_output.py +++ b/lightly/active_learning/utils/object_detection_output.py @@ -23,7 +23,7 @@ class ObjectDetectionOutput: are not passed on initialisation. Scores are by default set to `max(class prob) * objectness` for each bounding box. labels: - List of labels (i.e. argmax(class prob)). Are automatically inferred from + List of labels (i.e. argmax(class prob)). Are automatically inferred from the class probabilities. Examples: @@ -55,18 +55,19 @@ def __init__( class_probabilities: List[List[float]], scores: Optional[List[float]] = None, ): - if len(boxes) != len(object_probabilities) or \ - len(object_probabilities) != len(class_probabilities): + if len(boxes) != len(object_probabilities) or len(object_probabilities) != len( + class_probabilities + ): raise ValueError( - 'Boxes, object and class probabilities must be of same length but are ' - f'{len(boxes)}, {len(object_probabilities)}, and ' - f'{len(class_probabilities)}' + "Boxes, object and class probabilities must be of same length but are " + f"{len(boxes)}, {len(object_probabilities)}, and " + f"{len(class_probabilities)}" ) if scores is not None and len(scores) != len(boxes): raise ValueError( - f'Boxes and scores must be of same length but are {len(boxes)} and ' - f'{len(scores)}' + f"Boxes and scores must be of same length but are {len(boxes)} and " + f"{len(scores)}" ) self.boxes = boxes @@ -76,16 +77,16 @@ def __init__( if scores is None: # calculate the score as the object probability times the maximum # of the class probabilities - self.scores = [o * max(c) for o, c in zip(object_probabilities, class_probabilities)] + self.scores = [ + o * max(c) for o, c in zip(object_probabilities, class_probabilities) + ] else: self.scores = scores - @classmethod - def from_scores(cls, - boxes: List[BoundingBox], - scores: List[float], - labels: List[int]): + def from_scores( + cls, boxes: List[BoundingBox], scores: List[float], labels: List[int] + ): """Helper to convert from output format with scores. We advise not using this method if you want to use the uncertainty @@ -124,13 +125,13 @@ def from_scores(cls, """ if any([score > 1 for score in scores]): - raise ValueError('Scores must be smaller than or equal to one!') + raise ValueError("Scores must be smaller than or equal to one!") if any([score < 0 for score in scores]): - raise ValueError('Scores must be larger than or equal to zero!') + raise ValueError("Scores must be larger than or equal to zero!") if not all([isinstance(label, int) for label in labels]): - raise ValueError('Labels must be list of integers.') + raise ValueError("Labels must be list of integers.") # create fake object probabilities object_probabilities = [s for s in scores] @@ -148,4 +149,4 @@ def from_scores(cls, output = cls(boxes, object_probabilities, class_probabilities) output.scores = scores output.labels = labels - return output \ No newline at end of file + return output diff --git a/lightly/api/__init__.py b/lightly/api/__init__.py index 0e43641b7..7a3bd12a8 100644 --- a/lightly/api/__init__.py +++ b/lightly/api/__init__.py @@ -3,10 +3,9 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -from lightly.api.api_workflow_client import ApiWorkflowClient from lightly.api.api_workflow_artifacts import ArtifactNotExist +from lightly.api.api_workflow_client import ApiWorkflowClient from lightly.api.patch_rest_client import patch_rest_client - from lightly.openapi_generated.swagger_client.rest import RESTClientObject # Needed to handle list of arguments correctly diff --git a/lightly/api/api_workflow_artifacts.py b/lightly/api/api_workflow_artifacts.py index 504aefae7..1565d0e7f 100644 --- a/lightly/api/api_workflow_artifacts.py +++ b/lightly/api/api_workflow_artifacts.py @@ -1,8 +1,9 @@ import os + from lightly.api import download from lightly.openapi_generated.swagger_client import ( - DockerRunData, DockerRunArtifactData, + DockerRunData, ) from lightly.openapi_generated.swagger_client.models.docker_run_artifact_type import ( DockerRunArtifactType, diff --git a/lightly/api/api_workflow_client.py b/lightly/api/api_workflow_client.py index aff7534c1..26dbd45b2 100644 --- a/lightly/api/api_workflow_client.py +++ b/lightly/api/api_workflow_client.py @@ -1,28 +1,34 @@ +import os +import platform import warnings from io import IOBase from typing import * -import platform -import os import requests - -from lightly.api.api_workflow_artifacts import _ArtifactsMixin -from lightly.api.api_workflow_predictions import _PredictionsMixin -from lightly.api.api_workflow_tags import _TagsMixin from requests import Response from lightly.__init__ import __version__ +from lightly.api.api_workflow_artifacts import _ArtifactsMixin from lightly.api.api_workflow_collaboration import _CollaborationMixin from lightly.api.api_workflow_compute_worker import _ComputeWorkerMixin from lightly.api.api_workflow_datasets import _DatasetsMixin from lightly.api.api_workflow_datasources import _DatasourcesMixin from lightly.api.api_workflow_download_dataset import _DownloadDatasetMixin +from lightly.api.api_workflow_predictions import _PredictionsMixin from lightly.api.api_workflow_selection import _SelectionMixin +from lightly.api.api_workflow_tags import _TagsMixin from lightly.api.api_workflow_upload_dataset import _UploadDatasetMixin from lightly.api.api_workflow_upload_embeddings import _UploadEmbeddingsMixin from lightly.api.api_workflow_upload_metadata import _UploadCustomMetadataMixin -from lightly.api.utils import DatasourceType, get_signed_url_destination, get_api_client_configuration -from lightly.api.version_checking import is_compatible_version, LightlyAPITimeoutException +from lightly.api.utils import ( + DatasourceType, + get_api_client_configuration, + get_signed_url_destination, +) +from lightly.api.version_checking import ( + LightlyAPITimeoutException, + is_compatible_version, +) from lightly.openapi_generated.swagger_client import ( ApiClient, CollaborationApi, @@ -43,25 +49,26 @@ ScoresApi, TagsApi, ) - from lightly.utils.reordering import sort_items_by_keys # Env variable for server side encryption on S3 -LIGHTLY_S3_SSE_KMS_KEY = 'LIGHTLY_S3_SSE_KMS_KEY' - -class ApiWorkflowClient(_UploadEmbeddingsMixin, - _SelectionMixin, - _UploadDatasetMixin, - _DownloadDatasetMixin, - _DatasetsMixin, - _UploadCustomMetadataMixin, - _TagsMixin, - _DatasourcesMixin, - _ComputeWorkerMixin, - _CollaborationMixin, - _PredictionsMixin, - _ArtifactsMixin, - ): +LIGHTLY_S3_SSE_KMS_KEY = "LIGHTLY_S3_SSE_KMS_KEY" + + +class ApiWorkflowClient( + _UploadEmbeddingsMixin, + _SelectionMixin, + _UploadDatasetMixin, + _DownloadDatasetMixin, + _DatasetsMixin, + _UploadCustomMetadataMixin, + _TagsMixin, + _DatasourcesMixin, + _ComputeWorkerMixin, + _CollaborationMixin, + _PredictionsMixin, + _ArtifactsMixin, +): """Provides a uniform interface to communicate with the Lightly API. The APIWorkflowClient is used to communicate with the Lightly API. The client @@ -90,13 +97,15 @@ def __init__( embedding_id: Optional[str] = None, creator: str = Creator.USER_PIP, ): - try: if not is_compatible_version(__version__): warnings.warn( - UserWarning((f"Incompatible version of lightly pip package. " - f"Please upgrade to the latest version " - f"to be able to access the api.") + UserWarning( + ( + f"Incompatible version of lightly pip package. " + f"Please upgrade to the latest version " + f"to be able to access the api." + ) ) ) except LightlyAPITimeoutException: @@ -126,32 +135,39 @@ def __init__( self._scores_api = ScoresApi(api_client=self.api_client) self._samples_api = SamplesApi(api_client=self.api_client) self._quota_api = QuotaApi(api_client=self.api_client) - self._metadata_configurations_api = \ - MetaDataConfigurationsApi(api_client=self.api_client) + self._metadata_configurations_api = MetaDataConfigurationsApi( + api_client=self.api_client + ) self._predictions_api = PredictionsApi(api_client=self.api_client) @property def dataset_id(self) -> str: - '''The current dataset_id. + """The current dataset_id. If the dataset_id is set, it is returned. If it is not set, then the dataset_id of the last modified dataset is selected. - ''' + """ try: return self._dataset_id except AttributeError: all_datasets: List[DatasetData] = self.get_datasets() - datasets_sorted = sorted(all_datasets, key=lambda dataset: dataset.last_modified_at) + datasets_sorted = sorted( + all_datasets, key=lambda dataset: dataset.last_modified_at + ) last_modified_dataset = datasets_sorted[-1] self._dataset_id = last_modified_dataset.id - warnings.warn(UserWarning(f"Dataset has not been specified, " - f"taking the last modified dataset {last_modified_dataset.name} as default dataset.")) + warnings.warn( + UserWarning( + f"Dataset has not been specified, " + f"taking the last modified dataset {last_modified_dataset.name} as default dataset." + ) + ) return self._dataset_id @dataset_id.setter def dataset_id(self, dataset_id: str): """Sets the current dataset id for the client. - + Args: dataset_id: The new dataset id. @@ -165,11 +181,9 @@ def dataset_id(self, dataset_id: str): f"platform." ) self._dataset_id = dataset_id - def _order_list_by_filenames( - self, filenames_for_list: List[str], - list_to_order: List[object] + self, filenames_for_list: List[str], list_to_order: List[object] ) -> List[object]: """Orders a list such that it is in the order of the filenames specified on the server. @@ -198,8 +212,9 @@ def get_filenames(self) -> List[str]: This is an expensive operation, especially for large datasets. """ - filenames_on_server = self._mappings_api. \ - get_sample_mappings_by_dataset_id(dataset_id=self.dataset_id, field="fileName") + filenames_on_server = self._mappings_api.get_sample_mappings_by_dataset_id( + dataset_id=self.dataset_id, field="fileName" + ) return filenames_on_server def upload_file_with_signed_url( @@ -230,20 +245,25 @@ def upload_file_with_signed_url( # check to see if server side encryption for S3 is desired # see https://docs.aws.amazon.com/AmazonS3/latest/userguide/UsingServerSideEncryption.html # see https://docs.aws.amazon.com/AmazonS3/latest/userguide/UsingKMSEncryption.html - lightly_s3_sse_kms_key = os.environ.get(LIGHTLY_S3_SSE_KMS_KEY, '').strip() + lightly_s3_sse_kms_key = os.environ.get(LIGHTLY_S3_SSE_KMS_KEY, "").strip() # Only set s3 related headers when we are talking with s3 - if get_signed_url_destination(signed_write_url)==DatasourceType.S3 and lightly_s3_sse_kms_key: + if ( + get_signed_url_destination(signed_write_url) == DatasourceType.S3 + and lightly_s3_sse_kms_key + ): if headers is None: headers = {} # don't override previously set SSE - if 'x-amz-server-side-encryption' not in headers: - if lightly_s3_sse_kms_key.lower() == 'true': + if "x-amz-server-side-encryption" not in headers: + if lightly_s3_sse_kms_key.lower() == "true": # enable SSE with the key of amazon - headers['x-amz-server-side-encryption'] = 'AES256' + headers["x-amz-server-side-encryption"] = "AES256" else: # enable SSE with specific customer KMS key - headers['x-amz-server-side-encryption'] = 'aws:kms' - headers['x-amz-server-side-encryption-aws-kms-key-id'] = lightly_s3_sse_kms_key + headers["x-amz-server-side-encryption"] = "aws:kms" + headers[ + "x-amz-server-side-encryption-aws-kms-key-id" + ] = lightly_s3_sse_kms_key # start requests session and make put request sess = session or requests @@ -260,23 +280,24 @@ def set_request_timeout(self, timeout: Union[int, Tuple[int, int]]): Args: timeout: Timeout in seconds. Is either a single total_timeout value or a - (connect_timeout, read_timeout) tuple. + (connect_timeout, read_timeout) tuple. See https://urllib3.readthedocs.io/en/stable/reference/urllib3.util.html?highlight=timeout#urllib3.util.Timeout for details on the different values. """ set_api_client_request_timeout(client=self.api_client, timeout=timeout) -DEFAULT_API_TIMEOUT = 60 * 3 # seconds +DEFAULT_API_TIMEOUT = 60 * 3 # seconds + def set_api_client_request_timeout( - client: ApiClient, + client: ApiClient, timeout: Union[int, Tuple[int, int]] = DEFAULT_API_TIMEOUT, ): """Sets a default timeout for all requests with the client. - This function patches the request method of the api client. This is - necessary because the swagger api client does not respect any timeouts + This function patches the request method of the api client. This is + necessary because the swagger api client does not respect any timeouts configured by urllib3. Instead it expects a timeout to be passed with every request. Code here: https://github.com/lightly-ai/lightly/blob/ffbd32fe82f76b37c8ac497640355314474bfc3b/lightly/openapi_generated/swagger_client/rest.py#L141-L148 @@ -285,7 +306,7 @@ def set_api_client_request_timeout( Api client on which the timeout is applied. timeout: Timeout in seconds. Is either a single total_timeout value or a - (connect_timeout, read_timeout) tuple. + (connect_timeout, read_timeout) tuple. See https://urllib3.readthedocs.io/en/stable/reference/urllib3.util.html?highlight=timeout#urllib3.util.Timeout for details on the different values. @@ -293,9 +314,9 @@ def set_api_client_request_timeout( request_fn = client.rest_client.request def new_request_fn(*args, **kwargs): - request_timeout = kwargs['_request_timeout'] + request_timeout = kwargs["_request_timeout"] if request_timeout is None: - kwargs['_request_timeout'] = timeout + kwargs["_request_timeout"] = timeout return request_fn(*args, **kwargs) client.rest_client.request = new_request_fn diff --git a/lightly/api/api_workflow_collaboration.py b/lightly/api/api_workflow_collaboration.py index 1f4ae5f14..2c843412d 100644 --- a/lightly/api/api_workflow_collaboration.py +++ b/lightly/api/api_workflow_collaboration.py @@ -1,12 +1,17 @@ from typing import List -from lightly.openapi_generated.swagger_client.models.shared_access_config_create_request import SharedAccessConfigCreateRequest -from lightly.openapi_generated.swagger_client.models.shared_access_config_data import SharedAccessConfigData -from lightly.openapi_generated.swagger_client.models.shared_access_type import SharedAccessType +from lightly.openapi_generated.swagger_client.models.shared_access_config_create_request import ( + SharedAccessConfigCreateRequest, +) +from lightly.openapi_generated.swagger_client.models.shared_access_config_data import ( + SharedAccessConfigData, +) +from lightly.openapi_generated.swagger_client.models.shared_access_type import ( + SharedAccessType, +) class _CollaborationMixin: - def share_dataset_only_with(self, dataset_id: str, user_emails: List[str]): """Shares dataset with a list of users @@ -17,7 +22,7 @@ def share_dataset_only_with(self, dataset_id: str, user_emails: List[str]): Args: dataset_id: Identifier of dataset - user_emails: + user_emails: List of email addresses of users to grant write permission Examples: @@ -36,38 +41,41 @@ def share_dataset_only_with(self, dataset_id: str, user_emails: List[str]): >>> client.share_dataset_only_with(dataset_id="MY_DATASET_ID", user_emails=[]) """ body = SharedAccessConfigCreateRequest( - access_type=SharedAccessType.WRITE, - users=user_emails, - creator=self._creator + access_type=SharedAccessType.WRITE, users=user_emails, creator=self._creator + ) + self._collaboration_api.create_or_update_shared_access_config_by_dataset_id( + body=body, dataset_id=dataset_id ) - self._collaboration_api.create_or_update_shared_access_config_by_dataset_id(body=body, dataset_id=dataset_id) - def get_shared_users(self, dataset_id: str) -> List[str]: - """Get list of users that have access to the dataset - - Args: - dataset_id: - Identifier of dataset - - Returns: - List of email addresses of users that have write access to the dataset + """Get list of users that have access to the dataset - Examples: - >>> client = ApiWorkflowClient(token="MY_AWESOME_TOKEN") - >>> client.get_shared_users(dataset_id="MY_DATASET_ID") - >>> ["user@something.com"] - """ + Args: + dataset_id: + Identifier of dataset + + Returns: + List of email addresses of users that have write access to the dataset + + Examples: + >>> client = ApiWorkflowClient(token="MY_AWESOME_TOKEN") + >>> client.get_shared_users(dataset_id="MY_DATASET_ID") + >>> ["user@something.com"] + """ - access_configs: List[SharedAccessConfigData] = self._collaboration_api.get_shared_access_configs_by_dataset_id(dataset_id=dataset_id) - user_emails = [] + access_configs: List[ + SharedAccessConfigData + ] = self._collaboration_api.get_shared_access_configs_by_dataset_id( + dataset_id=dataset_id + ) + user_emails = [] - # iterate through configs and find first WRITE config - # we use the same hard rule in the frontend to communicate with the API - # as we currently only support WRITE access - for access_config in access_configs: - if access_config.access_type == SharedAccessType.WRITE: - user_emails.extend(access_config.users) - break + # iterate through configs and find first WRITE config + # we use the same hard rule in the frontend to communicate with the API + # as we currently only support WRITE access + for access_config in access_configs: + if access_config.access_type == SharedAccessType.WRITE: + user_emails.extend(access_config.users) + break - return user_emails + return user_emails diff --git a/lightly/api/api_workflow_compute_worker.py b/lightly/api/api_workflow_compute_worker.py index 4127d3812..d8784874d 100644 --- a/lightly/api/api_workflow_compute_worker.py +++ b/lightly/api/api_workflow_compute_worker.py @@ -1,12 +1,12 @@ import copy import dataclasses -from functools import partial -import time import difflib -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, Iterator +import time +from functools import partial +from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar, Union -from lightly.api.utils import retry from lightly.api import utils +from lightly.api.utils import retry from lightly.openapi_generated.swagger_client import ( ApiClient, CreateDockerWorkerRegistryEntryRequest, @@ -17,9 +17,9 @@ DockerRunScheduledState, DockerRunState, DockerWorkerConfigV2, + DockerWorkerConfigV2CreateRequest, DockerWorkerConfigV2Docker, DockerWorkerConfigV2Lightly, - DockerWorkerConfigV2CreateRequest, DockerWorkerType, SelectionConfig, SelectionConfigEntry, @@ -27,7 +27,6 @@ SelectionConfigEntryStrategy, TagData, ) - from lightly.openapi_generated.swagger_client.rest import ApiException STATE_SCHEDULED_ID_NOT_FOUND = "CANCELED_OR_NOT_EXISTING" @@ -75,8 +74,10 @@ def ended_successfully(self) -> bool: class _ComputeWorkerMixin: - def register_compute_worker(self, name: str = "Default", labels: Optional[List[str]] = None) -> str: - """Registers a new compute worker. + def register_compute_worker( + self, name: str = "Default", labels: Optional[List[str]] = None + ) -> str: + """Registers a new compute worker. If a worker with the same name already exists, the worker id of the existing worker is returned instead of registering a new worker. @@ -96,9 +97,9 @@ def register_compute_worker(self, name: str = "Default", labels: Optional[List[s if labels is None: labels = [] request = CreateDockerWorkerRegistryEntryRequest( - name=name, - worker_type=DockerWorkerType.FULL, - labels=labels, + name=name, + worker_type=DockerWorkerType.FULL, + labels=labels, creator=self._creator, ) response = self._compute_worker_api.register_docker_worker(request) @@ -172,7 +173,9 @@ def create_compute_worker_config( lightly=lightly, selection=selection, ) - request = DockerWorkerConfigV2CreateRequest(config=config, creator=self._creator) + request = DockerWorkerConfigV2CreateRequest( + config=config, creator=self._creator + ) response = self._compute_worker_api.create_docker_worker_config_v2(request) return response.id @@ -182,7 +185,7 @@ def schedule_compute_worker_run( lightly_config: Optional[Dict[str, Any]] = None, selection_config: Optional[Union[Dict[str, Any], SelectionConfig]] = None, priority: str = DockerRunScheduledPriority.MID, - runs_on: Optional[List[str]] = None + runs_on: Optional[List[str]] = None, ) -> str: """Schedules a run with the given configurations. @@ -213,7 +216,10 @@ def schedule_compute_worker_run( selection_config=selection_config, ) request = DockerRunScheduledCreateRequest( - config_id=config_id, priority=priority, runs_on=runs_on, creator=self._creator, + config_id=config_id, + priority=priority, + runs_on=runs_on, + creator=self._creator, ) response = self._compute_worker_api.create_docker_run_scheduled_by_dataset_id( body=request, @@ -238,7 +244,7 @@ def get_compute_worker_runs( if dataset_id is not None: runs: List[DockerRunData] = utils.paginate_endpoint( self._compute_worker_api.get_docker_runs_query_by_dataset_id, - dataset_id=dataset_id + dataset_id=dataset_id, ) else: runs: List[DockerRunData] = utils.paginate_endpoint( @@ -254,19 +260,17 @@ def get_compute_worker_run(self, run_id: str) -> DockerRunData: ApiException: If no run with the given id exists. """ - return self._compute_worker_api.get_docker_run_by_id( - run_id=run_id - ) + return self._compute_worker_api.get_docker_run_by_id(run_id=run_id) def get_compute_worker_run_from_scheduled_run( - self, + self, scheduled_run_id: str, ) -> DockerRunData: """Returns a run given its scheduled run id. Raises: ApiException: - If no run with the given scheduled run id exists or if the scheduled + If no run with the given scheduled run id exists or if the scheduled run has not yet started being processed by a worker. """ return self._compute_worker_api.get_docker_run_by_scheduled_id( @@ -283,13 +287,14 @@ def get_scheduled_compute_worker_runs( Args: state: DockerRunScheduledState value. If specified, then only runs in the given - state are returned. If omitted, then runs which have not yet finished + state are returned. If omitted, then runs which have not yet finished (neither 'DONE' nor 'CANCELED') are returned. Valid states are 'OPEN', 'LOCKED', 'DONE', and 'CANCELED'. """ if state is not None: return self._compute_worker_api.get_docker_runs_scheduled_by_dataset_id( - dataset_id=self.dataset_id, state=state, + dataset_id=self.dataset_id, + state=state, ) return self._compute_worker_api.get_docker_runs_scheduled_by_dataset_id( dataset_id=self.dataset_id, @@ -448,6 +453,7 @@ def get_compute_worker_run_tags(self, run_id: str) -> List[TagData]: tags_in_dataset = [tag for tag in tags if tag.dataset_id == self.dataset_id] return tags_in_dataset + def selection_config_from_dict(cfg: Dict[str, Any]) -> SelectionConfig: """Recursively converts selection config from dict to a SelectionConfig instance.""" new_cfg = copy.deepcopy(cfg) @@ -462,11 +468,12 @@ def selection_config_from_dict(cfg: Dict[str, Any]) -> SelectionConfig: _T = TypeVar("_T") + def _get_deserialize( api_client: ApiClient, klass: Type[_T], ) -> Callable[[Dict[str, Any]], _T]: - """Returns the deserializer of the ApiClient class for class klass. + """Returns the deserializer of the ApiClient class for class klass. TODO(Philipp, 02/23): We should replace this by our own deserializer which accepts snake case strings as input. @@ -479,7 +486,7 @@ def _get_deserialize( def _config_to_camel_case(cfg: Dict[str, Any]) -> Dict[str, Any]: - """Converts all keys in the cfg dictionary to camelCase. """ + """Converts all keys in the cfg dictionary to camelCase.""" cfg_camel_case = {} for key, value in cfg.items(): key_camel_case = _snake_to_camel_case(key) @@ -491,11 +498,9 @@ def _config_to_camel_case(cfg: Dict[str, Any]) -> Dict[str, Any]: def _snake_to_camel_case(snake: str) -> str: - """Converts the snake_case input to camelCase. """ + """Converts the snake_case input to camelCase.""" components = snake.split("_") - return components[0] + "".join( - component.title() for component in components[1:] - ) + return components[0] + "".join(component.title() for component in components[1:]) def _validate_config( @@ -510,7 +515,7 @@ def _validate_config( Raises: TypeError: If obj is not of swagger type. - + """ if cfg is None: @@ -525,12 +530,11 @@ def _validate_config( if not hasattr(obj, key): possible_options = list(type(obj).swagger_types.keys()) closest_match = difflib.get_close_matches( - word=key, - possibilities=possible_options, - n=1, - cutoff=0. + word=key, possibilities=possible_options, n=1, cutoff=0.0 )[0] - error_msg = f"Option '{key}' does not exist! Did you mean '{closest_match}'?" + error_msg = ( + f"Option '{key}' does not exist! Did you mean '{closest_match}'?" + ) raise InvalidConfigurationError(error_msg) if isinstance(item, dict): _validate_config(item, getattr(obj, key)) diff --git a/lightly/api/api_workflow_datasources.py b/lightly/api/api_workflow_datasources.py index c9c46d477..6b81d17de 100644 --- a/lightly/api/api_workflow_datasources.py +++ b/lightly/api/api_workflow_datasources.py @@ -1,25 +1,25 @@ import time +import warnings from typing import Dict, List, Optional, Tuple, Union import tqdm -import warnings +from lightly.openapi_generated.swagger_client import DatasourceConfigVerifyDataErrors from lightly.openapi_generated.swagger_client.models.datasource_config import ( DatasourceConfig, ) -from lightly.openapi_generated.swagger_client.models.datasource_purpose import ( - DatasourcePurpose, -) from lightly.openapi_generated.swagger_client.models.datasource_processed_until_timestamp_request import ( DatasourceProcessedUntilTimestampRequest, ) from lightly.openapi_generated.swagger_client.models.datasource_processed_until_timestamp_response import ( DatasourceProcessedUntilTimestampResponse, ) +from lightly.openapi_generated.swagger_client.models.datasource_purpose import ( + DatasourcePurpose, +) from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data import ( DatasourceRawSamplesData, ) -from lightly.openapi_generated.swagger_client import DatasourceConfigVerifyDataErrors class _DatasourcesMixin: @@ -531,7 +531,7 @@ def set_s3_delegated_access_config( }, dataset_id=self.dataset_id, ) - + def set_obs_config( self, resource_path: str, diff --git a/lightly/api/api_workflow_download_dataset.py b/lightly/api/api_workflow_download_dataset.py index 7bd845f89..6cf697669 100644 --- a/lightly/api/api_workflow_download_dataset.py +++ b/lightly/api/api_workflow_download_dataset.py @@ -1,29 +1,26 @@ -from typing import Dict, List, Optional import io -import warnings import os -import tqdm +import warnings +from concurrent.futures.thread import ThreadPoolExecutor +from typing import Dict, List, Optional from urllib.request import Request, urlopen -from PIL import Image - -from lightly.api.utils import paginate_endpoint, retry -from lightly.utils.hipify import bcolors -from concurrent.futures.thread import ThreadPoolExecutor +import tqdm +from PIL import Image from lightly.api import download from lightly.api.bitmask import BitMask +from lightly.api.utils import paginate_endpoint, retry from lightly.openapi_generated.swagger_client import ( DatasetEmbeddingData, - ImageType, FileNameFormat, + ImageType, ) +from lightly.utils.hipify import bcolors def _make_dir_and_save_image(output_dir: str, filename: str, img: Image): - """Saves the images and creates necessary subdirectories. - - """ + """Saves the images and creates necessary subdirectories.""" path = os.path.join(output_dir, filename) head = os.path.split(path)[0] @@ -35,10 +32,8 @@ def _make_dir_and_save_image(output_dir: str, filename: str, img: Image): def _get_image_from_read_url(read_url: str): - """Makes a get request to the signed read url and returns the image. - - """ - request = Request(read_url, method='GET') + """Makes a get request to the signed read url and returns the image.""" + request = Request(read_url, method="GET") with urlopen(request) as response: blob = response.read() img = Image.open(io.BytesIO(blob)) @@ -46,12 +41,13 @@ def _get_image_from_read_url(read_url: str): class _DownloadDatasetMixin: - - def download_dataset(self, - output_dir: str, - tag_name: str = 'initial-tag', - max_workers: int = 8, - verbose: bool = True): + def download_dataset( + self, + output_dir: str, + tag_name: str = "initial-tag", + max_workers: int = 8, + verbose: bool = True, + ): """Downloads images from the web-app and stores them in output_dir. Args: @@ -91,8 +87,7 @@ def download_dataset(self, # get sample ids sample_ids = self._mappings_api.get_sample_mappings_by_dataset_id( - self.dataset_id, - field='_id' + self.dataset_id, field="_id" ) indices = BitMask.from_hex(tag.bit_mask_data).to_indices() @@ -108,11 +103,11 @@ def download_dataset(self, max_workers = max(max_workers, 1) if verbose: - print(f'Downloading {bcolors.OKGREEN}{len(sample_ids)}{bcolors.ENDC} images (with {bcolors.OKGREEN}{max_workers}{bcolors.ENDC} workers):', flush=True) - pbar = tqdm.tqdm( - unit='imgs', - total=len(sample_ids) + print( + f"Downloading {bcolors.OKGREEN}{len(sample_ids)}{bcolors.ENDC} images (with {bcolors.OKGREEN}{max_workers}{bcolors.ENDC} workers):", + flush=True, ) + pbar = tqdm.tqdm(unit="imgs", total=len(sample_ids)) tqdm_lock = tqdm.tqdm.get_lock() # define lambda function for concurrent download @@ -128,10 +123,8 @@ def lambda_(i): img = _get_image_from_read_url(read_url) _make_dir_and_save_image(output_dir, filename, img) success = True - except Exception as e: # pylint: disable=broad-except - warnings.warn( - f'Downloading of image {filename} failed with error {e}' - ) + except Exception as e: # pylint: disable=broad-except + warnings.warn(f"Downloading of image {filename} failed with error {e}") success = False # update the progress bar @@ -143,12 +136,11 @@ def lambda_(i): return success with ThreadPoolExecutor(max_workers=max_workers) as executor: - results = list(executor.map( - lambda_, downloadables, chunksize=1)) + results = list(executor.map(lambda_, downloadables, chunksize=1)) if not all(results): - msg = 'Warning: Unsuccessful download! ' - msg += 'Failed at image: {}'.format(results.index(False)) + msg = "Warning: Unsuccessful download! " + msg += "Failed at image: {}".format(results.index(False)) warnings.warn(msg) def get_all_embedding_data(self) -> List[DatasetEmbeddingData]: @@ -174,16 +166,15 @@ def get_embedding_data_by_name(self, name: str) -> DatasetEmbeddingData: ) def download_embeddings_csv_by_id( - self, - embedding_id: str, + self, + embedding_id: str, output_path: str, ) -> None: """Downloads embeddings with the given embedding id from the dataset and saves them to the output path. """ read_url = self._embeddings_api.get_embeddings_csv_read_url_by_id( - dataset_id=self.dataset_id, - embedding_id=embedding_id + dataset_id=self.dataset_id, embedding_id=embedding_id ) download.download_and_write_file(url=read_url, output_path=output_path) @@ -204,11 +195,10 @@ def download_embeddings_csv(self, output_path: str) -> None: f"Could not find embeddings for dataset with id '{self.dataset_id}'." ) self.download_embeddings_csv_by_id( - embedding_id=last_embedding.id, + embedding_id=last_embedding.id, output_path=output_path, ) - def export_label_studio_tasks_by_tag_id( self, tag_id: str, @@ -230,7 +220,7 @@ def export_label_studio_tasks_by_tag_id( self._tags_api.export_tag_to_label_studio_tasks, page_size=20000, dataset_id=self.dataset_id, - tag_id=tag_id + tag_id=tag_id, ) return label_studio_tasks @@ -284,7 +274,7 @@ def export_label_box_data_rows_by_tag_id( self._tags_api.export_tag_to_label_box_data_rows, page_size=20000, dataset_id=self.dataset_id, - tag_id=tag_id + tag_id=tag_id, ) return label_box_data_rows @@ -317,7 +307,6 @@ def export_label_box_data_rows_by_tag_name( tag = self.get_tag_by_name(tag_name) return self.export_label_box_data_rows_by_tag_id(tag.id) - def export_filenames_by_tag_id( self, tag_id: str, @@ -365,7 +354,6 @@ def export_filenames_by_tag_name( tag = self.get_tag_by_name(tag_name) return self.export_filenames_by_tag_id(tag.id) - def export_filenames_and_read_urls_by_tag_id( self, tag_id: str, @@ -436,7 +424,7 @@ def export_filenames_and_read_urls_by_tag_name( def _get_latest_default_embedding_data( - embeddings: List[DatasetEmbeddingData] + embeddings: List[DatasetEmbeddingData], ) -> Optional[DatasetEmbeddingData]: """Returns the latest embedding data with a default name or None if no such default embedding exists. @@ -446,4 +434,4 @@ def _get_latest_default_embedding_data( last_embedding = sorted(default_embeddings, key=lambda e: e.created_at)[-1] return last_embedding else: - return None \ No newline at end of file + return None diff --git a/lightly/api/api_workflow_predictions.py b/lightly/api/api_workflow_predictions.py index d6738d3e9..fdbe4fd17 100644 --- a/lightly/api/api_workflow_predictions.py +++ b/lightly/api/api_workflow_predictions.py @@ -55,7 +55,9 @@ def create_or_update_prediction_task_schema( def create_or_update_predictions( self, - sample_id_to_prediction_singletons: Mapping[str, Sequence[PredictionSingletonRepr]], + sample_id_to_prediction_singletons: Mapping[ + str, Sequence[PredictionSingletonRepr] + ], prediction_version_id: int = -1, progress_bar: Optional[tqdm.tqdm] = None, max_workers: int = 8, diff --git a/lightly/api/api_workflow_selection.py b/lightly/api/api_workflow_selection.py index 6113b9a05..5813e29f3 100644 --- a/lightly/api/api_workflow_selection.py +++ b/lightly/api/api_workflow_selection.py @@ -7,18 +7,23 @@ from lightly.active_learning.config.selection_config import SelectionConfig from lightly.openapi_generated.swagger_client import ActiveLearningScoreCreateRequest from lightly.openapi_generated.swagger_client.models.job_state import JobState -from lightly.openapi_generated.swagger_client.models.job_status_data import JobStatusData +from lightly.openapi_generated.swagger_client.models.job_status_data import ( + JobStatusData, +) +from lightly.openapi_generated.swagger_client.models.sampling_config import ( + SamplingConfig, +) +from lightly.openapi_generated.swagger_client.models.sampling_config_stopping_condition import ( + SamplingConfigStoppingCondition, +) +from lightly.openapi_generated.swagger_client.models.sampling_create_request import ( + SamplingCreateRequest, +) from lightly.openapi_generated.swagger_client.models.tag_data import TagData -from lightly.openapi_generated.swagger_client.models.sampling_config import SamplingConfig -from lightly.openapi_generated.swagger_client.models.sampling_create_request import SamplingCreateRequest -from lightly.openapi_generated.swagger_client.models.sampling_config_stopping_condition import \ - SamplingConfigStoppingCondition def _parse_active_learning_scores(scores: Union[np.ndarray, List]): - """Makes list/np.array of active learning scores serializable. - - """ + """Makes list/np.array of active learning scores serializable.""" # the api only accepts float64s if isinstance(scores, np.ndarray): scores = scores.astype(np.float64) @@ -28,9 +33,7 @@ def _parse_active_learning_scores(scores: Union[np.ndarray, List]): class _SelectionMixin: - def upload_scores(self, al_scores: Dict[str, np.ndarray], query_tag_id: str = None): - tags = self.get_all_tags() # upload the active learning scores to the api @@ -39,13 +42,13 @@ def upload_scores(self, al_scores: Dict[str, np.ndarray], query_tag_id: str = No # will be the query tag (i.e. query_tag = initial-tag) # set the query tag to the initial-tag if necessary if query_tag_id is None: - query_tag = next(t for t in tags if t.name == 'initial-tag') + query_tag = next(t for t in tags if t.name == "initial-tag") query_tag_id = query_tag.id # iterate over all available score types and upload them for score_type, score_values in al_scores.items(): body = ActiveLearningScoreCreateRequest( score_type=score_type, - scores=_parse_active_learning_scores(score_values) + scores=_parse_active_learning_scores(score_values), ) self._scores_api.create_or_update_active_learning_score_by_tag_id( body, @@ -63,8 +66,12 @@ def sampling(self, *args, **kwargs): ) return self.selection(*args, **kwargs) - def selection(self, selection_config: SelectionConfig, preselected_tag_id: str = None, query_tag_id: str = None) \ - -> TagData: + def selection( + self, + selection_config: SelectionConfig, + preselected_tag_id: str = None, + query_tag_id: str = None, + ) -> TagData: """Performs a selection given the arguments. Args: @@ -88,9 +95,11 @@ def selection(self, selection_config: SelectionConfig, preselected_tag_id: str = # make sure the tag name does not exist yet tags = self.get_all_tags() if selection_config.name in [tag.name for tag in tags]: - raise RuntimeError(f'There already exists a tag with tag_name {selection_config.name}.') + raise RuntimeError( + f"There already exists a tag with tag_name {selection_config.name}." + ) if len(tags) == 0: - raise RuntimeError('There exists no initial-tag for this dataset.') + raise RuntimeError("There exists no initial-tag for this dataset.") # make sure we have an embedding id try: @@ -99,9 +108,13 @@ def selection(self, selection_config: SelectionConfig, preselected_tag_id: str = self.set_embedding_id_to_latest() # trigger the selection - payload = self._create_selection_create_request(selection_config, preselected_tag_id, query_tag_id) + payload = self._create_selection_create_request( + selection_config, preselected_tag_id, query_tag_id + ) payload.row_count = self.get_all_tags()[0].tot_size - response = self._selection_api.trigger_sampling_by_id(payload, self.dataset_id, self.embedding_id) + response = self._selection_api.trigger_sampling_by_id( + payload, self.dataset_id, self.embedding_id + ) job_id = response.job_id # poll the job status till the job is not running anymore @@ -109,35 +122,49 @@ def selection(self, selection_config: SelectionConfig, preselected_tag_id: str = job_status_data = None wait_time_till_next_poll = getattr(self, "wait_time_till_next_poll", 1) - while job_status_data is None \ - or job_status_data.status == JobState.RUNNING \ - or job_status_data.status == JobState.WAITING \ - or job_status_data.status == JobState.UNKNOWN: + while ( + job_status_data is None + or job_status_data.status == JobState.RUNNING + or job_status_data.status == JobState.WAITING + or job_status_data.status == JobState.UNKNOWN + ): # sleep before polling again time.sleep(wait_time_till_next_poll) # try to read the sleep time until the next poll from the status data try: - job_status_data: JobStatusData = self._jobs_api.get_job_status_by_id(job_id=job_id) + job_status_data: JobStatusData = self._jobs_api.get_job_status_by_id( + job_id=job_id + ) wait_time_till_next_poll = job_status_data.wait_time_till_next_poll except Exception as err: exception_counter += 1 if exception_counter == 20: - print(f"Selection job with job_id {job_id} could not be started because of error: {err}") + print( + f"Selection job with job_id {job_id} could not be started because of error: {err}" + ) raise err if job_status_data.status == JobState.FAILED: - raise RuntimeError(f"Selection job with job_id {job_id} failed with error {job_status_data.error}") + raise RuntimeError( + f"Selection job with job_id {job_id} failed with error {job_status_data.error}" + ) # get the new tag from the job status new_tag_id = job_status_data.result.data if new_tag_id is None: raise RuntimeError(f"TagId returned by job with job_id {job_id} is None.") - new_tag_data = self._tags_api.get_tag_by_tag_id(self.dataset_id, tag_id=new_tag_id) + new_tag_data = self._tags_api.get_tag_by_tag_id( + self.dataset_id, tag_id=new_tag_id + ) return new_tag_data - def _create_selection_create_request(self, selection_config: SelectionConfig, preselected_tag_id: str, query_tag_id: str - ) -> SamplingCreateRequest: + def _create_selection_create_request( + self, + selection_config: SelectionConfig, + preselected_tag_id: str, + query_tag_id: str, + ) -> SamplingCreateRequest: """Creates a SamplingCreateRequest First, it checks how many samples are already labeled by @@ -151,12 +178,14 @@ def _create_selection_create_request(self, selection_config: SelectionConfig, pr sampling_config = SamplingConfig( stopping_condition=SamplingConfigStoppingCondition( n_samples=selection_config.n_samples, - min_distance=selection_config.min_distance + min_distance=selection_config.min_distance, ) ) - sampling_create_request = SamplingCreateRequest(new_tag_name=selection_config.name, - method=selection_config.method, - config=sampling_config, - preselected_tag_id=preselected_tag_id, - query_tag_id=query_tag_id) + sampling_create_request = SamplingCreateRequest( + new_tag_name=selection_config.name, + method=selection_config.method, + config=sampling_config, + preselected_tag_id=preselected_tag_id, + query_tag_id=query_tag_id, + ) return sampling_create_request diff --git a/lightly/api/api_workflow_tags.py b/lightly/api/api_workflow_tags.py index 07ca004d9..80331ffea 100644 --- a/lightly/api/api_workflow_tags.py +++ b/lightly/api/api_workflow_tags.py @@ -2,9 +2,9 @@ from lightly.api.bitmask import BitMask from lightly.openapi_generated.swagger_client import ( - TagArithmeticsRequest, - TagArithmeticsOperation, - TagBitMaskResponse, + TagArithmeticsOperation, + TagArithmeticsRequest, + TagBitMaskResponse, TagData, ) @@ -12,8 +12,8 @@ class TagDoesNotExistError(ValueError): pass -class _TagsMixin: +class _TagsMixin: def get_all_tags(self) -> List[TagData]: """Gets all tags in the Lightly Platform for current dataset id. @@ -51,14 +51,14 @@ def get_tag_by_name(self, tag_name: str) -> TagData: tag_name_id_dict = {tag.name: tag.id for tag in self.get_all_tags()} tag_id = tag_name_id_dict.get(tag_name, None) if tag_id is None: - raise TagDoesNotExistError(f'Your tag_name does not exist: {tag_name}.') + raise TagDoesNotExistError(f"Your tag_name does not exist: {tag_name}.") return self.get_tag_by_id(tag_id) def get_filenames_in_tag( - self, - tag_data: TagData, - filenames_on_server: List[str] = None, - exclude_parent_tag: bool = False, + self, + tag_data: TagData, + filenames_on_server: List[str] = None, + exclude_parent_tag: bool = False, ) -> List[str]: """Gets the filenames of a tag @@ -80,12 +80,15 @@ def get_filenames_in_tag( if exclude_parent_tag: parent_tag_id = tag_data.prev_tag_id tag_arithmetics_request = TagArithmeticsRequest( - tag_id1=tag_data.id, tag_id2=parent_tag_id, - operation=TagArithmeticsOperation.DIFFERENCE) - bit_mask_response: TagBitMaskResponse = \ + tag_id1=tag_data.id, + tag_id2=parent_tag_id, + operation=TagArithmeticsOperation.DIFFERENCE, + ) + bit_mask_response: TagBitMaskResponse = ( self._tags_api.perform_tag_arithmetics_bitmask( body=tag_arithmetics_request, dataset_id=self.dataset_id ) + ) bit_mask_data = bit_mask_response.bit_mask_data else: bit_mask_data = tag_data.bit_mask_data @@ -93,16 +96,14 @@ def get_filenames_in_tag( if not filenames_on_server: filenames_on_server = self.get_filenames() - filenames_tag = BitMask.from_hex(bit_mask_data).\ - masked_select_from_list(filenames_on_server) + filenames_tag = BitMask.from_hex(bit_mask_data).masked_select_from_list( + filenames_on_server + ) return filenames_tag def create_tag_from_filenames( - self, - fnames_new_tag: List[str], - new_tag_name: str, - parent_tag_id: str = None + self, fnames_new_tag: List[str], new_tag_name: str, parent_tag_id: str = None ) -> TagData: """Creates a new tag from a list of filenames. @@ -120,17 +121,19 @@ def create_tag_from_filenames( Raises: RuntimeError """ - + # make sure the tag name does not exist yet tags = self.get_all_tags() if new_tag_name in [tag.name for tag in tags]: - raise RuntimeError(f'There already exists a tag with tag_name {new_tag_name}.') + raise RuntimeError( + f"There already exists a tag with tag_name {new_tag_name}." + ) if len(tags) == 0: - raise RuntimeError('There exists no initial-tag for this dataset.') + raise RuntimeError("There exists no initial-tag for this dataset.") # fallback to initial tag if no parent tag is provided if parent_tag_id is None: - parent_tag_id = next(tag.id for tag in tags if tag.name=='initial-tag') + parent_tag_id = next(tag.id for tag in tags if tag.name == "initial-tag") # get list of filenames from tag fnames_server = self.get_filenames() @@ -142,26 +145,26 @@ def create_tag_from_filenames( for i, fname in enumerate(fnames_server): if fname in fnames_new_tag: bitmask.set_kth_bit(i) - + # quick sanity check num_selected_samples = len(bitmask.to_indices()) if num_selected_samples != len(fnames_new_tag): raise RuntimeError( - f'An error occured when creating the new subset! ' - f'Out of the {len(fnames_new_tag)} filenames you provided ' - f'to create a new tag, only {num_selected_samples} have been ' - f'found on the server. ' - f'Make sure you use the correct filenames. ' - f'Valid filename example from the dataset: {fnames_server[0]}' - ) + f"An error occured when creating the new subset! " + f"Out of the {len(fnames_new_tag)} filenames you provided " + f"to create a new tag, only {num_selected_samples} have been " + f"found on the server. " + f"Make sure you use the correct filenames. " + f"Valid filename example from the dataset: {fnames_server[0]}" + ) # create new tag tag_data_dict = { - 'name': new_tag_name, - 'prevTagId': parent_tag_id, - 'bitMaskData': bitmask.to_hex(), - 'totSize': tot_size, - 'creator': self._creator, + "name": new_tag_name, + "prevTagId": parent_tag_id, + "bitMaskData": bitmask.to_hex(), + "totSize": tot_size, + "creator": self._creator, } new_tag = self._tags_api.create_tag_by_dataset_id( @@ -173,17 +176,17 @@ def create_tag_from_filenames( def delete_tag_by_id(self, tag_id: str) -> None: """Deletes a tag from the current dataset on the Lightly Platform. - + Args: tag_id: The id of the tag to be deleted. - + """ self._tags_api.delete_tag_by_tag_id(self.dataset_id, tag_id) def delete_tag_by_name(self, tag_name: str) -> None: """Deletes a tag from the current dataset on the Lightly Platform. - + Args: tag_name: The name of the tag to be deleted. diff --git a/lightly/api/api_workflow_upload_dataset.py b/lightly/api/api_workflow_upload_dataset.py index c67e22e11..fc94d3396 100644 --- a/lightly/api/api_workflow_upload_dataset.py +++ b/lightly/api/api_workflow_upload_dataset.py @@ -3,59 +3,66 @@ import os import warnings -from typing import Union, Dict -from datetime import datetime from concurrent.futures.thread import ThreadPoolExecutor +from datetime import datetime +from typing import Dict, Union import tqdm from lightly_utils import image_processing -from lightly.utils.hipify import bcolors - -from lightly.api.utils import check_filename -from lightly.api.utils import MAXIMUM_FILENAME_LENGTH -from lightly.api.utils import retry -from lightly.api.utils import build_azure_signed_url_write_headers +from lightly.api.utils import ( + MAXIMUM_FILENAME_LENGTH, + build_azure_signed_url_write_headers, + check_filename, + retry, +) from lightly.openapi_generated.swagger_client import SampleWriteUrls -from lightly.openapi_generated.swagger_client.models.sample_create_request \ - import SampleCreateRequest -from lightly.openapi_generated.swagger_client.models.sample_partial_mode \ - import SamplePartialMode - -from lightly.openapi_generated.swagger_client.models.tag_upsize_request \ - import TagUpsizeRequest -from lightly.openapi_generated.swagger_client.models.initial_tag_create_request\ - import InitialTagCreateRequest -from lightly.openapi_generated.swagger_client.models.job_status_meta \ - import JobStatusMeta -from lightly.openapi_generated.swagger_client.models.job_status_upload_method \ - import JobStatusUploadMethod - -from lightly.openapi_generated.swagger_client.models.datasource_config_base import DatasourceConfigBase +from lightly.openapi_generated.swagger_client.models.datasource_config_base import ( + DatasourceConfigBase, +) +from lightly.openapi_generated.swagger_client.models.initial_tag_create_request import ( + InitialTagCreateRequest, +) +from lightly.openapi_generated.swagger_client.models.job_status_meta import ( + JobStatusMeta, +) +from lightly.openapi_generated.swagger_client.models.job_status_upload_method import ( + JobStatusUploadMethod, +) +from lightly.openapi_generated.swagger_client.models.sample_create_request import ( + SampleCreateRequest, +) +from lightly.openapi_generated.swagger_client.models.sample_partial_mode import ( + SamplePartialMode, +) +from lightly.openapi_generated.swagger_client.models.tag_upsize_request import ( + TagUpsizeRequest, +) from lightly.openapi_generated.swagger_client.rest import ApiException - +from lightly.utils.hipify import bcolors try: from lightly.data import LightlyDataset + _lightly_dataset_available = True except ( - RuntimeError, # Different CUDA versions for torch and torchvision - OSError, # Different CUDA versions for torch and torchvision (old) + RuntimeError, # Different CUDA versions for torch and torchvision + OSError, # Different CUDA versions for torch and torchvision (old) ImportError, # No installation of torch or torchvision ): _lightly_dataset_available = False class _UploadDatasetMixin: - """Mixin to upload datasets to the Lightly Api. - - """ - - def upload_dataset(self, - input: Union[str, "LightlyDataset"], - max_workers: int = 8, - mode: str = "thumbnails", - custom_metadata: Union[Dict, None] = None): + """Mixin to upload datasets to the Lightly Api.""" + + def upload_dataset( + self, + input: Union[str, "LightlyDataset"], + max_workers: int = 8, + mode: str = "thumbnails", + custom_metadata: Union[Dict, None] = None, + ): """Uploads a dataset to to the Lightly cloud solution. Args: @@ -82,8 +89,8 @@ def upload_dataset(self, tags = self.get_all_tags() if len(tags) > 0: print( - f'Dataset with id {self.dataset_id} has {bcolors.OKGREEN}{len(tags)}{bcolors.ENDC} tags.', - flush=True + f"Dataset with id {self.dataset_id} has {bcolors.OKGREEN}{len(tags)}{bcolors.ENDC} tags.", + flush=True, ) # parse "input" variable @@ -103,37 +110,35 @@ def upload_dataset(self, # upload the samples print( - f'Uploading {bcolors.OKGREEN}{len(dataset)}{bcolors.ENDC} images (with {bcolors.OKGREEN}{max_workers}{bcolors.ENDC} workers).', - flush=True + f"Uploading {bcolors.OKGREEN}{len(dataset)}{bcolors.ENDC} images (with {bcolors.OKGREEN}{max_workers}{bcolors.ENDC} workers).", + flush=True, ) # TODO: remove _size_in_bytes from image_processing - image_processing.metadata._size_in_bytes = \ - lambda img: 0 # pylint: disable=protected-access + image_processing.metadata._size_in_bytes = ( + lambda img: 0 + ) # pylint: disable=protected-access # get the filenames of the samples already on the server samples = retry( self._samples_api.get_samples_partial_by_dataset_id, dataset_id=self.dataset_id, - mode=SamplePartialMode.FILENAMES + mode=SamplePartialMode.FILENAMES, ) filenames_on_server = [sample.file_name for sample in samples] filenames_on_server_set = set(filenames_on_server) if len(filenames_on_server) > 0: print( - f'Found {bcolors.OKGREEN}{len(filenames_on_server)}{bcolors.ENDC} images already on the server' - ', they are skipped during the upload.' + f"Found {bcolors.OKGREEN}{len(filenames_on_server)}{bcolors.ENDC} images already on the server" + ", they are skipped during the upload." ) # check the maximum allowed dataset size - total_filenames = set(dataset.get_filenames()).union( - filenames_on_server_set - ) - max_dataset_size = \ - int(self._quota_api.get_quota_maximum_dataset_size()) + total_filenames = set(dataset.get_filenames()).union(filenames_on_server_set) + max_dataset_size = int(self._quota_api.get_quota_maximum_dataset_size()) if len(total_filenames) > max_dataset_size: - msg = f'Your dataset has {bcolors.OKGREEN}{len(dataset)}{bcolors.ENDC} samples which' - msg += f' is more than the allowed maximum of {bcolors.OKGREEN}{max_dataset_size}{bcolors.ENDC}' + msg = f"Your dataset has {bcolors.OKGREEN}{len(dataset)}{bcolors.ENDC} samples which" + msg += f" is more than the allowed maximum of {bcolors.OKGREEN}{max_dataset_size}{bcolors.ENDC}" raise ValueError(msg) # index custom metadata by filename (only if it exists) @@ -147,9 +152,9 @@ def upload_dataset(self, # get the datasource try: datasource_config: DatasourceConfigBase = self.get_datasource() - datasource_type = datasource_config['type'] + datasource_type = datasource_config["type"] except ApiException: - datasource_type = 'LIGHTLY' # default to lightly datasource + datasource_type = "LIGHTLY" # default to lightly datasource # register dataset upload job_status_meta = JobStatusMeta( @@ -159,12 +164,11 @@ def upload_dataset(self, upload_method=JobStatusUploadMethod.USER_PIP, ) self._datasets_api.register_dataset_upload_by_id( - job_status_meta, - self.dataset_id + job_status_meta, self.dataset_id ) pbar = tqdm.tqdm( - unit='imgs', + unit="imgs", total=len(total_filenames) - len(filenames_on_server), ) tqdm_lock = tqdm.tqdm.get_lock() @@ -193,10 +197,8 @@ def lambda_(i): datasource_type=datasource_type, ) success = True - except Exception as e: # pylint: disable=broad-except - warnings.warn( - f'Upload of image {filename} failed with error {e}' - ) + except Exception as e: # pylint: disable=broad-except + warnings.warn(f"Upload of image {filename} failed with error {e}") success = False # update the progress bar @@ -207,22 +209,23 @@ def lambda_(i): return success with ThreadPoolExecutor(max_workers=max_workers) as executor: - results = list(executor.map( - lambda_, [i for i in range(len(dataset))], chunksize=1)) + results = list( + executor.map(lambda_, [i for i in range(len(dataset))], chunksize=1) + ) if not all(results): - msg = 'Warning: Unsuccessful upload(s)! ' - msg += 'This could cause problems when uploading embeddings.' - msg += 'Failed at image: {}'.format(results.index(False)) + msg = "Warning: Unsuccessful upload(s)! " + msg += "This could cause problems when uploading embeddings." + msg += "Failed at image: {}".format(results.index(False)) warnings.warn(msg) # set image type of data and create initial tag - if mode == 'full': - img_type = 'full' - elif mode == 'thumbnails': - img_type = 'thumbnail' + if mode == "full": + img_type = "full" + elif mode == "thumbnails": + img_type = "thumbnail" else: - img_type = 'meta' + img_type = "meta" if len(tags) == 0: # create initial tag @@ -237,7 +240,7 @@ def lambda_(i): else: # upsize existing tags upsize_tags_request = TagUpsizeRequest( - upsize_tag_name=datetime.now().strftime('%Y%m%d_%Hh%Mm%Ss'), + upsize_tag_name=datetime.now().strftime("%Y%m%d_%Hh%Mm%Ss"), upsize_tag_creator=self._creator, ) self._tags_api.upsize_tags_by_dataset_id( @@ -245,36 +248,38 @@ def lambda_(i): dataset_id=self.dataset_id, ) - def _upload_single_image(self, - image, - filename: str, - filepath: str, - mode: str, - custom_metadata: Union[Dict, None] = None, - datasource_type: str = 'LIGHTLY'): - """Uploads a single image to the Lightly platform. - - """ + def _upload_single_image( + self, + image, + filename: str, + filepath: str, + mode: str, + custom_metadata: Union[Dict, None] = None, + datasource_type: str = "LIGHTLY", + ): + """Uploads a single image to the Lightly platform.""" # check whether the filepath is too long if not check_filename(filepath): - msg = ('Filepath {filepath} is longer than the allowed maximum of ' - f'{MAXIMUM_FILENAME_LENGTH} characters and will be skipped.') + msg = ( + "Filepath {filepath} is longer than the allowed maximum of " + f"{MAXIMUM_FILENAME_LENGTH} characters and will be skipped." + ) raise ValueError(msg) # calculate metadata, and check if corrupted metadata = image_processing.Metadata(image).to_dict() - metadata['sizeInBytes'] = os.path.getsize(filepath) + metadata["sizeInBytes"] = os.path.getsize(filepath) # try to get exif data try: exifdata = image_processing.Exifdata(image) - except Exception: # pylint disable=broad-except + except Exception: # pylint disable=broad-except exifdata = None # generate thumbnail if necessary thumbname = None - if not metadata['is_corrupted'] and mode in ['thumbnails', 'full']: - thumbname = '.'.join(filename.split('.')[:-1]) + '_thumb.webp' + if not metadata["is_corrupted"] and mode in ["thumbnails", "full"]: + thumbname = ".".join(filename.split(".")[:-1]) + "_thumb.webp" body = SampleCreateRequest( file_name=filename, @@ -286,21 +291,19 @@ def _upload_single_image(self, sample_id = retry( self._samples_api.create_sample_by_dataset_id, body=body, - dataset_id=self.dataset_id + dataset_id=self.dataset_id, ).id - if not metadata['is_corrupted'] and mode in ['thumbnails', 'full']: + if not metadata["is_corrupted"] and mode in ["thumbnails", "full"]: def upload_thumbnail(image, signed_url): thumbnail = image_processing.Thumbnail(image) image_to_upload = thumbnail.to_bytes() headers = None - if datasource_type == 'AZURE': + if datasource_type == "AZURE": # build headers for Azure blob storage size_in_bytes = str(image_to_upload.getbuffer().nbytes) - headers = build_azure_signed_url_write_headers( - size_in_bytes - ) + headers = build_azure_signed_url_write_headers(size_in_bytes) retry( self.upload_file_with_signed_url, image_to_upload, @@ -310,36 +313,34 @@ def upload_thumbnail(image, signed_url): thumbnail.thumbnail.close() def upload_full_image(filepath, signed_url): - with open(filepath, 'rb') as image_to_upload: + with open(filepath, "rb") as image_to_upload: headers = None - if datasource_type == 'AZURE': + if datasource_type == "AZURE": # build headers for Azure blob storage image_to_upload.seek(0, 2) size_in_bytes = str(image_to_upload.tell()) image_to_upload.seek(0, 0) - headers = build_azure_signed_url_write_headers( - size_in_bytes - ) + headers = build_azure_signed_url_write_headers(size_in_bytes) retry( self.upload_file_with_signed_url, image_to_upload, signed_url, - headers=headers + headers=headers, ) - if mode == 'thumbnails': + if mode == "thumbnails": thumbnail_url = retry( self._samples_api.get_sample_image_write_url_by_id, dataset_id=self.dataset_id, sample_id=sample_id, - is_thumbnail=True + is_thumbnail=True, ) upload_thumbnail(image, thumbnail_url) - elif mode == 'full': + elif mode == "full": sample_write_urls: SampleWriteUrls = retry( self._samples_api.get_sample_image_write_urls_by_id, dataset_id=self.dataset_id, - sample_id=sample_id + sample_id=sample_id, ) upload_thumbnail(image, sample_write_urls.thumb) upload_full_image(filepath, sample_write_urls.full) diff --git a/lightly/api/api_workflow_upload_embeddings.py b/lightly/api/api_workflow_upload_embeddings.py index d01623a9c..be5046d45 100644 --- a/lightly/api/api_workflow_upload_embeddings.py +++ b/lightly/api/api_workflow_upload_embeddings.py @@ -1,19 +1,23 @@ -import io import csv -import tempfile import hashlib +import io +import tempfile from datetime import datetime from typing import List from urllib.request import Request, urlopen -from lightly.api.utils import retry -from lightly.openapi_generated.swagger_client import \ - DimensionalityReductionMethod, Trigger2dEmbeddingJobRequest -from lightly.openapi_generated.swagger_client.models.dataset_embedding_data \ - import DatasetEmbeddingData -from lightly.openapi_generated.swagger_client.models.write_csv_url_data \ - import WriteCSVUrlData -from lightly.utils.io import check_filenames, check_embeddings +from lightly.api.utils import retry +from lightly.openapi_generated.swagger_client import ( + DimensionalityReductionMethod, + Trigger2dEmbeddingJobRequest, +) +from lightly.openapi_generated.swagger_client.models.dataset_embedding_data import ( + DatasetEmbeddingData, +) +from lightly.openapi_generated.swagger_client.models.write_csv_url_data import ( + WriteCSVUrlData, +) +from lightly.utils.io import check_embeddings, check_filenames class EmbeddingDoesNotExistError(ValueError): @@ -21,33 +25,31 @@ class EmbeddingDoesNotExistError(ValueError): class _UploadEmbeddingsMixin: - def _get_csv_reader_from_read_url(self, read_url: str): - """Makes a get request to the signed read url and returns the .csv file. - - """ - request = Request(read_url, method='GET') + """Makes a get request to the signed read url and returns the .csv file.""" + request = Request(read_url, method="GET") with urlopen(request) as response: - buffer = io.StringIO(response.read().decode('utf-8')) + buffer = io.StringIO(response.read().decode("utf-8")) reader = csv.reader(buffer) return reader def set_embedding_id_to_latest(self): - """Sets the self.embedding_id to the one of the latest on the server. - - """ - embeddings_on_server: List[DatasetEmbeddingData] = \ - self._embeddings_api.get_embeddings_by_dataset_id( - dataset_id=self.dataset_id - ) + """Sets the self.embedding_id to the one of the latest on the server.""" + embeddings_on_server: List[ + DatasetEmbeddingData + ] = self._embeddings_api.get_embeddings_by_dataset_id( + dataset_id=self.dataset_id + ) if len(embeddings_on_server) == 0: - raise RuntimeError(f"There are no known embeddings for dataset_id {self.dataset_id}.") + raise RuntimeError( + f"There are no known embeddings for dataset_id {self.dataset_id}." + ) # return first entry as the API returns newest first self.embedding_id = embeddings_on_server[0].id def get_embedding_by_name( - self, name: str, ignore_suffix: bool = True + self, name: str, ignore_suffix: bool = True ) -> DatasetEmbeddingData: """Gets an embedding form the server by name. @@ -67,20 +69,23 @@ def get_embedding_by_name( on the server. """ - embeddings_on_server: List[DatasetEmbeddingData] = \ - self._embeddings_api.get_embeddings_by_dataset_id( - dataset_id=self.dataset_id - ) + embeddings_on_server: List[ + DatasetEmbeddingData + ] = self._embeddings_api.get_embeddings_by_dataset_id( + dataset_id=self.dataset_id + ) try: if ignore_suffix: embedding = next( - embedding for embedding in embeddings_on_server if - embedding.name.startswith(name) + embedding + for embedding in embeddings_on_server + if embedding.name.startswith(name) ) else: embedding = next( - embedding for embedding in embeddings_on_server if - embedding.name == name + embedding + for embedding in embeddings_on_server + if embedding.name == name ) except StopIteration: raise EmbeddingDoesNotExistError( @@ -110,20 +115,24 @@ def upload_embeddings(self, path_to_embeddings_csv: str, name: str): try: embedding = self.get_embedding_by_name(name, ignore_suffix=True) # -> append rows from server - print('Appending embeddings from server.') + print("Appending embeddings from server.") self.append_embeddings(path_to_embeddings_csv, embedding.id) - now = datetime.now().strftime('%Y%m%d_%Hh%Mm%Ss') - name = f'{name}_{now}' + now = datetime.now().strftime("%Y%m%d_%Hh%Mm%Ss") + name = f"{name}_{now}" except EmbeddingDoesNotExistError: pass # create a new csv with the filenames in the desired order rows_csv = self._order_csv_by_filenames( - path_to_embeddings_csv=path_to_embeddings_csv) + path_to_embeddings_csv=path_to_embeddings_csv + ) # get the URL to upload the csv to - response: WriteCSVUrlData = \ - self._embeddings_api.get_embeddings_csv_write_url_by_id(self.dataset_id, name=name) + response: WriteCSVUrlData = ( + self._embeddings_api.get_embeddings_csv_write_url_by_id( + self.dataset_id, name=name + ) + ) self.embedding_id = response.embedding_id signed_write_url = response.signed_write_url @@ -133,37 +142,32 @@ def upload_embeddings(self, path_to_embeddings_csv: str, name: str): writer = csv.writer(f) writer.writerows(rows_csv) f.seek(0) - embeddings_csv_as_bytes = f.read().encode('utf-8') + embeddings_csv_as_bytes = f.read().encode("utf-8") # write the bytes to a temporary in-memory byte file - with tempfile.SpooledTemporaryFile(mode='r+b') as f_bytes: + with tempfile.SpooledTemporaryFile(mode="r+b") as f_bytes: f_bytes.write(embeddings_csv_as_bytes) f_bytes.seek(0) retry( self.upload_file_with_signed_url, file=f_bytes, - signed_write_url=signed_write_url + signed_write_url=signed_write_url, ) # trigger the 2d embeddings job for dimensionality_reduction_method in [ DimensionalityReductionMethod.PCA, DimensionalityReductionMethod.TSNE, - DimensionalityReductionMethod.UMAP + DimensionalityReductionMethod.UMAP, ]: - body = Trigger2dEmbeddingJobRequest( - dimensionality_reduction_method=dimensionality_reduction_method) + dimensionality_reduction_method=dimensionality_reduction_method + ) self._embeddings_api.trigger2d_embeddings_job( - body=body, - dataset_id=self.dataset_id, - embedding_id=self.embedding_id + body=body, dataset_id=self.dataset_id, embedding_id=self.embedding_id ) - - def append_embeddings(self, - path_to_embeddings_csv: str, - embedding_id: str): + def append_embeddings(self, path_to_embeddings_csv: str, embedding_id: str): """Concatenates the embeddings from the server to the local ones. Loads the embedding csv file belonging to the embedding_id, and @@ -180,43 +184,43 @@ def append_embeddings(self, RuntimeError: If the number of columns in the local and the remote embeddings file mismatch. - + """ # read embedding from API - embedding_read_url = self._embeddings_api \ - .get_embeddings_csv_read_url_by_id(self.dataset_id, embedding_id) + embedding_read_url = self._embeddings_api.get_embeddings_csv_read_url_by_id( + self.dataset_id, embedding_id + ) embedding_reader = self._get_csv_reader_from_read_url(embedding_read_url) rows = list(embedding_reader) header, online_rows = rows[0], rows[1:] # read local embedding - with open(path_to_embeddings_csv, 'r') as f: + with open(path_to_embeddings_csv, "r") as f: local_rows = list(csv.reader(f)) if len(local_rows[0]) != len(header): raise RuntimeError( - 'Column mismatch! Number of columns in local and remote' - f' embeddings files must match but are {len(local_rows[0])}' - f' and {len(header)} respectively.' + "Column mismatch! Number of columns in local and remote" + f" embeddings files must match but are {len(local_rows[0])}" + f" and {len(header)} respectively." ) local_rows = local_rows[1:] # combine online and local embeddings total_rows = [header] - filename_to_local_row = { row[0]: row for row in local_rows } + filename_to_local_row = {row[0]: row for row in local_rows} for row in online_rows: # pick local over online filename if it exists total_rows.append(filename_to_local_row.pop(row[0], row)) # add all local rows which were not added yet total_rows.extend(list(filename_to_local_row.values())) - + # save embeddings again - with open(path_to_embeddings_csv, 'w') as f: + with open(path_to_embeddings_csv, "w") as f: writer = csv.writer(f) writer.writerows(total_rows) - def _order_csv_by_filenames(self, path_to_embeddings_csv: str) -> List[str]: """Orders the rows in a csv according to the order specified on the server and saves it as a new file. @@ -229,27 +233,32 @@ def _order_csv_by_filenames(self, path_to_embeddings_csv: str) -> List[str]: the filepath to the new csv """ - with open(path_to_embeddings_csv, 'r') as f: + with open(path_to_embeddings_csv, "r") as f: data = csv.reader(f) rows = list(data) header_row = rows[0] rows_without_header = rows[1:] - index_filenames = header_row.index('filenames') + index_filenames = header_row.index("filenames") filenames = [row[index_filenames] for row in rows_without_header] filenames_on_server = self.get_filenames() if len(filenames) != len(filenames_on_server): - raise ValueError(f'There are {len(filenames)} rows in the embedding file, but ' - f'{len(filenames_on_server)} filenames/samples on the server.') + raise ValueError( + f"There are {len(filenames)} rows in the embedding file, but " + f"{len(filenames_on_server)} filenames/samples on the server." + ) if set(filenames) != set(filenames_on_server): - raise ValueError(f'The filenames in the embedding file and ' - f'the filenames on the server do not align') + raise ValueError( + f"The filenames in the embedding file and " + f"the filenames on the server do not align" + ) check_filenames(filenames) - rows_without_header_ordered = \ - self._order_list_by_filenames(filenames, rows_without_header) + rows_without_header_ordered = self._order_list_by_filenames( + filenames, rows_without_header + ) rows_csv = [header_row] rows_csv += rows_without_header_ordered diff --git a/lightly/api/api_workflow_upload_metadata.py b/lightly/api/api_workflow_upload_metadata.py index edbf2b6d4..650248b55 100644 --- a/lightly/api/api_workflow_upload_metadata.py +++ b/lightly/api/api_workflow_upload_metadata.py @@ -1,16 +1,23 @@ - +from bisect import bisect_left from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Union -from bisect import bisect_left from tqdm import tqdm from lightly.api.utils import retry +from lightly.openapi_generated.swagger_client.models.configuration_entry import ( + ConfigurationEntry, +) +from lightly.openapi_generated.swagger_client.models.configuration_set_request import ( + ConfigurationSetRequest, +) +from lightly.openapi_generated.swagger_client.models.sample_partial_mode import ( + SamplePartialMode, +) +from lightly.openapi_generated.swagger_client.models.sample_update_request import ( + SampleUpdateRequest, +) from lightly.utils.hipify import print_as_warning -from lightly.openapi_generated.swagger_client.models.sample_update_request import SampleUpdateRequest -from lightly.openapi_generated.swagger_client.models.configuration_entry import ConfigurationEntry -from lightly.openapi_generated.swagger_client.models.configuration_set_request import ConfigurationSetRequest -from lightly.openapi_generated.swagger_client.models.sample_partial_mode import SamplePartialMode from lightly.utils.io import COCO_ANNOTATION_KEYS @@ -19,20 +26,16 @@ class InvalidCustomMetadataWarning(Warning): def _assert_key_exists_in_custom_metadata(key: str, dictionary: Dict): - """Raises a formatted KeyError if key is not a key of the dictionary. - - """ + """Raises a formatted KeyError if key is not a key of the dictionary.""" if key not in dictionary.keys(): raise KeyError( - f'Key {key} not found in custom metadata.\n' - f'Found keys: {dictionary.keys()}' + f"Key {key} not found in custom metadata.\n" + f"Found keys: {dictionary.keys()}" ) class _UploadCustomMetadataMixin: - """Mixin of helpers to allow upload of custom metadata. - - """ + """Mixin of helpers to allow upload of custom metadata.""" def verify_custom_metadata_format(self, custom_metadata: Dict): """Verifies that the custom metadata is in the correct format. @@ -54,8 +57,9 @@ def verify_custom_metadata_format(self, custom_metadata: Dict): COCO_ANNOTATION_KEYS.custom_metadata, custom_metadata ) - def index_custom_metadata_by_filename(self, custom_metadata: Dict)\ - -> Dict[str, Union[Dict, None]]: + def index_custom_metadata_by_filename( + self, custom_metadata: Dict + ) -> Dict[str, Union[Dict, None]]: """Creates an index to lookup custom metadata by filename. Args: @@ -73,30 +77,24 @@ def index_custom_metadata_by_filename(self, custom_metadata: Dict)\ # The mapping is filename -> image_id -> custom_metadata # This mapping is created in linear time. filename_to_image_id = { - image_info[COCO_ANNOTATION_KEYS.images_filename]: - image_info[COCO_ANNOTATION_KEYS.images_id] - for image_info - in custom_metadata[COCO_ANNOTATION_KEYS.images] + image_info[COCO_ANNOTATION_KEYS.images_filename]: image_info[ + COCO_ANNOTATION_KEYS.images_id + ] + for image_info in custom_metadata[COCO_ANNOTATION_KEYS.images] } image_id_to_custom_metadata = { - metadata[COCO_ANNOTATION_KEYS.custom_metadata_image_id]: - metadata - for metadata - in custom_metadata[COCO_ANNOTATION_KEYS.custom_metadata] + metadata[COCO_ANNOTATION_KEYS.custom_metadata_image_id]: metadata + for metadata in custom_metadata[COCO_ANNOTATION_KEYS.custom_metadata] } filename_to_metadata = { filename: image_id_to_custom_metadata.get(image_id, None) - for (filename, image_id) - in filename_to_image_id.items() + for (filename, image_id) in filename_to_image_id.items() } return filename_to_metadata - - - def upload_custom_metadata(self, - custom_metadata: Dict, - verbose: bool = False, - max_workers: int = 8): + def upload_custom_metadata( + self, custom_metadata: Dict, verbose: bool = False, max_workers: int = 8 + ): """Uploads custom metadata to the Lightly platform. The custom metadata is expected in a format similar to the COCO annotations: @@ -148,28 +146,24 @@ def upload_custom_metadata(self, self.verify_custom_metadata_format(custom_metadata) - - # For each metadata, we need the corresponding sample_id # on the server. The mapping is: # metadata -> image_id -> filename -> sample_id image_id_to_filename = { - image_info[COCO_ANNOTATION_KEYS.images_id]: - image_info[COCO_ANNOTATION_KEYS.images_filename] + image_info[COCO_ANNOTATION_KEYS.images_id]: image_info[ + COCO_ANNOTATION_KEYS.images_filename + ] for image_info in custom_metadata[COCO_ANNOTATION_KEYS.images] } samples = retry( self._samples_api.get_samples_partial_by_dataset_id, dataset_id=self.dataset_id, - mode=SamplePartialMode.FILENAMES + mode=SamplePartialMode.FILENAMES, ) - filename_to_sample_id = { - sample.file_name: sample.id - for sample in samples - } + filename_to_sample_id = {sample.file_name: sample.id for sample in samples} upload_requests = [] for metadata in custom_metadata[COCO_ANNOTATION_KEYS.custom_metadata]: @@ -177,21 +171,21 @@ def upload_custom_metadata(self, filename = image_id_to_filename.get(image_id, None) if filename is None: print_as_warning( - f'No image found for custom metadata annotation ' - f'with image_id {image_id}. ' - f'This custom metadata annotation is skipped. ', - InvalidCustomMetadataWarning + f"No image found for custom metadata annotation " + f"with image_id {image_id}. " + f"This custom metadata annotation is skipped. ", + InvalidCustomMetadataWarning, ) continue sample_id = filename_to_sample_id.get(filename, None) if sample_id is None: print_as_warning( - f'You tried to upload custom metadata for a sample with ' - f'filename {{{filename}}}, ' - f'but a sample with this filename ' - f'does not exist on the server. ' - f'This custom metadata annotation is skipped. ', - InvalidCustomMetadataWarning + f"You tried to upload custom metadata for a sample with " + f"filename {{{filename}}}, " + f"but a sample with this filename " + f"does not exist on the server. " + f"This custom metadata annotation is skipped. ", + InvalidCustomMetadataWarning, ) continue upload_request = (metadata, sample_id) @@ -213,18 +207,12 @@ def upload_sample_metadata(upload_request): # get iterator over results results = executor.map(upload_sample_metadata, upload_requests) if verbose: - results = tqdm( - results, - unit='metadata', - total=len(upload_requests) - ) + results = tqdm(results, unit="metadata", total=len(upload_requests)) # iterate over results to make sure they are completed list(results) def create_custom_metadata_config( - self, - name: str, - configs: List[ConfigurationEntry] + self, name: str, configs: List[ConfigurationEntry] ): """Creates custom metadata config from a list of configurations. @@ -245,13 +233,13 @@ def create_custom_metadata_config( >>> default_value='unknown', >>> value_data_type='CATEGORICAL_STRING', >>> ) - >>> + >>> >>> client.create_custom_metadata_config( >>> 'My Custom Metadata', >>> [entry], >>> ) - - + + """ config_set_request = ConfigurationSetRequest(name=name, configs=configs) resp = self._metadata_configurations_api.create_meta_data_configuration( @@ -259,4 +247,3 @@ def create_custom_metadata_config( dataset_id=self.dataset_id, ) return resp - diff --git a/lightly/api/bitmask.py b/lightly/api/bitmask.py index 893e718d6..ddb935571 100644 --- a/lightly/api/bitmask.py +++ b/lightly/api/bitmask.py @@ -7,32 +7,27 @@ def _hex_to_int(hexstring: str) -> int: - """Converts a hex string representation of an integer to an integer. - """ + """Converts a hex string representation of an integer to an integer.""" return int(hexstring, 16) def _bin_to_int(binstring: str) -> int: - """Converts a binary string representation of an integer to an integer. - """ + """Converts a binary string representation of an integer to an integer.""" return int(binstring, 2) def _int_to_hex(x: int) -> str: - """Converts an integer to a hex string representation. - """ + """Converts an integer to a hex string representation.""" return hex(x) def _int_to_bin(x: int) -> str: - """Converts an integer to a binary string representation. - """ + """Converts an integer to a binary string representation.""" return bin(x) def _get_nonzero_bits(x: int) -> List[int]: - """Returns a list of indices of nonzero bits in x. - """ + """Returns a list of indices of nonzero bits in x.""" offset = 0 nonzero_bit_indices = [] while x > 0: @@ -46,41 +41,35 @@ def _get_nonzero_bits(x: int) -> List[int]: def _invert(x: int, total_size: int) -> int: - """Flips every bit of x as if x was an unsigned integer. - """ + """Flips every bit of x as if x was an unsigned integer.""" # use XOR of x and 0xFFFFFF to get the inverse - return x ^ (2 ** total_size - 1) + return x ^ (2**total_size - 1) def _union(x: int, y: int) -> int: - """Uses bitwise OR to get the union of the two masks. - """ + """Uses bitwise OR to get the union of the two masks.""" return x | y def _intersection(x: int, y: int) -> int: - """Uses bitwise AND to get the intersection of the two masks. - """ + """Uses bitwise AND to get the intersection of the two masks.""" return x & y def _get_kth_bit(x: int, k: int) -> int: - """Returns the kth bit in the mask from the right. - """ + """Returns the kth bit in the mask from the right.""" mask = 1 << k return x & mask def _set_kth_bit(x: int, k: int) -> int: - """Sets the kth bit in the mask from the right. - """ + """Sets the kth bit in the mask from the right.""" mask = 1 << k return x | mask def _unset_kth_bit(x: int, k: int) -> int: - """Clears the kth bit in the mask from the right. - """ + """Clears the kth bit in the mask from the right.""" mask = ~(1 << k) return x & mask @@ -109,31 +98,26 @@ def __init__(self, x): @classmethod def from_hex(cls, hexstring: str): - """Creates a bit mask object from a hexstring. - """ + """Creates a bit mask object from a hexstring.""" return cls(_hex_to_int(hexstring)) @classmethod def from_bin(cls, binstring: str): - """Creates a BitMask from a binary string. - """ + """Creates a BitMask from a binary string.""" return cls(_bin_to_int(binstring)) @classmethod def from_length(cls, length: int): - """Creates a all-true bitmask of a predefined length - """ - binstring = '0b' + '1' * length + """Creates a all-true bitmask of a predefined length""" + binstring = "0b" + "1" * length return cls.from_bin(binstring) def to_hex(self): - """Creates a BitMask from a hex string. - """ + """Creates a BitMask from a hex string.""" return _int_to_hex(self.x) def to_bin(self): - """Returns a binary string representing the bit mask. - """ + """Returns a binary string representing the bit mask.""" return _int_to_bin(self.x) def to_indices(self) -> List[int]: @@ -156,8 +140,7 @@ def invert(self, total_size: int): self.x = _invert(self.x, total_size) def complement(self): - """Same as invert but with the appropriate name. - """ + """Same as invert but with the appropriate name.""" self.invert() def union(self, other): @@ -216,8 +199,7 @@ def masked_select_from_list(self, list_: List): return [list_[index] for index in indices] def get_kth_bit(self, k: int) -> bool: - """Returns the boolean value of the kth bit from the right. - """ + """Returns the boolean value of the kth bit from the right.""" return _get_kth_bit(self.x, k) > 0 def set_kth_bit(self, k: int): diff --git a/lightly/api/download.py b/lightly/api/download.py index 236001741..924dc7214 100644 --- a/lightly/api/download.py +++ b/lightly/api/download.py @@ -10,6 +10,7 @@ import PIL import requests import tqdm + from lightly.api import utils try: @@ -21,14 +22,16 @@ "installation instructions." ) -DEFAULT_VIDEO_TIMEOUT = 60 * 5 # seconds +DEFAULT_VIDEO_TIMEOUT = 60 * 5 # seconds + def _check_av_available() -> None: if isinstance(av, Exception): raise av + def download_image( - url: str, + url: str, session: requests.Session = None, retry_fn: Callable = utils.retry, request_kwargs: Optional[Dict] = None, @@ -36,9 +39,9 @@ def download_image( """Downloads an image from a url. Args: - url: + url: The url where the image is downloaded from. - session: + session: Session object to persist certain parameters across requests. retry_fn: Retry function that handles failed downloads. @@ -50,9 +53,9 @@ def download_image( """ request_kwargs = request_kwargs or {} - request_kwargs.setdefault('stream', True) - request_kwargs.setdefault('timeout', 10) - + request_kwargs.setdefault("stream", True) + request_kwargs.setdefault("timeout", 10) + def load_image(url, req, request_kwargs): with req.get(url=url, **request_kwargs) as response: response.raise_for_status() @@ -64,6 +67,7 @@ def load_image(url, req, request_kwargs): image = retry_fn(load_image, url, req, request_kwargs) return image + if not isinstance(av, ModuleNotFoundError): def download_all_video_frames( @@ -130,9 +134,9 @@ def download_all_video_frames( else: yield frame - - def download_video_frame(url: str, timestamp: int, *args, **kwargs - ) -> Union[PIL.Image.Image, av.VideoFrame, None]: + def download_video_frame( + url: str, timestamp: int, *args, **kwargs + ) -> Union[PIL.Image.Image, av.VideoFrame, None]: """ Wrapper around download_video_frames_at_timestamps for downloading only a single frame. @@ -143,7 +147,6 @@ def download_video_frame(url: str, timestamp: int, *args, **kwargs frames = list(frames) return frames[0] - def video_frame_count( url: str, video_channel: int = 0, @@ -186,7 +189,7 @@ def video_frame_count( with retry_fn(av.open, url, timeout=timeout) as container: stream = container.streams.video[video_channel] num_frames = 0 if ignore_metadata else stream.frames - # If number of frames not stored in the video file we have to decode all + # If number of frames not stored in the video file we have to decode all # frames and count them. if num_frames == 0: stream.thread_type = thread_type @@ -201,7 +204,9 @@ def all_video_frame_counts( thread_type: av.codec.context.ThreadType = av.codec.context.ThreadType.AUTO, ignore_metadata: bool = False, retry_fn: Callable = utils.retry, - exceptions_indicating_empty_video: Tuple[Type[BaseException], ...] = (RuntimeError,), + exceptions_indicating_empty_video: Tuple[Type[BaseException], ...] = ( + RuntimeError, + ), progress_bar: Optional[tqdm.tqdm] = None, ) -> List[Optional[int]]: """Finds the number of frames in the videos at the given urls. @@ -259,138 +264,131 @@ def job(url): if count is not None: total_count += count progress_bar.update(1) - progress_bar.set_description(f'Total frames found: {total_count}') + progress_bar.set_description(f"Total frames found: {total_count}") return frame_counts def download_video_frames_at_timestamps( - url: str, - timestamps: List[int], - as_pil_image: int = True, - thread_type: av.codec.context.ThreadType = av.codec.context.ThreadType.AUTO, - video_channel: int = 0, - seek_to_first_frame: bool = True, - retry_fn: Callable = utils.retry, - timeout: Optional[Union[float, Tuple[float, float]]] = DEFAULT_VIDEO_TIMEOUT, - ) -> Iterable[Union[PIL.Image.Image, av.VideoFrame]]: - """Lazily retrieves frames from a video at a specific timestamp stored at the given url. - - Args: - url: - The url where video is downloaded from. - timestamps: - Timestamps in pts from the start of the video. The images - at these timestamps are returned. - The timestamps must be strictly monotonically ascending. - See https://pyav.org/docs/develop/api/time.html#time - for details on pts. - as_pil_image: - Whether to return the frame as PIL.Image. - thread_type: - Which multithreading method to use for decoding the video. - See https://pyav.org/docs/stable/api/codec.html#av.codec.context.ThreadType - for details. - video_channel: - The video channel from which frames are loaded. - seek_to_first_frame: - Boolean indicating whether to seek to the first frame. - retry_fn: - Retry function that handles errors when opening the video container. - timeout: - Time in seconds to wait for new video data before giving up. - Timeout must either be an (open_timeout, read_timeout) tuple - or a single value which will be used as open and read timeout. - Timeouts only apply to individual steps during the download, - the complete video download can take much longer. - See https://pyav.org/docs/stable/api/_globals.html?highlight=av%20open#av.open - for details. - - Returns: - A generator that loads and returns a single frame per step. - - """ - _check_av_available() - - if len(timestamps) == 0: - return [] - - if any( - timestamps[i+1] <= timestamps[i] - for i - in range(len(timestamps) - 1) - ): - raise ValueError("The timestamps must be sorted " - "strictly monotonically ascending, but are not.") - min_timestamp = timestamps[0] - - if min_timestamp < 0: - raise ValueError(f"Negative timestamp is not allowed: {min_timestamp}") - - with retry_fn(av.open, url, timeout=timeout) as container: - stream = container.streams.video[video_channel] - stream.thread_type = thread_type - - if seek_to_first_frame: - # seek to last keyframe before the min_timestamp - container.seek( - min_timestamp, - any_frame=False, - backward=True, - stream=stream - ) - - index_timestamp = 0 - for frame in container.decode(stream): + url: str, + timestamps: List[int], + as_pil_image: int = True, + thread_type: av.codec.context.ThreadType = av.codec.context.ThreadType.AUTO, + video_channel: int = 0, + seek_to_first_frame: bool = True, + retry_fn: Callable = utils.retry, + timeout: Optional[Union[float, Tuple[float, float]]] = DEFAULT_VIDEO_TIMEOUT, + ) -> Iterable[Union[PIL.Image.Image, av.VideoFrame]]: + """Lazily retrieves frames from a video at a specific timestamp stored at the given url. - # advance from keyframe until correct timestamp is reached - if frame.pts > timestamps[index_timestamp]: + Args: + url: + The url where video is downloaded from. + timestamps: + Timestamps in pts from the start of the video. The images + at these timestamps are returned. + The timestamps must be strictly monotonically ascending. + See https://pyav.org/docs/develop/api/time.html#time + for details on pts. + as_pil_image: + Whether to return the frame as PIL.Image. + thread_type: + Which multithreading method to use for decoding the video. + See https://pyav.org/docs/stable/api/codec.html#av.codec.context.ThreadType + for details. + video_channel: + The video channel from which frames are loaded. + seek_to_first_frame: + Boolean indicating whether to seek to the first frame. + retry_fn: + Retry function that handles errors when opening the video container. + timeout: + Time in seconds to wait for new video data before giving up. + Timeout must either be an (open_timeout, read_timeout) tuple + or a single value which will be used as open and read timeout. + Timeouts only apply to individual steps during the download, + the complete video download can take much longer. + See https://pyav.org/docs/stable/api/_globals.html?highlight=av%20open#av.open + for details. - # dropped frames! - break + Returns: + A generator that loads and returns a single frame per step. - # it's ok to check by equality because timestamps are ints - if frame.pts == timestamps[index_timestamp]: + """ + _check_av_available() - # yield next frame - if as_pil_image: - yield frame.to_image() - else: - yield frame + if len(timestamps) == 0: + return [] - # update the timestamp - index_timestamp += 1 + if any(timestamps[i + 1] <= timestamps[i] for i in range(len(timestamps) - 1)): + raise ValueError( + "The timestamps must be sorted " + "strictly monotonically ascending, but are not." + ) + min_timestamp = timestamps[0] - if index_timestamp >= len(timestamps): - return + if min_timestamp < 0: + raise ValueError(f"Negative timestamp is not allowed: {min_timestamp}") - leftovers = timestamps[index_timestamp:] + with retry_fn(av.open, url, timeout=timeout) as container: + stream = container.streams.video[video_channel] + stream.thread_type = thread_type - # sometimes frames are skipped when we seek to the first frame - # let's retry downloading these frames without seeking - retry_skipped_timestamps = seek_to_first_frame - if retry_skipped_timestamps: - warnings.warn( - f'Timestamps {leftovers} could not be decoded! Retrying from the start...' - ) - frames = download_video_frames_at_timestamps( - url, - leftovers, - as_pil_image=as_pil_image, - thread_type=thread_type, - video_channel=video_channel, - seek_to_first_frame=False, - retry_fn=retry_fn, + if seek_to_first_frame: + # seek to last keyframe before the min_timestamp + container.seek( + min_timestamp, any_frame=False, backward=True, stream=stream ) - for frame in frames: - yield frame - return - raise RuntimeError( - f'Timestamps {leftovers} in video {url} could not be decoded!' + index_timestamp = 0 + for frame in container.decode(stream): + # advance from keyframe until correct timestamp is reached + if frame.pts > timestamps[index_timestamp]: + # dropped frames! + break + + # it's ok to check by equality because timestamps are ints + if frame.pts == timestamps[index_timestamp]: + # yield next frame + if as_pil_image: + yield frame.to_image() + else: + yield frame + + # update the timestamp + index_timestamp += 1 + + if index_timestamp >= len(timestamps): + return + + leftovers = timestamps[index_timestamp:] + + # sometimes frames are skipped when we seek to the first frame + # let's retry downloading these frames without seeking + retry_skipped_timestamps = seek_to_first_frame + if retry_skipped_timestamps: + warnings.warn( + f"Timestamps {leftovers} could not be decoded! Retrying from the start..." + ) + frames = download_video_frames_at_timestamps( + url, + leftovers, + as_pil_image=as_pil_image, + thread_type=thread_type, + video_channel=video_channel, + seek_to_first_frame=False, + retry_fn=retry_fn, ) + for frame in frames: + yield frame + return + + raise RuntimeError( + f"Timestamps {leftovers} in video {url} could not be decoded!" + ) def download_and_write_file( - url: str, output_path: str, + url: str, + output_path: str, session: requests.Session = None, retry_fn: Callable = utils.retry, request_kwargs: Optional[Dict] = None, @@ -410,8 +408,8 @@ def download_and_write_file( Additional parameters passed to requests.get(). """ request_kwargs = request_kwargs or {} - request_kwargs.setdefault('stream', True) - request_kwargs.setdefault('timeout', 10) + request_kwargs.setdefault("stream", True) + request_kwargs.setdefault("timeout", 10) req = requests if session is None else session out_path = pathlib.Path(output_path) out_path.parent.mkdir(parents=True, exist_ok=True) @@ -480,12 +478,12 @@ def job(**kwargs): with ThreadPoolExecutor(max_workers=max_workers) as executor: futures_to_file_info = { executor.submit( - job, - file_info=file_info, - output_dir=output_dir, - lock=lock, - sessions=sessions, - retry_fn=retry_fn, + job, + file_info=file_info, + output_dir=output_dir, + lock=lock, + sessions=sessions, + retry_fn=retry_fn, request_kwargs=request_kwargs, ): file_info for file_info in file_infos @@ -500,6 +498,7 @@ def job(**kwargs): except Exception as ex: warnings.warn(f"Could not download {filename} from {url}") + def download_prediction_file( url: str, session: requests.Session = None, @@ -512,6 +511,7 @@ def download_prediction_file( """ return download_json_file(url, session=session, request_kwargs=request_kwargs) + def download_json_file( url: str, session: requests.Session = None, @@ -522,7 +522,7 @@ def download_json_file( Args: url: Url of the file to download. - session: + session: Session object to persist certain parameters across requests. request_kwargs: Additional parameters passed to requests.get(). @@ -532,8 +532,8 @@ def download_json_file( """ request_kwargs = request_kwargs or {} - request_kwargs.setdefault('stream', True) - request_kwargs.setdefault('timeout', 10) + request_kwargs.setdefault("stream", True) + request_kwargs.setdefault("timeout", 10) req = requests if session is None else session response = req.get(url, **request_kwargs) diff --git a/lightly/api/patch_rest_client.py b/lightly/api/patch_rest_client.py index 1a43cb42c..ade415872 100644 --- a/lightly/api/patch_rest_client.py +++ b/lightly/api/patch_rest_client.py @@ -24,9 +24,17 @@ def patch_rest_client(rest_client: Type): """ request = rest_client.request - def request_patched(self, method, url, query_params=None, headers=None, - body=None, post_params=None, _preload_content=True, - _request_timeout=None): + def request_patched( + self, + method, + url, + query_params=None, + headers=None, + body=None, + post_params=None, + _preload_content=True, + _request_timeout=None, + ): if query_params is not None: new_query_params = [] for name, value in query_params: @@ -35,6 +43,16 @@ def request_patched(self, method, url, query_params=None, headers=None, else: new_query_params.append((name, value)) query_params = new_query_params - return request(self, method=method, url=url, query_params=query_params, headers=headers, body=body, post_params=post_params, _preload_content=_preload_content, _request_timeout=_request_timeout) - - rest_client.request = request_patched \ No newline at end of file + return request( + self, + method=method, + url=url, + query_params=query_params, + headers=headers, + body=body, + post_params=post_params, + _preload_content=_preload_content, + _request_timeout=_request_timeout, + ) + + rest_client.request = request_patched diff --git a/lightly/api/prediction_singletons.py b/lightly/api/prediction_singletons.py index b05fefeb3..36ccf8704 100644 --- a/lightly/api/prediction_singletons.py +++ b/lightly/api/prediction_singletons.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Optional, List +from typing import List, Optional from lightly.openapi_generated.swagger_client import TaskType @@ -84,7 +84,7 @@ def __init__( self.probabilities = probabilities -# Not used +# Not used class PredictionSingletonInstanceSegmentationRepr(PredictionSingletonRepr): def __init__( self, diff --git a/lightly/api/utils.py b/lightly/api/utils.py index 9018059f9..01f892fcc 100644 --- a/lightly/api/utils.py +++ b/lightly/api/utils.py @@ -5,14 +5,15 @@ import io import os -import time import random +import time from enum import Enum from typing import List, Optional # the following two lines are needed because # PIL misidentifies certain jpeg images as MPOs from PIL import JpegImagePlugin + JpegImagePlugin._getmp = lambda: None from lightly.openapi_generated.swagger_client.configuration import Configuration @@ -42,7 +43,7 @@ def retry(func, *args, **kwargs): """ # config - backoff = 1. + random.random() * 0.1 + backoff = 1.0 + random.random() * 0.1 max_backoff = RETRY_MAX_BACKOFF max_retries = RETRY_MAX_RETRIES @@ -61,8 +62,8 @@ def retry(func, *args, **kwargs): # max retries exceeded if current_retries >= max_retries: raise RuntimeError( - f'Maximum retries exceeded! Original exception: {type(e)}: {str(e)}') from e - + f"Maximum retries exceeded! Original exception: {type(e)}: {str(e)}" + ) from e def paginate_endpoint(fn, page_size=5000, *args, **kwargs) -> List: @@ -91,11 +92,9 @@ def paginate_endpoint(fn, page_size=5000, *args, **kwargs) -> List: return entries - - def getenv(key: str, default: str): """Return the value of the environment variable key if it exists, - or default if it doesn’t. + or default if it doesn’t. """ try: @@ -109,15 +108,13 @@ def getenv(key: str, default: str): return default -def PIL_to_bytes(img, ext: str = 'png', quality: int = None): - """Return the PIL image as byte stream. Useful to send image via requests. - - """ +def PIL_to_bytes(img, ext: str = "png", quality: int = None): + """Return the PIL image as byte stream. Useful to send image via requests.""" bytes_io = io.BytesIO() if quality is not None: img.save(bytes_io, format=ext, quality=quality) else: - subsampling = -1 if ext.lower() in ['jpg', 'jpeg'] else 0 + subsampling = -1 if ext.lower() in ["jpg", "jpeg"] else 0 img.save(bytes_io, format=ext, quality=100, subsampling=subsampling) bytes_io.seek(0) return bytes_io @@ -134,10 +131,12 @@ def check_filename(basename): return len(basename) <= MAXIMUM_FILENAME_LENGTH -def build_azure_signed_url_write_headers(content_length: str, - x_ms_blob_type: str = 'BlockBlob', - accept: str = '*/*', - accept_encoding: str = '*'): +def build_azure_signed_url_write_headers( + content_length: str, + x_ms_blob_type: str = "BlockBlob", + accept: str = "*/*", + accept_encoding: str = "*", +): """Builds the headers required for a SAS PUT to Azure blob storage. Args: @@ -155,11 +154,11 @@ def build_azure_signed_url_write_headers(content_length: str, """ headers = { - 'x-ms-blob-type': x_ms_blob_type, - 'Accept': accept, - 'Content-Length': content_length, - 'x-ms-original-content-length': content_length, - 'Accept-Encoding': accept_encoding, + "x-ms-blob-type": x_ms_blob_type, + "Accept": accept, + "Content-Length": content_length, + "x-ms-original-content-length": content_length, + "Accept-Encoding": accept_encoding, } return headers @@ -171,7 +170,7 @@ class DatasourceType(Enum): LOCAL = "LOCAL" -def get_signed_url_destination(signed_url: str = '') -> DatasourceType: +def get_signed_url_destination(signed_url: str = "") -> DatasourceType: """ Tries to figure out the of which cloud provider/datasource type a signed url comes from (S3, GCS, Azure) Args: @@ -183,11 +182,11 @@ def get_signed_url_destination(signed_url: str = '') -> DatasourceType: assert isinstance(signed_url, str) - if 'storage.googleapis.com/' in signed_url: + if "storage.googleapis.com/" in signed_url: return DatasourceType.GCS - if '.amazonaws.com/' in signed_url and '.s3.' in signed_url: + if ".amazonaws.com/" in signed_url and ".s3." in signed_url: return DatasourceType.S3 - if '.windows.net/' in signed_url: + if ".windows.net/" in signed_url: return DatasourceType.AZURE # default to local as it must be some special setup return DatasourceType.LOCAL @@ -197,7 +196,6 @@ def get_api_client_configuration( token: Optional[str] = None, raise_if_no_token_specified: bool = True, ) -> Configuration: - host = getenv("LIGHTLY_SERVER_LOCATION", "https://api.lightly.ai") ssl_ca_cert = getenv("LIGHTLY_CA_CERTS", None) @@ -213,4 +211,4 @@ def get_api_client_configuration( configuration.ssl_ca_cert = ssl_ca_cert configuration.host = host - return configuration \ No newline at end of file + return configuration diff --git a/lightly/api/version_checking.py b/lightly/api/version_checking.py index a9bed727c..fa0c69fb8 100644 --- a/lightly/api/version_checking.py +++ b/lightly/api/version_checking.py @@ -2,16 +2,16 @@ import warnings from typing import Tuple +from lightly.api import utils from lightly.openapi_generated.swagger_client import VersioningApi from lightly.openapi_generated.swagger_client.api_client import ApiClient - -from lightly.api import utils from lightly.utils.version_compare import version_compare class LightlyAPITimeoutException(Exception): pass + class TimeoutDecorator: def __init__(self, seconds): self.seconds = seconds @@ -55,7 +55,9 @@ def get_versioning_api() -> VersioningApi: def get_latest_version(current_version: str) -> Tuple[None, str]: try: versioning_api = get_versioning_api() - version_number: str = versioning_api.get_latest_pip_version(current_version=current_version) + version_number: str = versioning_api.get_latest_pip_version( + current_version=current_version + ) return version_number except Exception as e: return None @@ -68,8 +70,10 @@ def get_minimum_compatible_version(): def pretty_print_latest_version(current_version, latest_version, width=70): - warning = f"You are using lightly version {current_version}. " \ - f"There is a newer version of the package available. " \ - f"For compatability reasons, please upgrade your current version: " \ - f"pip install lightly=={latest_version}" + warning = ( + f"You are using lightly version {current_version}. " + f"There is a newer version of the package available. " + f"For compatability reasons, please upgrade your current version: " + f"pip install lightly=={latest_version}" + ) warnings.warn(Warning(warning)) diff --git a/lightly/cli/__init__.py b/lightly/cli/__init__.py index 5eb16d522..4cf371932 100644 --- a/lightly/cli/__init__.py +++ b/lightly/cli/__init__.py @@ -6,9 +6,9 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from lightly.cli.crop_cli import crop_cli +from lightly.cli.download_cli import download_cli +from lightly.cli.embed_cli import embed_cli from lightly.cli.lightly_cli import lightly_cli from lightly.cli.train_cli import train_cli -from lightly.cli.embed_cli import embed_cli from lightly.cli.upload_cli import upload_cli -from lightly.cli.download_cli import download_cli -from lightly.cli.crop_cli import crop_cli diff --git a/lightly/cli/_cli_simclr.py b/lightly/cli/_cli_simclr.py index 282602896..e16996667 100644 --- a/lightly/cli/_cli_simclr.py +++ b/lightly/cli/_cli_simclr.py @@ -14,29 +14,22 @@ class _SimCLR(nn.Module): """Implementation of SimCLR used by the command-line interface. - Provides backwards compatability with old checkpoints. + Provides backwards compatability with old checkpoints. """ - def __init__(self, backbone: nn.Module, num_ftrs: int = 32, - out_dim: int = 128): - + def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128): super(_SimCLR, self).__init__() self.backbone = backbone - self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs, - out_dim) - + self.projection_head = SimCLRProjectionHead(num_ftrs, num_ftrs, out_dim) def forward(self, x0: torch.Tensor, x1: torch.Tensor = None): - """Embeds and projects the input images. - - """ + """Embeds and projects the input images.""" # forward pass of first input x0 f0 = self.backbone(x0).flatten(start_dim=1) out0 = self.projection_head(f0) - # return out0 if x1 is None if x1 is None: return out0 @@ -46,4 +39,4 @@ def forward(self, x0: torch.Tensor, x1: torch.Tensor = None): out1 = self.projection_head(f1) # return both outputs - return out0, out1 \ No newline at end of file + return out0, out1 diff --git a/lightly/cli/_helpers.py b/lightly/cli/_helpers.py index ac4449d1c..31b6ce5f1 100644 --- a/lightly/cli/_helpers.py +++ b/lightly/cli/_helpers.py @@ -4,22 +4,17 @@ # All Rights Reserved import os -import torch import hydra +import torch from hydra import utils from torch import nn as nn -from lightly.utils.version_compare import version_compare - from lightly.cli._cli_simclr import _SimCLR from lightly.embedding import SelfSupervisedEmbedding - -from lightly.models import ZOO as model_zoo, ResNetGenerator +from lightly.models import ZOO as model_zoo +from lightly.models import ResNetGenerator from lightly.models.batchnorm import get_norm_layer - - - - +from lightly.utils.version_compare import version_compare def cpu_count(): @@ -32,87 +27,73 @@ def cpu_count(): def fix_input_path(path): - """Fix broken relative paths. - - """ + """Fix broken relative paths.""" if not os.path.isabs(path): path = utils.to_absolute_path(path) return path -def fix_hydra_arguments(config_path: str = 'config', config_name: str = 'config'): +def fix_hydra_arguments(config_path: str = "config", config_name: str = "config"): """Helper to make hydra arugments adaptive to installed hydra version - + Hydra introduced the `version_base` argument in version 1.2.0 - We use this helper to provide backwards compatibility to older hydra verisons. + We use this helper to provide backwards compatibility to older hydra verisons. """ - hydra_args = {'config_path': config_path, 'config_name': config_name} + hydra_args = {"config_path": config_path, "config_name": config_name} try: - if version_compare(hydra.__version__, '1.1.2') > 0: - hydra_args['version_base'] = '1.1' + if version_compare(hydra.__version__, "1.1.2") > 0: + hydra_args["version_base"] = "1.1" except ValueError: pass - + return hydra_args def is_url(checkpoint): - """Check whether the checkpoint is a url or not. - - """ - is_url = ('https://storage.googleapis.com' in checkpoint) + """Check whether the checkpoint is a url or not.""" + is_url = "https://storage.googleapis.com" in checkpoint return is_url def get_ptmodel_from_config(model): - """Get a pre-trained model from the lightly model zoo. - - """ - key = model['name'] - key += '/simclr' - key += '/d' + str(model['num_ftrs']) - key += '/w' + str(float(model['width'])) + """Get a pre-trained model from the lightly model zoo.""" + key = model["name"] + key += "/simclr" + key += "/d" + str(model["num_ftrs"]) + key += "/w" + str(float(model["width"])) if key in model_zoo.keys(): return model_zoo[key], key else: - return '', key + return "", key def load_state_dict_from_url(url, map_location=None): - """Try to load the checkopint from the given url. - - """ + """Try to load the checkopint from the given url.""" try: - state_dict = torch.hub.load_state_dict_from_url( - url, map_location=map_location - ) + state_dict = torch.hub.load_state_dict_from_url(url, map_location=map_location) return state_dict except Exception: - print('Not able to load state dict from %s' % (url)) - print('Retrying with http:// prefix') + print("Not able to load state dict from %s" % (url)) + print("Retrying with http:// prefix") try: - url = url.replace('https', 'http') - state_dict = torch.hub.load_state_dict_from_url( - url, map_location=map_location - ) + url = url.replace("https", "http") + state_dict = torch.hub.load_state_dict_from_url(url, map_location=map_location) return state_dict except Exception: - print('Not able to load state dict from %s' % (url)) + print("Not able to load state dict from %s" % (url)) # in this case downloading the pre-trained model was not possible # notify the user and return - return {'state_dict': None} + return {"state_dict": None} def _maybe_expand_batchnorm_weights(model_dict, state_dict, num_splits): - """Expands the weights of the BatchNorm2d to the size of SplitBatchNorm. - - """ - running_mean = 'running_mean' - running_var = 'running_var' + """Expands the weights of the BatchNorm2d to the size of SplitBatchNorm.""" + running_mean = "running_mean" + running_var = "running_var" for key, item in model_dict.items(): # not batchnorm -> continue @@ -135,25 +116,24 @@ def _maybe_expand_batchnorm_weights(model_dict, state_dict, num_splits): def _filter_state_dict(state_dict, remove_model_prefix_offset: int = 1): """Makes the state_dict compatible with the model. - + Prevents unexpected key error when loading PyTorch-Lightning checkpoints. Allows backwards compatability to checkpoints before v1.0.6. """ - prev_backbone = 'features' - curr_backbone = 'backbone' + prev_backbone = "features" + curr_backbone = "backbone" new_state_dict = {} for key, item in state_dict.items(): # remove the "model." prefix from the state dict key - key_parts = key.split('.')[remove_model_prefix_offset:] + key_parts = key.split(".")[remove_model_prefix_offset:] # with v1.0.6 the backbone of the models will be renamed from # "features" to "backbone", ensure compatability with old ckpts - key_parts = \ - [k if k != prev_backbone else curr_backbone for k in key_parts] + key_parts = [k if k != prev_backbone else curr_backbone for k in key_parts] - new_key = '.'.join(key_parts) + new_key = ".".join(key_parts) new_state_dict[new_key] = item return new_state_dict @@ -166,22 +146,22 @@ def _fix_projection_head_keys(state_dict): replaced! Relevant issue: https://github.com/lightly-ai/lightly/issues/379 Prevents unexpected key error when loading old checkpoints. - + """ - projection_head_identifier = 'projection_head' - prediction_head_identifier = 'prediction_head' - projection_head_insert = 'layers' + projection_head_identifier = "projection_head" + prediction_head_identifier = "prediction_head" + projection_head_insert = "layers" new_state_dict = {} for key, item in state_dict.items(): - if (projection_head_identifier in key or \ - prediction_head_identifier in key) and \ - projection_head_insert not in key: + if ( + projection_head_identifier in key or prediction_head_identifier in key + ) and projection_head_insert not in key: # insert layers if it's not part of the key yet - key_parts = key.split('.') + key_parts = key.split(".") key_parts.insert(1, projection_head_insert) - new_key = '.'.join(key_parts) + new_key = ".".join(key_parts) else: new_key = key @@ -190,14 +170,14 @@ def _fix_projection_head_keys(state_dict): return new_state_dict -def load_from_state_dict(model, - state_dict, - strict: bool = True, - apply_filter: bool = True, - num_splits: int = 0): - """Loads the model weights from the state dictionary. - - """ +def load_from_state_dict( + model, + state_dict, + strict: bool = True, + apply_filter: bool = True, + num_splits: int = 0, +): + """Loads the model weights from the state dictionary.""" # step 1: filter state dict if apply_filter: @@ -206,45 +186,46 @@ def load_from_state_dict(model, state_dict = _fix_projection_head_keys(state_dict) # step 2: expand batchnorm weights - state_dict = \ - _maybe_expand_batchnorm_weights(model.state_dict(), state_dict, num_splits) + state_dict = _maybe_expand_batchnorm_weights( + model.state_dict(), state_dict, num_splits + ) # step 3: load from checkpoint model.load_state_dict(state_dict, strict=strict) def get_model_from_config(cfg, is_cli_call: bool = False) -> SelfSupervisedEmbedding: - checkpoint = cfg['checkpoint'] + checkpoint = cfg["checkpoint"] if torch.cuda.is_available(): - device = torch.device('cuda') + device = torch.device("cuda") else: - device = torch.device('cpu') + device = torch.device("cpu") if not checkpoint: - checkpoint, key = get_ptmodel_from_config(cfg['model']) + checkpoint, key = get_ptmodel_from_config(cfg["model"]) if not checkpoint: - msg = 'Cannot download checkpoint for key {} '.format(key) - msg += 'because it does not exist!' + msg = "Cannot download checkpoint for key {} ".format(key) + msg += "because it does not exist!" raise RuntimeError(msg) state_dict = load_state_dict_from_url(checkpoint, map_location=device)[ - 'state_dict' + "state_dict" ] else: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint - state_dict = torch.load(checkpoint, map_location=device)['state_dict'] + state_dict = torch.load(checkpoint, map_location=device)["state_dict"] # load model - resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width']) + resnet = ResNetGenerator(cfg["model"]["name"], cfg["model"]["width"]) last_conv_channels = list(resnet.children())[-1].in_features features = nn.Sequential( get_norm_layer(3, 0), *list(resnet.children())[:-1], - nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1), + nn.Conv2d(last_conv_channels, cfg["model"]["num_ftrs"], 1), nn.AdaptiveAvgPool2d(1), ) model = _SimCLR( - features, num_ftrs=cfg['model']['num_ftrs'], out_dim=cfg['model']['out_dim'] + features, num_ftrs=cfg["model"]["num_ftrs"], out_dim=cfg["model"]["out_dim"] ).to(device) if state_dict is not None: diff --git a/lightly/cli/config/get_config.py b/lightly/cli/config/get_config.py index a1a6f6b08..e494ac9d7 100644 --- a/lightly/cli/config/get_config.py +++ b/lightly/cli/config/get_config.py @@ -1,6 +1,6 @@ from pathlib import Path -from omegaconf import OmegaConf, DictConfig +from omegaconf import DictConfig, OmegaConf def get_lightly_config() -> DictConfig: diff --git a/lightly/cli/crop_cli.py b/lightly/cli/crop_cli.py index 5ddf57bc0..bea498a1f 100644 --- a/lightly/cli/crop_cli.py +++ b/lightly/cli/crop_cli.py @@ -13,58 +13,66 @@ import hydra import yaml -from lightly.utils.hipify import bcolors from lightly.active_learning.utils import BoundingBox -from lightly.cli._helpers import fix_input_path -from lightly.cli._helpers import fix_hydra_arguments +from lightly.cli._helpers import fix_hydra_arguments, fix_input_path from lightly.data import LightlyDataset -from lightly.utils.cropping.crop_image_by_bounding_boxes import crop_dataset_by_bounding_boxes_and_save +from lightly.utils.cropping.crop_image_by_bounding_boxes import ( + crop_dataset_by_bounding_boxes_and_save, +) from lightly.utils.cropping.read_yolo_label_file import read_yolo_label_file +from lightly.utils.hipify import bcolors def _crop_cli(cfg, is_cli_call=True): - input_dir = cfg['input_dir'] + input_dir = cfg["input_dir"] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) - output_dir = cfg['output_dir'] + output_dir = cfg["output_dir"] if output_dir and is_cli_call: output_dir = fix_input_path(output_dir) - label_dir = cfg['label_dir'] + label_dir = cfg["label_dir"] if label_dir and is_cli_call: label_dir = fix_input_path(label_dir) - label_names_file = cfg['label_names_file'] + label_names_file = cfg["label_names_file"] if label_names_file and len(label_names_file) > 0: if is_cli_call: label_names_file = fix_input_path(label_names_file) - with open(label_names_file, 'r') as file: + with open(label_names_file, "r") as file: label_names_file_dict = yaml.full_load(file) - class_names = label_names_file_dict['names'] + class_names = label_names_file_dict["names"] else: class_names = None - dataset = LightlyDataset(input_dir) - class_indices_list_list: List[List[int]] = [] bounding_boxes_list_list: List[List[BoundingBox]] = [] # YOLO-Specific for filename_image in dataset.get_filenames(): filepath_image_base, image_extension = os.path.splitext(filename_image) - filepath_label = os.path.join(label_dir, filename_image).replace(image_extension, '.txt') - class_indices, bounding_boxes = read_yolo_label_file(filepath_label, float(cfg['crop_padding'])) + filepath_label = os.path.join(label_dir, filename_image).replace( + image_extension, ".txt" + ) + class_indices, bounding_boxes = read_yolo_label_file( + filepath_label, float(cfg["crop_padding"]) + ) class_indices_list_list.append(class_indices) bounding_boxes_list_list.append(bounding_boxes) - cropped_images_list_list = \ - crop_dataset_by_bounding_boxes_and_save(dataset, output_dir, bounding_boxes_list_list, class_indices_list_list, class_names) + cropped_images_list_list = crop_dataset_by_bounding_boxes_and_save( + dataset, + output_dir, + bounding_boxes_list_list, + class_indices_list_list, + class_names, + ) - print(f'Cropped images are stored at: {bcolors.OKBLUE}{output_dir}{bcolors.ENDC}') + print(f"Cropped images are stored at: {bcolors.OKBLUE}{output_dir}{bcolors.ENDC}") return cropped_images_list_list -@hydra.main(**fix_hydra_arguments(config_path = 'config', config_name = 'config')) +@hydra.main(**fix_hydra_arguments(config_path="config", config_name="config")) def crop_cli(cfg): """Crops images into one sub-image for each object. diff --git a/lightly/cli/download_cli.py b/lightly/cli/download_cli.py index dd4f6b332..d6650772e 100644 --- a/lightly/cli/download_cli.py +++ b/lightly/cli/download_cli.py @@ -11,36 +11,33 @@ import os import hydra -from lightly.utils.hipify import bcolors import lightly.data as data -from lightly.cli._helpers import fix_input_path -from lightly.cli._helpers import fix_hydra_arguments -from lightly.cli._helpers import cpu_count -from lightly.utils.hipify import print_as_warning - from lightly.api.api_workflow_client import ApiWorkflowClient +from lightly.cli._helpers import cpu_count, fix_hydra_arguments, fix_input_path from lightly.openapi_generated.swagger_client import Creator +from lightly.utils.hipify import bcolors, print_as_warning def _download_cli(cfg, is_cli_call=True): - - tag_name = str(cfg['tag_name']) - dataset_id = str(cfg['dataset_id']) - token = str(cfg['token']) + tag_name = str(cfg["tag_name"]) + dataset_id = str(cfg["dataset_id"]) + token = str(cfg["token"]) if not tag_name or not token or not dataset_id: - print_as_warning('Please specify all of the parameters tag_name, token and dataset_id') - print_as_warning('For help, try: lightly-download --help') + print_as_warning( + "Please specify all of the parameters tag_name, token and dataset_id" + ) + print_as_warning("For help, try: lightly-download --help") return # set the number of workers if unset - if cfg['loader']['num_workers'] < 0: + if cfg["loader"]["num_workers"] < 0: # set the number of workers to the number of CPUs available, # but minimum of 8 num_workers = max(8, cpu_count()) num_workers = min(32, num_workers) - cfg['loader']['num_workers'] = num_workers + cfg["loader"]["num_workers"] = num_workers api_workflow_client = ApiWorkflowClient( token=token, dataset_id=dataset_id, creator=Creator.USER_PIP_LIGHTLY_MAGIC @@ -50,12 +47,12 @@ def _download_cli(cfg, is_cli_call=True): tag_data = api_workflow_client.get_tag_by_name(tag_name) filenames_tag = api_workflow_client.get_filenames_in_tag( tag_data, - exclude_parent_tag=cfg['exclude_parent_tag'], + exclude_parent_tag=cfg["exclude_parent_tag"], ) # store sample names in a .txt file - filename = tag_name + '.txt' - with open(filename, 'w') as f: + filename = tag_name + ".txt" + with open(filename, "w") as f: for item in filenames_tag: f.write("%s\n" % item) @@ -63,19 +60,19 @@ def _download_cli(cfg, is_cli_call=True): msg = f'The list of samples in tag {cfg["tag_name"]} is stored at: {bcolors.OKBLUE}{filepath}{bcolors.ENDC}' print(msg, flush=True) - if not cfg['input_dir'] and cfg['output_dir']: + if not cfg["input_dir"] and cfg["output_dir"]: # download full images from api - output_dir = fix_input_path(cfg['output_dir']) + output_dir = fix_input_path(cfg["output_dir"]) api_workflow_client.download_dataset( - output_dir, - tag_name=tag_name, - max_workers=cfg['loader']['num_workers'] + output_dir, tag_name=tag_name, max_workers=cfg["loader"]["num_workers"] ) - elif cfg['input_dir'] and cfg['output_dir']: - input_dir = fix_input_path(cfg['input_dir']) - output_dir = fix_input_path(cfg['output_dir']) - print(f'Copying files from {input_dir} to {bcolors.OKBLUE}{output_dir}{bcolors.ENDC}.') + elif cfg["input_dir"] and cfg["output_dir"]: + input_dir = fix_input_path(cfg["input_dir"]) + output_dir = fix_input_path(cfg["output_dir"]) + print( + f"Copying files from {input_dir} to {bcolors.OKBLUE}{output_dir}{bcolors.ENDC}." + ) # create a dataset from the input directory dataset = data.LightlyDataset(input_dir=input_dir) @@ -84,41 +81,41 @@ def _download_cli(cfg, is_cli_call=True): dataset.dump(output_dir, filenames_tag) -@hydra.main(**fix_hydra_arguments(config_path = 'config', config_name = 'config')) +@hydra.main(**fix_hydra_arguments(config_path="config", config_name="config")) def download_cli(cfg): """Download images from the Lightly platform. Args: cfg: The default configs are loaded from the config file. - To overwrite them please see the section on the config file + To overwrite them please see the section on the config file (.config.config.yaml). - + Command-Line Args: tag_name: Download all images from the requested tag. Use initial-tag to get all images from the dataset. token: User access token to the Lightly platform. If dataset_id - and token are specified, the images and embeddings are + and token are specified, the images and embeddings are uploaded to the platform. dataset_id: - Identifier of the dataset on the Lightly platform. If - dataset_id and token are specified, the images and + Identifier of the dataset on the Lightly platform. If + dataset_id and token are specified, the images and embeddings are uploaded to the platform. input_dir: If input_dir and output_dir are specified, lightly will copy - all images belonging to the tag from the input_dir to the + all images belonging to the tag from the input_dir to the output_dir. output_dir: If input_dir and output_dir are specified, lightly will copy - all images belonging to the tag from the input_dir to the + all images belonging to the tag from the input_dir to the output_dir. Examples: >>> # download list of all files in the dataset from the Lightly platform >>> lightly-download token='123' dataset_id='XYZ' - >>> + >>> >>> # download list of all files in tag 'my-tag' from the Lightly platform >>> lightly-download token='123' dataset_id='XYZ' tag_name='my-tag' >>> diff --git a/lightly/cli/embed_cli.py b/lightly/cli/embed_cli.py index 53dff03dc..12891827b 100644 --- a/lightly/cli/embed_cli.py +++ b/lightly/cli/embed_cli.py @@ -9,38 +9,38 @@ # All Rights Reserved import os -from typing import Union, Tuple, List +from typing import List, Tuple, Union import hydra import numpy as np import torch import torchvision -from lightly.utils.hipify import bcolors +from lightly.cli._helpers import ( + cpu_count, + fix_hydra_arguments, + fix_input_path, + get_model_from_config, +) from lightly.data import LightlyDataset +from lightly.utils.hipify import bcolors from lightly.utils.io import save_embeddings -from lightly.cli._helpers import fix_hydra_arguments -from lightly.cli._helpers import get_model_from_config -from lightly.cli._helpers import fix_input_path -from lightly.cli._helpers import cpu_count - - -def _embed_cli(cfg, is_cli_call=True) -> \ - Union[ - Tuple[np.ndarray, List[int], List[str]], - str - ]: - """ See embed_cli() for usage documentation - - is_cli_call: - If True: - Saves the embeddings as file and returns the filepath. - If False: - Returns the embeddings, labels, filenames as tuple. - Embeddings are of shape (n_samples, embedding_size) - len(labels) = len(filenames) = n_samples + + +def _embed_cli( + cfg, is_cli_call=True +) -> Union[Tuple[np.ndarray, List[int], List[str]], str]: + """See embed_cli() for usage documentation + + is_cli_call: + If True: + Saves the embeddings as file and returns the filepath. + If False: + Returns the embeddings, labels, filenames as tuple. + Embeddings are of shape (n_samples, embedding_size) + len(labels) = len(filenames) = n_samples """ - input_dir = cfg['input_dir'] + input_dir = cfg["input_dir"] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) @@ -48,14 +48,14 @@ def _embed_cli(cfg, is_cli_call=True) -> \ torch.backends.cudnn.benchmark = False if torch.cuda.is_available(): - device = torch.device('cuda') + device = torch.device("cuda") else: - device = torch.device('cpu') + device = torch.device("cpu") transform = torchvision.transforms.Compose( [ torchvision.transforms.Resize( - (cfg['collate']['input_size'], cfg['collate']['input_size']) + (cfg["collate"]["input_size"], cfg["collate"]["input_size"]) ), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( @@ -67,33 +67,33 @@ def _embed_cli(cfg, is_cli_call=True) -> \ dataset = LightlyDataset(input_dir, transform=transform) # disable drop_last and shuffle - cfg['loader']['drop_last'] = False - cfg['loader']['shuffle'] = False - cfg['loader']['batch_size'] = min(cfg['loader']['batch_size'], len(dataset)) + cfg["loader"]["drop_last"] = False + cfg["loader"]["shuffle"] = False + cfg["loader"]["batch_size"] = min(cfg["loader"]["batch_size"], len(dataset)) # determine the number of available cores - if cfg['loader']['num_workers'] < 0: - cfg['loader']['num_workers'] = cpu_count() + if cfg["loader"]["num_workers"] < 0: + cfg["loader"]["num_workers"] = cpu_count() - dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader']) + dataloader = torch.utils.data.DataLoader(dataset, **cfg["loader"]) encoder = get_model_from_config(cfg, is_cli_call) embeddings, labels, filenames = encoder.embed(dataloader, device=device) if is_cli_call: - path = os.path.join(os.getcwd(), 'embeddings.csv') + path = os.path.join(os.getcwd(), "embeddings.csv") save_embeddings(path, embeddings, labels, filenames) - print(f'Embeddings are stored at {bcolors.OKBLUE}{path}{bcolors.ENDC}') + print(f"Embeddings are stored at {bcolors.OKBLUE}{path}{bcolors.ENDC}") os.environ[ - cfg['environment_variable_names']['lightly_last_embedding_path'] + cfg["environment_variable_names"]["lightly_last_embedding_path"] ] = path return path return embeddings, labels, filenames -@hydra.main(**fix_hydra_arguments(config_path = 'config', config_name = 'config')) +@hydra.main(**fix_hydra_arguments(config_path="config", config_name="config")) def embed_cli(cfg) -> str: """Embed images from the command-line. diff --git a/lightly/cli/lightly_cli.py b/lightly/cli/lightly_cli.py index b2076916a..68827421e 100644 --- a/lightly/cli/lightly_cli.py +++ b/lightly/cli/lightly_cli.py @@ -11,15 +11,15 @@ import hydra from omegaconf import DictConfig -from lightly.utils.hipify import print_as_warning from lightly.cli._helpers import fix_hydra_arguments -from lightly.cli.train_cli import _train_cli from lightly.cli.embed_cli import _embed_cli +from lightly.cli.train_cli import _train_cli from lightly.cli.upload_cli import _upload_cli +from lightly.utils.hipify import print_as_warning def validate_cfg(cfg: DictConfig) -> bool: - """ Validates a config + """Validates a config Prints warnings if it is not. Args: @@ -31,64 +31,66 @@ def validate_cfg(cfg: DictConfig) -> bool: """ valid = True - if cfg['trainer']['max_epochs'] > 0 and cfg['append']: - print_as_warning('When appending to an existing dataset you must ' - 'use the same embedding model. Thus specify ' - 'trainer.max_epochs=0. If you had trained your own model, ' - 'you can use it with checkpoint="path/to/model.ckp".') + if cfg["trainer"]["max_epochs"] > 0 and cfg["append"]: + print_as_warning( + "When appending to an existing dataset you must " + "use the same embedding model. Thus specify " + "trainer.max_epochs=0. If you had trained your own model, " + 'you can use it with checkpoint="path/to/model.ckp".' + ) valid = False return valid def _lightly_cli(cfg, is_cli_call=True): - cfg['loader']['shuffle'] = True - cfg['loader']['drop_last'] = True + cfg["loader"]["shuffle"] = True + cfg["loader"]["drop_last"] = True if not validate_cfg(cfg): return - if cfg['trainer']['max_epochs'] > 0: - print('#' * 10 + ' Starting to train an embedding model.') + if cfg["trainer"]["max_epochs"] > 0: + print("#" * 10 + " Starting to train an embedding model.") checkpoint = _train_cli(cfg, is_cli_call) else: - checkpoint = '' + checkpoint = "" - cfg['loader']['shuffle'] = False - cfg['loader']['drop_last'] = False - cfg['checkpoint'] = checkpoint + cfg["loader"]["shuffle"] = False + cfg["loader"]["drop_last"] = False + cfg["checkpoint"] = checkpoint - print('#' * 10 + ' Starting to embed your dataset.') + print("#" * 10 + " Starting to embed your dataset.") embeddings = _embed_cli(cfg, is_cli_call) - cfg['embeddings'] = embeddings + cfg["embeddings"] = embeddings - if cfg['token'] and (cfg['dataset_id'] or cfg['new_dataset_name']): - print('#' * 10 + ' Starting to upload your dataset to the Lightly platform.') + if cfg["token"] and (cfg["dataset_id"] or cfg["new_dataset_name"]): + print("#" * 10 + " Starting to upload your dataset to the Lightly platform.") _upload_cli(cfg) - print('#' * 10 + ' Finished') + print("#" * 10 + " Finished") -@hydra.main(**fix_hydra_arguments(config_path = 'config', config_name = 'config')) +@hydra.main(**fix_hydra_arguments(config_path="config", config_name="config")) def lightly_cli(cfg): """Train a self-supervised model and use it to embed your dataset. Args: cfg: The default configs are loaded from the config file. - To overwrite them please see the section on the config file + To overwrite them please see the section on the config file (.config.config.yaml). - + Command-Line Args: input_dir: Path to the input directory where images are stored. token: User access token to the Lightly platform. If dataset_id - and token are specified, the images and embeddings are + and token are specified, the images and embeddings are uploaded to the platform. (Required for upload) dataset_id: - Identifier of the dataset on the Lightly platform. If - dataset_id and token are specified, the images and + Identifier of the dataset on the Lightly platform. If + dataset_id and token are specified, the images and embeddings are uploaded to the platform. (Required for upload) custom_metadata: diff --git a/lightly/cli/train_cli.py b/lightly/cli/train_cli.py index f8b7b6231..d2046fb57 100644 --- a/lightly/cli/train_cli.py +++ b/lightly/cli/train_cli.py @@ -8,156 +8,148 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved import os +import warnings import hydra import torch import torch.nn as nn -import warnings - from omegaconf import OmegaConf -from lightly.utils.hipify import bcolors from lightly.cli._cli_simclr import _SimCLR -from lightly.data import ImageCollateFunction -from lightly.data import LightlyDataset +from lightly.cli._helpers import ( + cpu_count, + fix_hydra_arguments, + fix_input_path, + get_ptmodel_from_config, + is_url, + load_from_state_dict, + load_state_dict_from_url, +) +from lightly.data import ImageCollateFunction, LightlyDataset from lightly.embedding import SelfSupervisedEmbedding from lightly.loss import NTXentLoss - from lightly.models import ResNetGenerator from lightly.models.batchnorm import get_norm_layer - -from lightly.cli._helpers import is_url -from lightly.cli._helpers import get_ptmodel_from_config -from lightly.cli._helpers import fix_input_path -from lightly.cli._helpers import load_state_dict_from_url -from lightly.cli._helpers import load_from_state_dict -from lightly.cli._helpers import cpu_count -from lightly.cli._helpers import fix_hydra_arguments +from lightly.utils.hipify import bcolors def _train_cli(cfg, is_cli_call=True): - - input_dir = cfg['input_dir'] + input_dir = cfg["input_dir"] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) - if 'seed' in cfg.keys(): - seed = cfg['seed'] + if "seed" in cfg.keys(): + seed = cfg["seed"] torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if torch.cuda.is_available(): - device = 'cuda' - elif cfg['trainer'] and cfg['trainer']['gpus']: - device = 'cpu' - cfg['trainer']['gpus'] = 0 + device = "cuda" + elif cfg["trainer"] and cfg["trainer"]["gpus"]: + device = "cpu" + cfg["trainer"]["gpus"] = 0 else: - device = 'cpu' - + device = "cpu" + distributed_strategy = None - if cfg['trainer']['gpus'] > 1: - distributed_strategy = 'ddp' - - if cfg['loader']['batch_size'] < 64: - msg = 'Training a self-supervised model with a small batch size: {}! ' - msg = msg.format(cfg['loader']['batch_size']) - msg += 'Small batch size may harm embedding quality. ' - msg += 'You can specify the batch size via the loader key-word: ' - msg += 'loader.batch_size=BSZ' + if cfg["trainer"]["gpus"] > 1: + distributed_strategy = "ddp" + + if cfg["loader"]["batch_size"] < 64: + msg = "Training a self-supervised model with a small batch size: {}! " + msg = msg.format(cfg["loader"]["batch_size"]) + msg += "Small batch size may harm embedding quality. " + msg += "You can specify the batch size via the loader key-word: " + msg += "loader.batch_size=BSZ" warnings.warn(msg) # determine the number of available cores - if cfg['loader']['num_workers'] < 0: - cfg['loader']['num_workers'] = cpu_count() + if cfg["loader"]["num_workers"] < 0: + cfg["loader"]["num_workers"] = cpu_count() state_dict = None - checkpoint = cfg['checkpoint'] - if cfg['pre_trained'] and not checkpoint: + checkpoint = cfg["checkpoint"] + if cfg["pre_trained"] and not checkpoint: # if checkpoint wasn't specified explicitly and pre_trained is True # try to load the checkpoint from the model zoo - checkpoint, key = get_ptmodel_from_config(cfg['model']) + checkpoint, key = get_ptmodel_from_config(cfg["model"]) if not checkpoint: - msg = 'Cannot download checkpoint for key {} '.format(key) - msg += 'because it does not exist! ' - msg += 'Model will be trained from scratch.' + msg = "Cannot download checkpoint for key {} ".format(key) + msg += "because it does not exist! " + msg += "Model will be trained from scratch." warnings.warn(msg) elif checkpoint: checkpoint = fix_input_path(checkpoint) if is_cli_call else checkpoint - + if checkpoint: # load the PyTorch state dictionary and map it to the current device if is_url(checkpoint): - state_dict = load_state_dict_from_url( - checkpoint, map_location=device - )['state_dict'] + state_dict = load_state_dict_from_url(checkpoint, map_location=device)[ + "state_dict" + ] else: - state_dict = torch.load( - checkpoint, map_location=device - )['state_dict'] + state_dict = torch.load(checkpoint, map_location=device)["state_dict"] # load model - resnet = ResNetGenerator(cfg['model']['name'], cfg['model']['width']) + resnet = ResNetGenerator(cfg["model"]["name"], cfg["model"]["width"]) last_conv_channels = list(resnet.children())[-1].in_features features = nn.Sequential( get_norm_layer(3, 0), *list(resnet.children())[:-1], - nn.Conv2d(last_conv_channels, cfg['model']['num_ftrs'], 1), + nn.Conv2d(last_conv_channels, cfg["model"]["num_ftrs"], 1), nn.AdaptiveAvgPool2d(1), ) model = _SimCLR( - features, - num_ftrs=cfg['model']['num_ftrs'], - out_dim=cfg['model']['out_dim'] + features, num_ftrs=cfg["model"]["num_ftrs"], out_dim=cfg["model"]["out_dim"] ) if state_dict is not None: load_from_state_dict(model, state_dict) - criterion = NTXentLoss(**cfg['criterion']) - optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer']) + criterion = NTXentLoss(**cfg["criterion"]) + optimizer = torch.optim.SGD(model.parameters(), **cfg["optimizer"]) dataset = LightlyDataset(input_dir) - cfg['loader']['batch_size'] = min( - cfg['loader']['batch_size'], - len(dataset) - ) + cfg["loader"]["batch_size"] = min(cfg["loader"]["batch_size"], len(dataset)) - collate_fn = ImageCollateFunction(**cfg['collate']) - dataloader = torch.utils.data.DataLoader(dataset, - **cfg['loader'], - collate_fn=collate_fn) + collate_fn = ImageCollateFunction(**cfg["collate"]) + dataloader = torch.utils.data.DataLoader( + dataset, **cfg["loader"], collate_fn=collate_fn + ) encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader) # Add strategy field to trainer config trainer_config = OmegaConf.create( - dict(strategy=distributed_strategy, **cfg['trainer']) + dict(strategy=distributed_strategy, **cfg["trainer"]) ) encoder.train_embedding( trainer_config=trainer_config, - checkpoint_callback_config=cfg['checkpoint_callback'], - summary_callback_config=cfg['summary_callback'], + checkpoint_callback_config=cfg["checkpoint_callback"], + summary_callback_config=cfg["summary_callback"], ) - print(f'Best model is stored at: {bcolors.OKBLUE}{encoder.checkpoint}{bcolors.ENDC}') + print( + f"Best model is stored at: {bcolors.OKBLUE}{encoder.checkpoint}{bcolors.ENDC}" + ) os.environ[ - cfg['environment_variable_names']['lightly_last_checkpoint_path'] + cfg["environment_variable_names"]["lightly_last_checkpoint_path"] ] = encoder.checkpoint return encoder.checkpoint -@hydra.main(**fix_hydra_arguments(config_path = 'config', config_name = 'config')) +@hydra.main(**fix_hydra_arguments(config_path="config", config_name="config")) def train_cli(cfg): """Train a self-supervised model from the command-line. Args: cfg: The default configs are loaded from the config file. - To overwrite them please see the section on the config file + To overwrite them please see the section on the config file (.config.config.yaml). - + Command-Line Args: input_dir: Path to the input directory where images are stored. diff --git a/lightly/cli/upload_cli.py b/lightly/cli/upload_cli.py index e9357b1eb..b8dcf8062 100644 --- a/lightly/cli/upload_cli.py +++ b/lightly/cli/upload_cli.py @@ -12,21 +12,14 @@ from typing import Union import hydra - import torchvision -from lightly.utils.hipify import bcolors - -from lightly.api.api_workflow_upload_embeddings import \ - EmbeddingDoesNotExistError - -from lightly.cli._helpers import fix_input_path -from lightly.cli._helpers import cpu_count -from lightly.cli._helpers import fix_hydra_arguments -from lightly.utils.hipify import print_as_warning from lightly.api.api_workflow_client import ApiWorkflowClient +from lightly.api.api_workflow_upload_embeddings import EmbeddingDoesNotExistError +from lightly.cli._helpers import cpu_count, fix_hydra_arguments, fix_input_path from lightly.data import LightlyDataset from lightly.openapi_generated.swagger_client import Creator +from lightly.utils.hipify import bcolors, print_as_warning SUCCESS_RETURN_VALUE = "Success" @@ -47,70 +40,76 @@ def _upload_cli(cfg, is_cli_call=True) -> Union[str, None]: "Please use the Lightly Worker instead: https://docs.lightly.ai/docs/install-lightly\n", ) - input_dir = cfg['input_dir'] + input_dir = cfg["input_dir"] if input_dir and is_cli_call: input_dir = fix_input_path(input_dir) - path_to_embeddings = cfg['embeddings'] + path_to_embeddings = cfg["embeddings"] if path_to_embeddings and is_cli_call: path_to_embeddings = fix_input_path(path_to_embeddings) - dataset_id = cfg['dataset_id'] - token = cfg['token'] - new_dataset_name = cfg['new_dataset_name'] + dataset_id = cfg["dataset_id"] + token = cfg["token"] + new_dataset_name = cfg["new_dataset_name"] cli_api_args_wrong = False if not token: - print_as_warning('Please specify your access token.') + print_as_warning("Please specify your access token.") cli_api_args_wrong = True if dataset_id: if new_dataset_name: print_as_warning( - 'Please specify either the dataset_id of an existing dataset ' - 'or a new_dataset_name, but not both.' + "Please specify either the dataset_id of an existing dataset " + "or a new_dataset_name, but not both." ) cli_api_args_wrong = True else: - api_workflow_client = \ - ApiWorkflowClient(token=token, dataset_id=dataset_id, creator=Creator.USER_PIP_LIGHTLY_MAGIC) + api_workflow_client = ApiWorkflowClient( + token=token, + dataset_id=dataset_id, + creator=Creator.USER_PIP_LIGHTLY_MAGIC, + ) else: if new_dataset_name: - api_workflow_client = ApiWorkflowClient(token=token, creator=Creator.USER_PIP_LIGHTLY_MAGIC) + api_workflow_client = ApiWorkflowClient( + token=token, creator=Creator.USER_PIP_LIGHTLY_MAGIC + ) api_workflow_client.create_dataset(dataset_name=new_dataset_name) else: print_as_warning( - 'Please specify either the dataset_id of an existing dataset ' - 'or a new_dataset_name.') + "Please specify either the dataset_id of an existing dataset " + "or a new_dataset_name." + ) cli_api_args_wrong = True # delete the dataset_id as it might be an empty string # Use api_workflow_client.dataset_id instead del dataset_id if cli_api_args_wrong: - print_as_warning('For help, try: lightly-upload --help') + print_as_warning("For help, try: lightly-upload --help") return # potentially load custom metadata custom_metadata = None - if cfg['custom_metadata']: - path_to_custom_metadata = fix_input_path(cfg['custom_metadata']) + if cfg["custom_metadata"]: + path_to_custom_metadata = fix_input_path(cfg["custom_metadata"]) print( - 'Loading custom metadata from ' - f'{bcolors.OKBLUE}{path_to_custom_metadata}{bcolors.ENDC}' + "Loading custom metadata from " + f"{bcolors.OKBLUE}{path_to_custom_metadata}{bcolors.ENDC}" ) - with open(path_to_custom_metadata, 'r') as f: + with open(path_to_custom_metadata, "r") as f: custom_metadata = json.load(f) # set the number of workers if unset - if cfg['loader']['num_workers'] < 0: + if cfg["loader"]["num_workers"] < 0: # set the number of workers to the number of CPUs available, # but minimum of 8 num_workers = max(8, cpu_count()) num_workers = min(32, num_workers) - cfg['loader']['num_workers'] = num_workers + cfg["loader"]["num_workers"] = num_workers - size = cfg['resize'] + size = cfg["resize"] if not isinstance(size, int): size = tuple(size) transform = None @@ -120,31 +119,33 @@ def _upload_cli(cfg, is_cli_call=True) -> Union[str, None]: if input_dir: if not cfg.append and len(api_workflow_client.get_all_tags()) > 0: print_as_warning( - 'The dataset you specified already has samples. ' - 'If you want to add additional samples, you need to specify ' - 'append=True as CLI argument.' + "The dataset you specified already has samples. " + "If you want to add additional samples, you need to specify " + "append=True as CLI argument." ) return - mode = cfg['upload'] + mode = cfg["upload"] dataset = LightlyDataset(input_dir=input_dir, transform=transform) api_workflow_client.upload_dataset( input=dataset, mode=mode, - max_workers=cfg['loader']['num_workers'], + max_workers=cfg["loader"]["num_workers"], custom_metadata=custom_metadata, ) - print('Finished the upload of the dataset.') + print("Finished the upload of the dataset.") if path_to_embeddings: - name = cfg['embedding_name'] + name = cfg["embedding_name"] if not cfg.append: try: - embedding = api_workflow_client.get_embedding_by_name(name=name, ignore_suffix=True) + embedding = api_workflow_client.get_embedding_by_name( + name=name, ignore_suffix=True + ) print_as_warning( - 'The dataset you specified already has an embedding. ' - 'If you want to add additional samples, you need to specify ' - 'append=True as CLI argument.' + "The dataset you specified already has an embedding. " + "If you want to add additional samples, you need to specify " + "append=True as CLI argument." ) return except EmbeddingDoesNotExistError: @@ -152,28 +153,30 @@ def _upload_cli(cfg, is_cli_call=True) -> Union[str, None]: api_workflow_client.upload_embeddings( path_to_embeddings_csv=path_to_embeddings, name=name ) - print('Finished upload of embeddings.') + print("Finished upload of embeddings.") if custom_metadata is not None and not input_dir: # upload custom metadata separately api_workflow_client.upload_custom_metadata( custom_metadata, verbose=True, - max_workers=cfg['loader']['num_workers'], + max_workers=cfg["loader"]["num_workers"], ) if new_dataset_name: - print(f'The dataset_id of the newly created dataset is ' - f'{bcolors.OKBLUE}{api_workflow_client.dataset_id}{bcolors.ENDC}') + print( + f"The dataset_id of the newly created dataset is " + f"{bcolors.OKBLUE}{api_workflow_client.dataset_id}{bcolors.ENDC}" + ) os.environ[ - cfg['environment_variable_names']['lightly_last_dataset_id'] + cfg["environment_variable_names"]["lightly_last_dataset_id"] ] = api_workflow_client.dataset_id return SUCCESS_RETURN_VALUE -@hydra.main(**fix_hydra_arguments(config_path = 'config', config_name = 'config')) +@hydra.main(**fix_hydra_arguments(config_path="config", config_name="config")) def upload_cli(cfg): """Upload images/embeddings from the command-line to the Lightly platform. diff --git a/lightly/cli/version_cli.py b/lightly/cli/version_cli.py index e96a61b16..f9b537c57 100644 --- a/lightly/cli/version_cli.py +++ b/lightly/cli/version_cli.py @@ -10,21 +10,19 @@ # All Rights Reserved import hydra -import lightly +import lightly from lightly.cli._helpers import fix_hydra_arguments def _version_cli(): version = lightly.__version__ - print(f'lightly version {version}', flush=True) + print(f"lightly version {version}", flush=True) -@hydra.main(**fix_hydra_arguments(config_path = 'config', config_name = 'config')) +@hydra.main(**fix_hydra_arguments(config_path="config", config_name="config")) def version_cli(cfg): - """Prints the version of the used lightly package to the terminal. - - """ + """Prints the version of the used lightly package to the terminal.""" _version_cli() diff --git a/lightly/core.py b/lightly/core.py index 2d3efea80..01412880d 100644 --- a/lightly/core.py +++ b/lightly/core.py @@ -2,17 +2,16 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -from typing import Tuple, List +import os +from typing import List, Tuple import numpy as np +import yaml -from lightly.cli.train_cli import _train_cli +import lightly.cli as cli from lightly.cli.embed_cli import _embed_cli from lightly.cli.lightly_cli import _lightly_cli -import lightly.cli as cli - -import yaml -import os +from lightly.cli.train_cli import _train_cli def _get_config_path(config_path): @@ -30,7 +29,7 @@ def _get_config_path(config_path): """ if config_path is None: dirname = os.path.dirname(cli.__file__) - config_path = os.path.join(dirname, 'config/config.yaml') + config_path = os.path.join(dirname, "config/config.yaml") if not os.path.exists(config_path): raise ValueError("Config path {} does not exist!".format(config_path)) @@ -48,7 +47,7 @@ def _load_config_file(config_path): """ Loader = yaml.FullLoader - with open(config_path, 'r') as config_file: + with open(config_path, "r") as config_file: cfg = yaml.load(config_file, Loader=Loader) return cfg @@ -76,9 +75,9 @@ def _add_kwargs(cfg, kwargs): return cfg -def train_model_and_embed_images(config_path: str = None, **kwargs) -> Tuple[ - np.ndarray, List[int], List[str] -]: +def train_model_and_embed_images( + config_path: str = None, **kwargs +) -> Tuple[np.ndarray, List[int], List[str]]: """Train a self-supervised model and use it to embed images. First trains a modle using the _train_cli(), @@ -120,7 +119,7 @@ def train_model_and_embed_images(config_path: str = None, **kwargs) -> Tuple[ config_args = _add_kwargs(config_args, kwargs) checkpoint = _train_cli(config_args, is_cli_call=False) - config_args['checkpoint'] = checkpoint + config_args["checkpoint"] = checkpoint embeddings, labels, filenames = _embed_cli(config_args, is_cli_call=False) return embeddings, labels, filenames @@ -154,7 +153,7 @@ def train_embedding_model(config_path: str = None, **kwargs): >>> input_dir='path/to/data', config_path=my_config_path) >>> >>> # train a model with default settings and overwrites: large batch - >>> # sizes are benefitial for self-supervised training and more + >>> # sizes are benefitial for self-supervised training and more >>> # workers speed up the dataloading process. >>> my_loader = { >>> batch_size: 100, @@ -216,6 +215,6 @@ def embed_images(checkpoint: str, config_path: str = None, **kwargs): config_args = _load_config_file(config_path) config_args = _add_kwargs(config_args, kwargs) - config_args['checkpoint'] = checkpoint + config_args["checkpoint"] = checkpoint return _embed_cli(config_args, is_cli_call=False) diff --git a/lightly/data/__init__.py b/lightly/data/__init__.py index 872bf15cd..285cd7d8b 100644 --- a/lightly/data/__init__.py +++ b/lightly/data/__init__.py @@ -3,20 +3,24 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from lightly.data._video import ( + EmptyVideoError, + NonIncreasingTimestampError, + UnseekableTimestampError, + VideoError, +) +from lightly.data.collate import ( + BaseCollateFunction, + DINOCollateFunction, + ImageCollateFunction, + MAECollateFunction, + MoCoCollateFunction, + MSNCollateFunction, + MultiCropCollateFunction, + PIRLCollateFunction, + SimCLRCollateFunction, + SwaVCollateFunction, + VICRegLCollateFunction, + imagenet_normalize, +) from lightly.data.dataset import LightlyDataset -from lightly.data.collate import BaseCollateFunction -from lightly.data.collate import DINOCollateFunction -from lightly.data.collate import ImageCollateFunction -from lightly.data.collate import MAECollateFunction -from lightly.data.collate import MSNCollateFunction -from lightly.data.collate import PIRLCollateFunction -from lightly.data.collate import SimCLRCollateFunction -from lightly.data.collate import MoCoCollateFunction -from lightly.data.collate import MultiCropCollateFunction -from lightly.data.collate import SwaVCollateFunction -from lightly.data.collate import imagenet_normalize -from lightly.data.collate import VICRegLCollateFunction -from lightly.data._video import VideoError -from lightly.data._video import EmptyVideoError -from lightly.data._video import NonIncreasingTimestampError -from lightly.data._video import UnseekableTimestampError diff --git a/lightly/data/_helpers.py b/lightly/data/_helpers.py index e09e6fae2..00ab4fe09 100644 --- a/lightly/data/_helpers.py +++ b/lightly/data/_helpers.py @@ -4,7 +4,7 @@ # All Rights Reserved import os -from typing import List, Set, Optional, Callable, Dict, Any +from typing import Any, Callable, Dict, List, Optional, Set from torchvision import datasets @@ -12,17 +12,26 @@ try: from lightly.data._video import VideoDataset + VIDEO_DATASET_AVAILABLE = True except Exception as e: VIDEO_DATASET_AVAILABLE = False VIDEO_DATASET_ERRORMSG = e -IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', - '.pgm', '.tif', '.tiff', '.webp') +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +) -VIDEO_EXTENSIONS = ('.mp4', '.mov', '.avi', '.mpg', - '.hevc', '.m4v', '.webm', '.mpeg') +VIDEO_EXTENSIONS = (".mp4", ".mov", ".avi", ".mpg", ".hevc", ".m4v", ".webm", ".mpeg") def _dir_contains_videos(root: str, extensions: tuple): @@ -68,7 +77,7 @@ def _is_lightly_output_dir(dirname: str): True if dirname is "lightly_outputs" else false. """ - return 'lightly_outputs' in dirname + return "lightly_outputs" in dirname def _contains_subdirs(root: str): @@ -82,15 +91,15 @@ def _contains_subdirs(root: str): """ with os.scandir(root) as scan_dir: - return any(not _is_lightly_output_dir(f.name) for f in scan_dir \ - if f.is_dir()) + return any(not _is_lightly_output_dir(f.name) for f in scan_dir if f.is_dir()) def _load_dataset_from_folder( - root: str, transform, - is_valid_file: Optional[Callable[[str], bool]] = None, - tqdm_args: Dict[str, Any] = None, - num_workers_video_frame_counting: int = 0 + root: str, + transform, + is_valid_file: Optional[Callable[[str], bool]] = None, + tqdm_args: Dict[str, Any] = None, + num_workers_video_frame_counting: int = 0, ): """Initializes dataset from folder. @@ -106,17 +115,19 @@ def _load_dataset_from_folder( """ if not os.path.exists(root): - raise ValueError(f'The input directory {root} does not exist!') + raise ValueError(f"The input directory {root} does not exist!") # if there is a video in the input directory but we do not have # the right dependencies, raise a ValueError contains_videos = _contains_videos(root, VIDEO_EXTENSIONS) if contains_videos and not VIDEO_DATASET_AVAILABLE: - raise ValueError(f'The input directory {root} contains videos ' - 'but the VideoDataset is not available. \n' - 'Make sure you have installed the right ' - 'dependencies. The error from the imported ' - f'module was: {VIDEO_DATASET_ERRORMSG}') + raise ValueError( + f"The input directory {root} contains videos " + "but the VideoDataset is not available. \n" + "Make sure you have installed the right " + "dependencies. The error from the imported " + f"module was: {VIDEO_DATASET_ERRORMSG}" + ) if contains_videos: # root contains videos -> create a video dataset @@ -126,20 +137,20 @@ def _load_dataset_from_folder( transform=transform, is_valid_file=is_valid_file, tqdm_args=tqdm_args, - num_workers=num_workers_video_frame_counting + num_workers=num_workers_video_frame_counting, ) elif _contains_subdirs(root): # root contains subdirectories -> create an image folder dataset - dataset = datasets.ImageFolder(root, - transform=transform, - is_valid_file=is_valid_file - ) + dataset = datasets.ImageFolder( + root, transform=transform, is_valid_file=is_valid_file + ) else: # root contains plain images -> create a folder dataset - dataset = DatasetFolder(root, - extensions=IMG_EXTENSIONS, - transform=transform, - is_valid_file=is_valid_file - ) + dataset = DatasetFolder( + root, + extensions=IMG_EXTENSIONS, + transform=transform, + is_valid_file=is_valid_file, + ) return dataset diff --git a/lightly/data/_image.py b/lightly/data/_image.py index 6983ae506..e266c8d75 100644 --- a/lightly/data/_image.py +++ b/lightly/data/_image.py @@ -3,16 +3,18 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -from typing import List, Tuple, Set - import os +from typing import List, Set, Tuple + import torchvision.datasets as datasets from torchvision import transforms from lightly.data._image_loaders import default_loader -def _make_dataset(directory, extensions=None, is_valid_file=None) -> List[Tuple[str, int]]: +def _make_dataset( + directory, extensions=None, is_valid_file=None +) -> List[Tuple[str, int]]: """Returns a list of all image files with targets in the directory. Args: @@ -30,21 +32,23 @@ def _make_dataset(directory, extensions=None, is_valid_file=None) -> List[Tuple[ if extensions is None: if is_valid_file is None: - ValueError('Both extensions and is_valid_file cannot be None') + ValueError("Both extensions and is_valid_file cannot be None") else: _is_valid_file = is_valid_file else: + def is_valid_file_extension(filepath): return filepath.lower().endswith(extensions) + if is_valid_file is None: _is_valid_file = is_valid_file_extension else: + def _is_valid_file(filepath): return is_valid_file_extension(filepath) and is_valid_file(filepath) instances = [] for f in os.scandir(directory): - if not _is_valid_file(f.path): continue @@ -53,12 +57,12 @@ def _is_valid_file(filepath): item = (f.path, 0) instances.append(item) - return sorted(instances, key=lambda x: x[0]) # sort by path + return sorted(instances, key=lambda x: x[0]) # sort by path class DatasetFolder(datasets.VisionDataset): """Implements a dataset folder. - + DatasetFolder based on torchvisions implementation. (https://pytorch.org/docs/stable/torchvision/datasets.html#datasetfolder) @@ -81,25 +85,24 @@ class DatasetFolder(datasets.VisionDataset): """ - def __init__(self, - root: str, - loader=default_loader, - extensions=None, - transform=None, - target_transform=None, - is_valid_file=None, - ): - - super(DatasetFolder, self).__init__(root, - transform=transform, - target_transform=target_transform) + def __init__( + self, + root: str, + loader=default_loader, + extensions=None, + transform=None, + target_transform=None, + is_valid_file=None, + ): + super(DatasetFolder, self).__init__( + root, transform=transform, target_transform=target_transform + ) samples = _make_dataset(self.root, extensions, is_valid_file) if len(samples) == 0: - msg = 'Found 0 files in folder: {}\n'.format(self.root) + msg = "Found 0 files in folder: {}\n".format(self.root) if extensions is not None: - msg += 'Supported extensions are: {}'.format( - ','.join(extensions)) + msg += "Supported extensions are: {}".format(",".join(extensions)) raise RuntimeError(msg) self.loader = loader @@ -130,7 +133,5 @@ def __getitem__(self, index: int): return sample, target def __len__(self): - """Returns the number of samples in the dataset. - - """ + """Returns the number of samples in the dataset.""" return len(self.samples) diff --git a/lightly/data/_image_loaders.py b/lightly/data/_image_loaders.py index 0be296f88..d7680de8e 100644 --- a/lightly/data/_image_loaders.py +++ b/lightly/data/_image_loaders.py @@ -12,14 +12,15 @@ def pil_loader(path): # open path as file to avoid ResourceWarning # (https://github.com/python-pillow/Pillow/issues/835) - with open(path, 'rb') as f: + with open(path, "rb") as f: img = Image.open(f) - return img.convert('RGB') + return img.convert("RGB") def accimage_loader(path): try: import accimage + return accimage.Image(path) except IOError: # Potentially a decoding problem, fall back to PIL.Image @@ -28,7 +29,8 @@ def accimage_loader(path): def default_loader(path): from torchvision import get_image_backend - if get_image_backend() == 'accimage': + + if get_image_backend() == "accimage": return accimage_loader(path) else: return pil_loader(path) diff --git a/lightly/data/_utils.py b/lightly/data/_utils.py index 122d6c0d4..86ea687b0 100644 --- a/lightly/data/_utils.py +++ b/lightly/data/_utils.py @@ -5,29 +5,28 @@ import os from typing import * -from PIL import Image -from PIL import UnidentifiedImageError + import tqdm.contrib.concurrent as concurrent +from PIL import Image, UnidentifiedImageError + from lightly.data import LightlyDataset def check_images(data_dir: str) -> Tuple[List[str], List[str]]: - '''Iterate through a directory of images and find corrupt images + """Iterate through a directory of images and find corrupt images Args: data_dir: Path to the directory containing the images Returns: (healthy_images, corrupt_images) - ''' + """ dataset = LightlyDataset(input_dir=data_dir) filenames = dataset.get_filenames() def _is_corrupt(filename): try: - image = Image.open( - os.path.join(data_dir, filename) - ) + image = Image.open(os.path.join(data_dir, filename)) image.load() except (IOError, UnidentifiedImageError): return True @@ -35,12 +34,8 @@ def _is_corrupt(filename): return False mapped = concurrent.thread_map( - _is_corrupt, - filenames, - chunksize=min(32, len(filenames)) + _is_corrupt, filenames, chunksize=min(32, len(filenames)) ) - healthy_images = [f for f, is_corrupt - in zip(filenames, mapped) if not is_corrupt] - corrupt_images = [f for f, is_corrupt - in zip(filenames, mapped) if is_corrupt] + healthy_images = [f for f, is_corrupt in zip(filenames, mapped) if not is_corrupt] + corrupt_images = [f for f, is_corrupt in zip(filenames, mapped) if is_corrupt] return healthy_images, corrupt_images diff --git a/lightly/data/_video.py b/lightly/data/_video.py index a2fc57d45..8961ea9ba 100644 --- a/lightly/data/_video.py +++ b/lightly/data/_video.py @@ -4,51 +4,54 @@ # All Rights Reserved import os -from typing import List, Tuple, Dict, Any -from fractions import Fraction import threading -import weakref import warnings +import weakref +from fractions import Fraction +from typing import Any, Dict, List, Tuple import numpy as np -from PIL import Image - import torch import torchvision +from PIL import Image from torch.utils.data import DataLoader, Dataset -from torchvision import datasets -from torchvision import io +from torchvision import datasets, io from tqdm import tqdm try: import av + AV_AVAILABLE = True except ImportError: AV_AVAILABLE = False if io._HAS_VIDEO_OPT: - torchvision.set_video_backend('video_reader') + torchvision.set_video_backend("video_reader") class VideoError(Exception): """Base exception class for errors during video loading.""" + pass class EmptyVideoError(VideoError): """Exception raised when trying to load a frame from an empty video.""" + pass class FrameShapeError(VideoError): """Exception raised when the loaded frame has an unexpected shape.""" + pass class NonIncreasingTimestampError(VideoError): - """Exception raised when trying to load a frame that has a timestamp + """Exception raised when trying to load a frame that has a timestamp equal or lower than the timestamps of previous frames in the video. """ + pass @@ -56,42 +59,44 @@ class UnseekableTimestampError(VideoError): """Exception raised when trying to load a frame that has a timestamp which cannot be seeked to by the video loader. """ + pass -# @guarin 18.02.2022 -# VideoLoader and VideoDataset multi-thread and multi-processing infos -# -------------------------------------------------------------------- -# The VideoDataset class should be safe to use in multi-thread and +# @guarin 18.02.2022 +# VideoLoader and VideoDataset multi-thread and multi-processing infos +# -------------------------------------------------------------------- +# The VideoDataset class should be safe to use in multi-thread and # multi-processing settings. For the multi-processing setting it is assumed that -# a pytorch DataLoader is used. Multi-threading should not be use with the +# a pytorch DataLoader is used. Multi-threading should not be use with the # torchvision pyav video packend as pyav seems to be limited to a single thread. # You will not see any speedups when using it from multiple threads! -#  -# The VideoLoader class is thread safe because it inherits from threading.local. -# When using it within a pytorch DataLoader a new instance should be created -# in each process when using the torchvision video_reader backend, otherwise +# +# The VideoLoader class is thread safe because it inherits from threading.local. +# When using it within a pytorch DataLoader a new instance should be created +# in each process when using the torchvision video_reader backend, otherwise # decoder errors can happen when iterating multiple times over the dataloader. # This is specific to the video_reader backend and does not happen with the pyav # backend. -# -# In the VideoDataset class we avoid sharing VideoLoader instances between -# workers by tracking the worker accessing the dataset. VideoLoaders are reset -# if a new worker accesses the dataset. Note that changes to the dataset class -# by a worker are unique to that worker and not seen by other workers or the +# +# In the VideoDataset class we avoid sharing VideoLoader instances between +# workers by tracking the worker accessing the dataset. VideoLoaders are reset +# if a new worker accesses the dataset. Note that changes to the dataset class +# by a worker are unique to that worker and not seen by other workers or the # main process. + class VideoLoader(threading.local): """Implementation of VideoLoader. The VideoLoader is a wrapper around the torchvision video interface. With the VideoLoader you can read specific frames or the next frames of a video. It automatically switches to the `video_loader` backend if available. Reading - sequential frames is significantly faster since it uses the VideoReader + sequential frames is significantly faster since it uses the VideoReader class from torchvision. The video loader automatically detects if you read out subsequent frames and - will use the fast read method if possible. + will use the fast read method if possible. Attributes: path: @@ -124,37 +129,38 @@ class from torchvision. >>> # get next frame >>> frame = video_loader.read_frame() """ + def __init__( self, path: str, timestamps: List[float], - backend: str = 'video_reader', + backend: str = "video_reader", eps: float = 1e-6, ): self.path = path self.timestamps = timestamps self.current_index = None - self.pts_unit='sec' + self.pts_unit = "sec" self.backend = backend self.eps = eps - has_video_reader = io._HAS_VIDEO_OPT and hasattr(io, 'VideoReader') + has_video_reader = io._HAS_VIDEO_OPT and hasattr(io, "VideoReader") - if has_video_reader and self.backend == 'video_reader': - self.reader = io.VideoReader(path = self.path) + if has_video_reader and self.backend == "video_reader": + self.reader = io.VideoReader(path=self.path) else: self.reader = None - def read_frame(self, timestamp = None): + def read_frame(self, timestamp=None): """Reads the next frame or from timestamp. If no timestamp is provided this method just returns the next frame from - the video. This is significantly (up to 10x) faster if the `video_loader` + the video. This is significantly (up to 10x) faster if the `video_loader` backend is available. If a timestamp is provided we first have to seek to the right position and then load the frame. - + Args: - timestamp: + timestamp: Specific timestamp of frame in seconds or None (default: None) Returns: @@ -163,16 +169,14 @@ def read_frame(self, timestamp = None): Raises: StopIteration: If end of video is reached and timestamp is None. - ValueError: + ValueError: If provided timestamp is not in self.timestamps. VideoError: If the frame could not be loaded. """ if not self.timestamps: - raise EmptyVideoError( - f'Cannot load frame from empty video {self.path}.' - ) + raise EmptyVideoError(f"Cannot load frame from empty video {self.path}.") if timestamp is None: # Try to read next frame. @@ -184,7 +188,7 @@ def read_frame(self, timestamp = None): # Reached end of video. raise StopIteration() else: - # Read next frame. + # Read next frame. index = self.current_index + 1 timestamp = self.timestamps[index] elif ( @@ -192,46 +196,48 @@ def read_frame(self, timestamp = None): and self.current_index + 1 < len(self.timestamps) and timestamp == self.timestamps[self.current_index + 1] ): - # Provided timestamp is timestamp of next frame. + # Provided timestamp is timestamp of next frame. index = self.current_index + 1 else: - # Random timestamp, must find corresponding index. + # Random timestamp, must find corresponding index. index = self.timestamps.index(timestamp) if self.reader: - # Only seek if we cannot just call next(self.reader). + # Only seek if we cannot just call next(self.reader). if ( - self.current_index is None and index != 0 - or self.current_index is not None and index != self.current_index + 1 + self.current_index is None + and index != 0 + or self.current_index is not None + and index != self.current_index + 1 ): self.reader.seek(timestamp) - # Find next larger timestamp than the one we seek. Used to verify - # that we did not seek too far in the video and that the correct - # frame is returned. + # Find next larger timestamp than the one we seek. Used to verify + # that we did not seek too far in the video and that the correct + # frame is returned. if index + 1 < len(self.timestamps): try: next_timestamp = next( - ts for ts in self.timestamps[index + 1:] if ts > timestamp + ts for ts in self.timestamps[index + 1 :] if ts > timestamp ) except StopIteration: # All timestamps of future frames are smaller. - next_timestamp = float('inf') + next_timestamp = float("inf") else: - # Want to load last frame in video. - next_timestamp = float('inf') + # Want to load last frame in video. + next_timestamp = float("inf") - # Load the frame. + # Load the frame. try: while True: frame_info = next(self.reader) - if frame_info['pts'] < timestamp - self.eps: - # Did not read far enough, let's continue reading more + if frame_info["pts"] < timestamp - self.eps: + # Did not read far enough, let's continue reading more # frames. This can happen due to decreasing timestamps. frame_info = next(self.reader) - elif frame_info['pts'] >= next_timestamp: - # Accidentally read too far, let's seek back to the - # correct position and exit. This can happen due to + elif frame_info["pts"] >= next_timestamp: + # Accidentally read too far, let's seek back to the + # correct position and exit. This can happen due to # imprecise seek. self.reader.seek(timestamp) frame_info = next(self.reader) @@ -240,47 +246,49 @@ def read_frame(self, timestamp = None): break except StopIteration: # Accidentally reached the end of the video, let's seek back to - # the correction position. This can happen due to imprecise seek. + # the correction position. This can happen due to imprecise seek. self.reader.seek(timestamp) try: frame_info = next(self.reader) except StopIteration as ex: - # Seeking to this timestamp simply doesn't work. + # Seeking to this timestamp simply doesn't work. raise UnseekableTimestampError( - f'Cannot seek to frame with timestamp {float(timestamp)} ' - f'in {self.path}.' + f"Cannot seek to frame with timestamp {float(timestamp)} " + f"in {self.path}." ) from ex if ( - frame_info['pts'] < timestamp - self.eps - or frame_info['pts'] >= next_timestamp + frame_info["pts"] < timestamp - self.eps + or frame_info["pts"] >= next_timestamp ): - # We accidentally loaded the wrong frame. This should only + # We accidentally loaded the wrong frame. This should only # happen if self.reader.seek(timestamp) does not seek to the - # correct timestamp. In this case there is nothing we can do to + # correct timestamp. In this case there is nothing we can do to # load the correct frame and we alert the user that something # went wrong. warnings.warn( - f'Loaded wrong frame in {self.path}! Tried to load frame ' - f'with index {index} and timestamp {float(timestamp)} but ' + f"Loaded wrong frame in {self.path}! Tried to load frame " + f"with index {index} and timestamp {float(timestamp)} but " f'could only find frame with timestamp {frame_info["pts"]}.' ) # Make sure we have the tensor in correct shape (we want H x W x C) - frame = frame_info['data'].permute(1,2,0) + frame = frame_info["data"].permute(1, 2, 0) self.current_index = index - else: # fallback on pyav - frame, _, _ = io.read_video(self.path, - start_pts=timestamp, - end_pts=timestamp, - pts_unit=self.pts_unit) + else: # fallback on pyav + frame, _, _ = io.read_video( + self.path, + start_pts=timestamp, + end_pts=timestamp, + pts_unit=self.pts_unit, + ) self.current_index = index if len(frame.shape) < 3: raise FrameShapeError( - f'Loaded frame has unexpected shape {frame.shape}. ' - f'Frames are expected to have 3 dimensions: (H, W, C).' + f"Loaded frame has unexpected shape {frame.shape}. " + f"Frames are expected to have 3 dimensions: (H, W, C)." ) # sometimes torchvision returns multiple frames for one timestamp (bug?) @@ -297,7 +305,6 @@ def read_frame(self, timestamp = None): class _TimestampFpsFromVideosDataset(Dataset): - def __init__(self, video_instances: List[str], pts_unit: str): self.video_instances = video_instances self.pts_unit = pts_unit @@ -312,12 +319,12 @@ def __getitem__(self, index): def _make_dataset( - directory, - extensions=None, - is_valid_file=None, - pts_unit='sec', - tqdm_args=None, - num_workers: int = 0 + directory, + extensions=None, + is_valid_file=None, + pts_unit="sec", + tqdm_args=None, + num_workers: int = 0, ): """Returns a list of all video files, timestamps, and offsets. @@ -344,15 +351,18 @@ def _make_dataset( tqdm_args = {} if extensions is None: if is_valid_file is None: - ValueError('Both extensions and is_valid_file cannot be None') + ValueError("Both extensions and is_valid_file cannot be None") else: _is_valid_file = is_valid_file else: + def is_valid_file_extension(filepath): return filepath.lower().endswith(extensions) + if is_valid_file is None: _is_valid_file = is_valid_file_extension else: + def _is_valid_file(filepath): return is_valid_file_extension(filepath) and is_valid_file(filepath) @@ -361,8 +371,8 @@ def _is_valid_file(filepath): def on_error(error): raise error - for root, _, files in os.walk(directory, onerror=on_error): + for root, _, files in os.walk(directory, onerror=on_error): for fname in files: # skip invalid files if not _is_valid_file(os.path.join(root, fname)): @@ -380,13 +390,13 @@ def on_error(error): _TimestampFpsFromVideosDataset(video_instances, pts_unit=pts_unit), num_workers=num_workers, batch_size=None, - shuffle=False + shuffle=False, ) # actually load the data tqdm_args = dict(tqdm_args) - tqdm_args.setdefault('unit', ' video') - tqdm_args.setdefault('desc', 'Counting frames in videos') + tqdm_args.setdefault("unit", " video") + tqdm_args.setdefault("desc", "Counting frames in videos") timestamps_fpss = list(tqdm(loader, **tqdm_args)) timestamps, fpss = zip(*timestamps_fpss) @@ -397,9 +407,7 @@ def on_error(error): return video_instances, timestamps, offsets, fpss -def _find_non_increasing_timestamps( - timestamps: List[Fraction] - ) -> List[bool]: +def _find_non_increasing_timestamps(timestamps: List[Fraction]) -> List[bool]: """Finds all non-increasing timestamps. Arguments: @@ -413,8 +421,11 @@ def _find_non_increasing_timestamps( """ if len(timestamps) == 0: return [] - is_non_increasing = np.zeros(shape=len(timestamps), dtype=bool, ) - max_timestamp = timestamps[0]-1 + is_non_increasing = np.zeros( + shape=len(timestamps), + dtype=bool, + ) + max_timestamp = timestamps[0] - 1 for i, timestamp in enumerate(timestamps): if timestamp > max_timestamp: max_timestamp = timestamp @@ -443,55 +454,55 @@ class VideoDataset(datasets.VisionDataset): Used to check corrupt files exception_on_non_increasing_timestamp: If True, a NonIncreasingTimestampError is raised when trying to load - a frame that has a timestamp lower or equal to the timestamps of + a frame that has a timestamp lower or equal to the timestamps of previous frames in the same video. """ - def __init__(self, - root, - extensions=None, - transform=None, - target_transform=None, - is_valid_file=None, - exception_on_non_increasing_timestamp=True, - tqdm_args: Dict[str, Any]=None, - num_workers: int = 0, - ): - - super(VideoDataset, self).__init__(root, - transform=transform, - target_transform=target_transform) + def __init__( + self, + root, + extensions=None, + transform=None, + target_transform=None, + is_valid_file=None, + exception_on_non_increasing_timestamp=True, + tqdm_args: Dict[str, Any] = None, + num_workers: int = 0, + ): + super(VideoDataset, self).__init__( + root, transform=transform, target_transform=target_transform + ) videos, video_timestamps, offsets, fps = _make_dataset( self.root, extensions, is_valid_file, tqdm_args=tqdm_args, - num_workers=num_workers + num_workers=num_workers, ) if len(videos) == 0: - msg = 'Found 0 videos in folder: {}\n'.format(self.root) + msg = "Found 0 videos in folder: {}\n".format(self.root) if extensions is not None: - msg += 'Supported extensions are: {}'.format( - ','.join(extensions)) + msg += "Supported extensions are: {}".format(",".join(extensions)) raise RuntimeError(msg) self.extensions = extensions self.backend = torchvision.get_video_backend() - self.exception_on_non_increasing_timestamp = exception_on_non_increasing_timestamp + self.exception_on_non_increasing_timestamp = ( + exception_on_non_increasing_timestamp + ) self.videos = videos self.video_timestamps = video_timestamps - self._length = sum(( - len(ts) for ts in self.video_timestamps - )) - # Boolean value for every timestamp in self.video_timestamps. If True + self._length = sum((len(ts) for ts in self.video_timestamps)) + # Boolean value for every timestamp in self.video_timestamps. If True # the timestamp of the frame is non-increasing compared to timestamps of # previous frames in the video. self.video_timestamps_is_non_increasing = [ - _find_non_increasing_timestamps(timestamps) for timestamps in video_timestamps + _find_non_increasing_timestamps(timestamps) + for timestamps in video_timestamps ] # offsets[i] indicates the index of the first frame of the i-th video. @@ -499,25 +510,25 @@ def __init__(self, self.offsets = offsets self.fps = fps - # Current VideoLoader instance and the corresponding video index. We + # Current VideoLoader instance and the corresponding video index. We # only keep track of the last accessed video as this is a good trade-off # between speed and memory requirements. # See https://github.com/lightly-ai/lightly/pull/702 for details. self._video_loader = None self._video_index = None - # Keep unique reference of dataloader worker. We need this to avoid + # Keep unique reference of dataloader worker. We need this to avoid # accidentaly sharing VideoLoader instances between workers. self._worker_ref = None - # Lock to prevent multiple threads creating a new VideoLoader at the + # Lock to prevent multiple threads creating a new VideoLoader at the # same time. self._video_loader_lock = threading.Lock() def __getitem__(self, index): """Returns item at index. - Finds the video of the frame at index with the help of the frame + Finds the video of the frame at index with the help of the frame offsets. Then, loads the frame from the video, applies the transforms, and returns the frame along with the index of the video (as target). @@ -549,14 +560,16 @@ def __getitem__(self, index): """ if index < 0 or index >= self.__len__(): - raise IndexError(f'Index {index} is out of bounds for VideoDataset' - f' of size {self.__len__()}.') + raise IndexError( + f"Index {index} is out of bounds for VideoDataset" + f" of size {self.__len__()}." + ) # each sample belongs to a video, to load the sample at index, we need # to find the video to which the sample belongs and then read the frame # from this video on the disk. i = len(self.offsets) - 1 - while (self.offsets[i] > index): + while self.offsets[i] > index: i = i - 1 timestamp_idx = index - self.offsets[i] @@ -566,11 +579,11 @@ def __getitem__(self, index): and self.video_timestamps_is_non_increasing[i][timestamp_idx] ): raise NonIncreasingTimestampError( - f'Frame {timestamp_idx} of video {self.videos[i]} has ' - f'a timestamp that is equal or lower than timestamps of previous ' - f'frames in the video. Trying to load this frame might result ' - f'in the wrong frame being returned. Set the VideoDataset.exception_on_non_increasing_timestamp' - f'attribute to False to allow unsafe frame loading.' + f"Frame {timestamp_idx} of video {self.videos[i]} has " + f"a timestamp that is equal or lower than timestamps of previous " + f"frames in the video. Trying to load this frame might result " + f"in the wrong frame being returned. Set the VideoDataset.exception_on_non_increasing_timestamp" + f"attribute to False to allow unsafe frame loading." ) # find and return the frame as PIL image @@ -598,30 +611,32 @@ def get_filename(self, index): """Returns a filename for the frame at index. The filename is created from the video filename, the frame number, and - the video format. The frame number will be zero padded to make sure + the video format. The frame number will be zero padded to make sure all filenames have the same length and can easily be sorted. E.g. when retrieving a sample from the video `my_video.mp4` at frame 153, the filename will be: >>> my_video-153-mp4.png - + Args: index: Index of the frame to retrieve. Returns: The filename of the frame as described above. - + """ if index < 0 or index >= self.__len__(): - raise IndexError(f'Index {index} is out of bounds for VideoDataset' - f' of size {self.__len__()}.') + raise IndexError( + f"Index {index} is out of bounds for VideoDataset" + f" of size {self.__len__()}." + ) # each sample belongs to a video, to load the sample at index, we need # to find the video to which the sample belongs and then read the frame # from this video on the disk. i = len(self.offsets) - 1 - while (self.offsets[i] > index): + while self.offsets[i] > index: i = i - 1 # get filename of the video file @@ -642,9 +657,7 @@ def get_filename(self, index): ) def get_filenames(self) -> List[str]: - """Returns a list filenames for all frames in the dataset. - - """ + """Returns a list filenames for all frames in the dataset.""" filenames = [] for i, video in enumerate(self.videos): video_name, video_format = self._video_name_format(video) @@ -663,11 +676,9 @@ def get_filenames(self) -> List[str]: return filenames def _video_frame_count(self, video_index: int) -> int: - """Returns the number of frames in the video with the given index. - - """ + """Returns the number of frames in the video with the given index.""" if video_index < len(self.offsets) - 1: - n_frames = self.offsets[video_index+1] - self.offsets[video_index] + n_frames = self.offsets[video_index + 1] - self.offsets[video_index] else: n_frames = len(self) - self.offsets[video_index] return n_frames @@ -677,14 +688,14 @@ def _video_name_format(self, video_filename: str) -> Tuple[str, str]: Returns: A (video_name, video_format) tuple where video_name is the filename - relative to self.root and video_format is the file extension, for + relative to self.root and video_format is the file extension, for example 'mp4'. """ video_filename = os.path.relpath(video_filename, self.root) - splits = video_filename.split('.') + splits = video_filename.split(".") video_format = splits[-1] - video_name = '.'.join(splits[:-1]) + video_name = ".".join(splits[:-1]) return video_name, video_format def _format_filename( @@ -693,9 +704,9 @@ def _format_filename( frame_number: int, video_format: str, zero_padding: int = 8, - extension: str = 'png' + extension: str = "png", ) -> str: - return f'{video_name}-{frame_number:0{zero_padding}}-{video_format}.{extension}' + return f"{video_name}-{frame_number:0{zero_padding}}-{video_format}.{extension}" def _get_video_loader(self, video_index: int) -> VideoLoader: """Returns a video loader unique to the current dataloader worker.""" @@ -715,7 +726,9 @@ def _get_video_loader(self, video_index: int) -> VideoLoader: if video_index != self._video_index: video = self.videos[video_index] timestamps = self.video_timestamps[video_index] - self._video_loader = VideoLoader(video, timestamps, backend=self.backend) + self._video_loader = VideoLoader( + video, timestamps, backend=self.backend + ) self._video_index = video_index return self._video_loader diff --git a/lightly/data/collate.py b/lightly/data/collate.py index cda715906..3215530a5 100644 --- a/lightly/data/collate.py +++ b/lightly/data/collate.py @@ -3,23 +3,23 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from typing import List, Optional, Tuple, Union + import torch import torch.nn as nn - -from typing import List, Tuple, Union, Optional - -from PIL import Image import torchvision import torchvision.transforms as T +from PIL import Image from lightly.transforms import GaussianBlur, Jigsaw, RandomRotate, RandomSolarization from lightly.transforms.random_crop_and_flip_with_grid import RandomResizedCropAndFlip -from lightly.transforms.utils import IMAGENET_NORMALIZE from lightly.transforms.rotation import random_rotation_transform +from lightly.transforms.utils import IMAGENET_NORMALIZE imagenet_normalize = IMAGENET_NORMALIZE # Kept for backwards compatibility + class BaseCollateFunction(nn.Module): """Base class for other collate implementations. @@ -36,7 +36,6 @@ class BaseCollateFunction(nn.Module): """ def __init__(self, transform: torchvision.transforms.Compose): - super(BaseCollateFunction, self).__init__() self.transform = transform @@ -159,7 +158,6 @@ def __init__( rr_degrees: Union[None, float, Tuple[float, float]] = None, normalize: dict = imagenet_normalize, ): - if isinstance(input_size, tuple): input_size_ = max(input_size) else: @@ -289,7 +287,6 @@ def __init__( rr_degrees: Union[None, float, Tuple[float, float]] = None, normalize: dict = imagenet_normalize, ): - super(SimCLRCollateFunction, self).__init__( input_size=input_size, cj_prob=cj_prob, @@ -377,7 +374,6 @@ def __init__( rr_degrees: Union[None, float, Tuple[float, float]] = None, normalize: dict = imagenet_normalize, ): - super(MoCoCollateFunction, self).__init__( input_size=input_size, cj_prob=cj_prob, @@ -423,7 +419,6 @@ def __init__( crop_max_scales: List[float], transforms: T.Compose, ): - if len(crop_sizes) != len(crop_counts): raise ValueError( "Length of crop_sizes and crop_counts must be equal but are" @@ -442,7 +437,6 @@ def __init__( crop_transforms = [] for i in range(len(crop_sizes)): - random_resized_crop = T.RandomResizedCrop( crop_sizes[i], scale=(crop_min_scales[i], crop_max_scales[i]) ) @@ -533,7 +527,6 @@ def __init__( sigmas: Tuple[float, float] = (0.2, 2), normalize: dict = imagenet_normalize, ): - color_jitter = T.ColorJitter( cj_strength, cj_strength, @@ -654,7 +647,6 @@ def __init__( solarization_prob=0.2, normalize=imagenet_normalize, ): - flip_and_color_jitter = T.Compose( [ T.RandomHorizontalFlip(p=hf_prob), @@ -692,7 +684,10 @@ def __init__( global_crop, flip_and_color_jitter, GaussianBlur( - kernel_size=kernel_size, scale=kernel_scale, sigmas=sigmas, prob=gaussian_blur[0] + kernel_size=kernel_size, + scale=kernel_scale, + sigmas=sigmas, + prob=gaussian_blur[0], ), normalize, ] @@ -704,7 +699,10 @@ def __init__( global_crop, flip_and_color_jitter, GaussianBlur( - kernel_size=kernel_size, scale=kernel_scale, sigmas=sigmas, prob=gaussian_blur[1] + kernel_size=kernel_size, + scale=kernel_scale, + sigmas=sigmas, + prob=gaussian_blur[1], ), RandomSolarization(prob=solarization_prob), normalize, @@ -719,7 +717,10 @@ def __init__( ), flip_and_color_jitter, GaussianBlur( - kernel_size=kernel_size, scale=kernel_scale, sigmas=sigmas, prob=gaussian_blur[2] + kernel_size=kernel_size, + scale=kernel_scale, + sigmas=sigmas, + prob=gaussian_blur[2], ), normalize, ] @@ -1033,10 +1034,8 @@ def __init__( random_gray_scale: float = 0.2, normalize: dict = imagenet_normalize, ): - transforms = [] for i in range(len(crop_sizes)): - random_resized_crop = T.RandomResizedCrop( crop_sizes[i], scale=(crop_min_scales[i], crop_max_scales[i]) ) @@ -1146,7 +1145,6 @@ def __init__( rr_degrees: Union[None, float, Tuple[float, float]] = None, normalize: dict = imagenet_normalize, ): - if isinstance(input_size, tuple): input_size_ = max(input_size) else: @@ -1303,7 +1301,6 @@ def forward( torch.Tensor, torch.Tensor, ]: - """ Applies transforms to images in the input batch. diff --git a/lightly/data/dataset.py b/lightly/data/dataset.py index 6b81117e7..adbdf290e 100644 --- a/lightly/data/dataset.py +++ b/lightly/data/dataset.py @@ -3,28 +3,24 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -import os import bisect +import os import shutil import tempfile +from typing import Any, Callable, Dict, List, Union +import torchvision.datasets as datasets from PIL import Image -from typing import List, Union, Callable, Dict, Any from torch._C import Value - -import torchvision.datasets as datasets from torchvision import transforms -from lightly.data._helpers import _load_dataset_from_folder -from lightly.data._helpers import DatasetFolder +from lightly.data._helpers import DatasetFolder, _load_dataset_from_folder from lightly.data._video import VideoDataset from lightly.utils.io import check_filenames def _get_filename_by_index(dataset, index): - """Default function which maps the index of an image to a filename. - - """ + """Default function which maps the index of an image to a filename.""" if isinstance(dataset, datasets.ImageFolder): # filename is the path of the image relative to the dataset root full_path = dataset.imgs[index][0] @@ -42,17 +38,13 @@ def _get_filename_by_index(dataset, index): def _ensure_dir(path): - """Makes sure that the directory at path exists. - - """ + """Makes sure that the directory at path exists.""" dirname = os.path.dirname(path) os.makedirs(dirname, exist_ok=True) def _copy_image(input_dir, output_dir, filename): - """Copies an image from the input directory to the output directory. - - """ + """Copies an image from the input directory to the output directory.""" source = os.path.join(input_dir, filename) target = os.path.join(output_dir, filename) _ensure_dir(target) @@ -60,9 +52,7 @@ def _copy_image(input_dir, output_dir, filename): def _save_image(image, output_dir, filename, fmt): - """Saves an image in the output directory. - - """ + """Saves an image in the output directory.""" target = os.path.join(output_dir, filename) _ensure_dir(target) try: @@ -71,7 +61,7 @@ def _save_image(image, output_dir, filename, fmt): image.save(target, format=fmt) except ValueError: # could not determine format from filename - image.save(target, format='png') + image.save(target, format="png") def _dump_image(dataset, output_dir, filename, index, fmt): @@ -149,27 +139,24 @@ class LightlyDataset: >>> # `- ... """ - def __init__(self, - input_dir: Union[str, None], - transform: transforms.Compose = None, - index_to_filename: - Callable[[datasets.VisionDataset, int], str] = None, - filenames: List[str] = None, - tqdm_args: Dict[str, Any] = None, - num_workers_video_frame_counting: int = 0 - ): - + def __init__( + self, + input_dir: Union[str, None], + transform: transforms.Compose = None, + index_to_filename: Callable[[datasets.VisionDataset, int], str] = None, + filenames: List[str] = None, + tqdm_args: Dict[str, Any] = None, + num_workers_video_frame_counting: int = 0, + ): # can pass input_dir=None to create an "empty" dataset self.input_dir = input_dir if filenames is not None: - filepaths = [ - os.path.join(input_dir, filename) - for filename in filenames - ] + filepaths = [os.path.join(input_dir, filename) for filename in filenames] filepaths = set(filepaths) def is_valid_file(filepath: str): return filepath in filepaths + else: is_valid_file = None @@ -179,12 +166,11 @@ def is_valid_file(filepath: str): transform, is_valid_file=is_valid_file, tqdm_args=tqdm_args, - num_workers_video_frame_counting=num_workers_video_frame_counting + num_workers_video_frame_counting=num_workers_video_frame_counting, ) elif transform is not None: raise ValueError( - 'transform must be None when input_dir is None but is ' - f'{transform}', + "transform must be None when input_dir is None but is " f"{transform}", ) # initialize function to get filename of image @@ -198,10 +184,7 @@ def is_valid_file(filepath: str): check_filenames(self.get_filenames()) @classmethod - def from_torch_dataset(cls, - dataset, - transform=None, - index_to_filename=None): + def from_torch_dataset(cls, dataset, transform=None, index_to_filename=None): """Builds a LightlyDataset from a PyTorch (or torchvision) dataset. Args: @@ -252,22 +235,16 @@ def __getitem__(self, index: int): return sample, target, fname def __len__(self): - """Returns the length of the dataset. - - """ + """Returns the length of the dataset.""" return len(self.dataset) def __add__(self, other): - """Adds another item to the dataset. - - """ + """Adds another item to the dataset.""" raise NotImplementedError() def get_filenames(self) -> List[str]: - """Returns all filenames in the dataset. - - """ - if hasattr(self.dataset, 'get_filenames'): + """Returns all filenames in the dataset.""" + if hasattr(self.dataset, "get_filenames"): return self.dataset.get_filenames() list_of_filenames = [] @@ -276,10 +253,12 @@ def get_filenames(self) -> List[str]: list_of_filenames.append(fname) return list_of_filenames - def dump(self, - output_dir: str, - filenames: Union[List[str], None] = None, - format: Union[str, None] = None): + def dump( + self, + output_dir: str, + filenames: Union[List[str], None] = None, + format: Union[str, None] = None, + ): """Saves images in the dataset to the output directory. Will copy the images from the input directory to the output directory @@ -294,13 +273,13 @@ def dump(self, format: Image format. Can be any pillow image format (png, jpg, ...). By default we try to use the same format as the input data. If - not possible (e.g. for videos) we dump the image + not possible (e.g. for videos) we dump the image as a png image to prevent compression artifacts. """ if self.dataset.transform is not None: - raise RuntimeError('Cannot dump dataset which applies transforms!') + raise RuntimeError("Cannot dump dataset which applies transforms!") # create directory if it doesn't exist yet os.makedirs(output_dir, exist_ok=True) @@ -317,8 +296,10 @@ def dump(self, for index, filename in enumerate(all_filenames): filename_index = bisect.bisect_left(filenames, filename) # make sure the filename exists in filenames - if filename_index < len(filenames) and \ - filenames[filename_index] == filename: + if ( + filename_index < len(filenames) + and filenames[filename_index] == filename + ): indices.append(index) # dump images @@ -347,8 +328,7 @@ def get_filepath_from_filename(self, filename: str, image: Image = None): """ - has_input_dir = hasattr(self, 'input_dir') and \ - isinstance(self.input_dir, str) + has_input_dir = hasattr(self, "input_dir") and isinstance(self.input_dir, str) if has_input_dir: path_to_image = os.path.join(self.input_dir, filename) if os.path.isfile(path_to_image): @@ -357,14 +337,14 @@ def get_filepath_from_filename(self, filename: str, image: Image = None): if image is None: raise ValueError( - 'The parameter image must not be None for' - 'VideoDatasets and TorchDatasets' + "The parameter image must not be None for" + "VideoDatasets and TorchDatasets" ) # the file doesn't exist, save it as a jpg and return filepath folder_path = tempfile.mkdtemp() - filepath = os.path.join(folder_path, filename) + '.jpg' - + filepath = os.path.join(folder_path, filename) + ".jpg" + if os.path.dirname(filepath): os.makedirs(os.path.dirname(filepath), exist_ok=True) @@ -373,14 +353,10 @@ def get_filepath_from_filename(self, filename: str, image: Image = None): @property def transform(self): - """Getter for the transform of the dataset. - - """ + """Getter for the transform of the dataset.""" return self.dataset.transform @transform.setter def transform(self, t): - """Setter for the transform of the dataset. - - """ + """Setter for the transform of the dataset.""" self.dataset.transform = t diff --git a/lightly/data/lightly_subset.py b/lightly/data/lightly_subset.py index ef84bace8..e7d4f08fb 100644 --- a/lightly/data/lightly_subset.py +++ b/lightly/data/lightly_subset.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Tuple +from typing import Dict, List, Tuple from lightly.data.dataset import LightlyDataset @@ -21,8 +21,9 @@ def __init__(self, base_dataset: LightlyDataset, filenames_subset: List[str]): fname = base_dataset.index_to_filename(self.dataset, index) dict_base_dataset_filename_index[fname] = index - self.mapping_subset_index_to_baseset_index = \ - [dict_base_dataset_filename_index[filename] for filename in filenames_subset] + self.mapping_subset_index_to_baseset_index = [ + dict_base_dataset_filename_index[filename] for filename in filenames_subset + ] def __getitem__(self, index_subset: int) -> Tuple[object, object, str]: """An overwrite for indexing. @@ -50,9 +51,7 @@ def __len__(self) -> int: return len(self.filenames_subset) def get_filenames(self) -> List[str]: - """Returns all filenames in the subset. - - """ + """Returns all filenames in the subset.""" return self.filenames_subset def index_to_filename(self, dataset, index_subset: int): diff --git a/lightly/data/multi_view_collate.py b/lightly/data/multi_view_collate.py index a38022d19..ed3d8e166 100644 --- a/lightly/data/multi_view_collate.py +++ b/lightly/data/multi_view_collate.py @@ -1,8 +1,9 @@ -import torch -from torch import Tensor from typing import List, Tuple, Union from warnings import warn +import torch +from torch import Tensor + class MultiViewCollate: def __call__( diff --git a/lightly/embedding/_base.py b/lightly/embedding/_base.py index ee2e4f406..0a0e0e7cb 100644 --- a/lightly/embedding/_base.py +++ b/lightly/embedding/_base.py @@ -13,17 +13,10 @@ class BaseEmbedding(LightningModule): - """All trainable embeddings must inherit from BaseEmbedding. + """All trainable embeddings must inherit from BaseEmbedding.""" - """ - - def __init__(self, - model, - criterion, - optimizer, - dataloader, - scheduler=None): - """ Constructor + def __init__(self, model, criterion, optimizer, dataloader, scheduler=None): + """Constructor Args: model: (torch.nn.Module) @@ -46,7 +39,6 @@ def forward(self, x0, x1): return self.model(x0, x1) def training_step(self, batch, batch_idx): - # get the two image transformations (x0, x1), _, _ = batch # forward pass of the transformations @@ -54,7 +46,7 @@ def training_step(self, batch, batch_idx): # calculate loss loss = self.criterion(y0, y1) # log loss and return - self.log('loss', loss) + self.log("loss", loss) return loss def configure_optimizers(self): @@ -72,7 +64,7 @@ def train_embedding( checkpoint_callback_config: DictConfig, summary_callback_config: DictConfig, ): - """ Train the model on the provided dataset. + """Train the model on the provided dataset. Args: trainer_config: pylightning_trainer arguments, examples include: @@ -90,7 +82,9 @@ def train_embedding( """ trainer_callbacks = [] - checkpoint_cb = callbacks.create_checkpoint_callback(**checkpoint_callback_config) + checkpoint_cb = callbacks.create_checkpoint_callback( + **checkpoint_callback_config + ) trainer_callbacks.append(checkpoint_cb) summary_cb = callbacks.create_summary_callback( @@ -115,7 +109,5 @@ def train_embedding( self.checkpoint = os.path.join(self.cwd, checkpoint_cb.best_model_path) def embed(self, *args, **kwargs): - """Must be implemented by classes which inherit from BaseEmbedding. - - """ + """Must be implemented by classes which inherit from BaseEmbedding.""" raise NotImplementedError() diff --git a/lightly/embedding/callbacks.py b/lightly/embedding/callbacks.py index 02c1d5f21..af586c23a 100644 --- a/lightly/embedding/callbacks.py +++ b/lightly/embedding/callbacks.py @@ -9,7 +9,7 @@ def create_checkpoint_callback( save_last=False, save_top_k=0, - monitor='loss', + monitor="loss", dirpath=None, ) -> ModelCheckpoint: """Initializes the checkpoint callback. @@ -27,18 +27,18 @@ def create_checkpoint_callback( """ return ModelCheckpoint( dirpath=os.getcwd() if dirpath is None else dirpath, - filename='lightly_epoch_{epoch:d}', + filename="lightly_epoch_{epoch:d}", save_last=save_last, save_top_k=save_top_k, monitor=monitor, - auto_insert_metric_name=False) + auto_insert_metric_name=False, + ) def create_summary_callback( summary_callback_config: DictConfig, trainer_config: DictConfig ) -> ModelSummary: - """Creates a summary callback. - """ + """Creates a summary callback.""" # TODO: Drop support for the "weights_summary" argument. weights_summary = trainer_config.get("weights_summary", None) if weights_summary not in [None, "None"]: diff --git a/lightly/embedding/embedding.py b/lightly/embedding/embedding.py index d4e378e7d..d86e25a3a 100644 --- a/lightly/embedding/embedding.py +++ b/lightly/embedding/embedding.py @@ -4,14 +4,14 @@ # All Rights Reserved import time -from typing import List, Union, Tuple +from typing import List, Tuple, Union import numpy as np import torch -import lightly -from lightly.embedding._base import BaseEmbedding from tqdm import tqdm +import lightly +from lightly.embedding._base import BaseEmbedding from lightly.utils.reordering import sort_items_by_keys if lightly._is_prefetch_generator_available(): @@ -68,14 +68,12 @@ def __init__( dataloader: torch.utils.data.DataLoader, scheduler=None, ): - super(SelfSupervisedEmbedding, self).__init__( model, criterion, optimizer, dataloader, scheduler ) - def embed(self, - dataloader: torch.utils.data.DataLoader, - device: torch.device = None + def embed( + self, dataloader: torch.utils.data.DataLoader, device: torch.device = None ) -> Tuple[np.ndarray, List[int], List[str]]: """Embeds images in a vector space. @@ -109,20 +107,15 @@ def embed(self, dataset = dataloader.dataset if lightly._is_prefetch_generator_available(): dataloader = BackgroundGenerator(dataloader, max_prefetch=3) - - pbar = tqdm( - total=len(dataset), - unit='imgs' - ) + + pbar = tqdm(total=len(dataset), unit="imgs") efficiency = 0.0 embeddings = [] labels = [] with torch.no_grad(): - start_timepoint = time.time() - for (image_batch, label_batch, filename_batch) in dataloader: - + for image_batch, label_batch, filename_batch in dataloader: batch_size = image_batch.shape[0] # the following 2 lines are needed to prevent a file handler leak, @@ -159,12 +152,8 @@ def embed(self, labels = labels.cpu().numpy() sorted_filenames = dataset.get_filenames() - sorted_embeddings = sort_items_by_keys( - filenames, embeddings, sorted_filenames - ) - sorted_labels = sort_items_by_keys( - filenames, labels, sorted_filenames - ) + sorted_embeddings = sort_items_by_keys(filenames, embeddings, sorted_filenames) + sorted_labels = sort_items_by_keys(filenames, labels, sorted_filenames) embeddings = np.stack(sorted_embeddings) labels = np.stack(sorted_labels).tolist() diff --git a/lightly/loss/barlow_twins_loss.py b/lightly/loss/barlow_twins_loss.py index 2eee25b98..a9e3a5739 100644 --- a/lightly/loss/barlow_twins_loss.py +++ b/lightly/loss/barlow_twins_loss.py @@ -1,10 +1,11 @@ import torch import torch.distributed as dist + class BarlowTwinsLoss(torch.nn.Module): """Implementation of the Barlow Twins Loss from Barlow Twins[0] paper. This code specifically implements the Figure Algorithm 1 from [0]. - + [0] Zbontar,J. et.al, 2021, Barlow Twins... https://arxiv.org/abs/2103.03230 Examples: @@ -24,19 +25,15 @@ class BarlowTwinsLoss(torch.nn.Module): """ - def __init__( - self, - lambda_param: float = 5e-3, - gather_distributed : bool = False - ): + def __init__(self, lambda_param: float = 5e-3, gather_distributed: bool = False): """Lambda param configuration with default value like in [0] Args: - lambda_param: - Parameter for importance of redundancy reduction term. + lambda_param: + Parameter for importance of redundancy reduction term. Defaults to 5e-3 [0]. gather_distributed: - If True then the cross-correlation matrices from all gpus are + If True then the cross-correlation matrices from all gpus are gathered and summed before the loss calculation. """ super(BarlowTwinsLoss, self).__init__() @@ -44,18 +41,17 @@ def __init__( self.gather_distributed = gather_distributed def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: - device = z_a.device # normalize repr. along the batch dimension - z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0) # NxD - z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0) # NxD + z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0) # NxD + z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0) # NxD N = z_a.size(0) D = z_a.size(1) # cross-correlation matrix - c = torch.mm(z_a_norm.T, z_b_norm) / N # DxD + c = torch.mm(z_a_norm.T, z_b_norm) / N # DxD # sum cross-correlation matrix between multiple gpus if self.gather_distributed and dist.is_initialized(): @@ -65,7 +61,7 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: dist.all_reduce(c) # loss - c_diff = (c - torch.eye(D, device=device)).pow(2) # DxD + c_diff = (c - torch.eye(D, device=device)).pow(2) # DxD # multiply off-diagonal elems of c_diff by lambda c_diff[~torch.eye(D, dtype=bool)] *= self.lambda_param loss = c_diff.sum() diff --git a/lightly/loss/dcl_loss.py b/lightly/loss/dcl_loss.py index c3bac9b69..933191a3d 100644 --- a/lightly/loss/dcl_loss.py +++ b/lightly/loss/dcl_loss.py @@ -1,22 +1,20 @@ -from typing import Callable, Optional from functools import partial +from typing import Callable, Optional import torch -from torch import nn -from torch import Tensor +from torch import Tensor, nn from lightly.utils import dist + def negative_mises_fisher_weights( - out0: Tensor, - out1: Tensor, - sigma: float=0.5 + out0: Tensor, out1: Tensor, sigma: float = 0.5 ) -> torch.Tensor: """Negative Mises-Fisher weighting function as presented in Decoupled Contrastive Learning [0]. The implementation was inspired by [1]. - + - [0] Chun-Hsiao Y. et. al., 2021, Decoupled Contrastive Learning https://arxiv.org/abs/2110.06848 - [1] https://github.com/raminnakhli/Decoupled-Contrastive-Learning @@ -32,15 +30,16 @@ def negative_mises_fisher_weights( Returns: A tensor with shape (batch_size,) where each entry is the weight for one of the input images. - + """ - similarity = torch.einsum('nm,nm->n', out0.detach(), out1.detach()) / sigma + similarity = torch.einsum("nm,nm->n", out0.detach(), out1.detach()) / sigma return 2 - out0.shape[0] * nn.functional.softmax(similarity, dim=0) + class DCLLoss(nn.Module): """Implementation of the Decoupled Contrastive Learning Loss from Decoupled Contrastive Learning [0]. - + This code implements Equation 6 in [0], including the sum over all images `i` and views `k`. The loss is reduced to a mean loss over the mini-batch. The implementation was inspired by [1]. @@ -53,14 +52,14 @@ class DCLLoss(nn.Module): Similarities are scaled by inverse temperature. weight_fn: Weighting function `w` from the paper. Scales the loss between the - positive views (views from the same image). No weighting is performed + positive views (views from the same image). No weighting is performed if weight_fn is None. The function must take the two input tensors passed to the forward call as input and return a weight tensor. The returned weight tensor must have the same length as the input tensors. gather_distributed: - If True then negatives from all gpus are gathered before the + If True then negatives from all gpus are gathered before the loss calculation. - + Examples: >>> loss_fn = DCLLoss(temperature=0.07) @@ -79,8 +78,9 @@ class DCLLoss(nn.Module): >>> # you can also add a custom weighting function >>> weight_fn = lambda out0, out1: torch.sum((out0 - out1) ** 2, dim=1) >>> loss_fn = DCLLoss(weight_fn=weight_fn) - + """ + def __init__( self, temperature: float = 0.1, @@ -98,7 +98,7 @@ def forward( out1: Tensor, ) -> Tensor: """Forward pass of the DCL loss. - + Args: out0: Output projections of the first set of transformed images. @@ -106,7 +106,7 @@ def forward( out1: Output projections of the second set of transformed images. Shape: (batch_size, embedding_size) - + Returns: Mean loss over the mini-batch. """ @@ -130,10 +130,10 @@ def forward( def _loss(self, out0, out1, out0_all, out1_all): """Calculates DCL loss for out0 with respect to its positives in out1 and the negatives in out1, out0_all, and out1_all. - + This code implements Equation 6 in [0], including the sum over all images `i` but with `k` fixed at 0. - + Args: out0: Output projections of the first set of transformed images. @@ -143,12 +143,12 @@ def _loss(self, out0, out1, out0_all, out1_all): Shape: (batch_size, embedding_size) out0_all: Output projections of the first set of transformed images from - all distributed processes/gpus. Should be equal to out0 in an + all distributed processes/gpus. Should be equal to out0 in an undistributed setting. Shape (batch_size * world_size, embedding_size) out1_all: Output projections of the second set of transformed images from - all distributed processes/gpus. Should be equal to out1 in an + all distributed processes/gpus. Should be equal to out1 in an undistributed setting. Shape (batch_size * world_size, embedding_size) @@ -165,8 +165,8 @@ def _loss(self, out0, out1, out0_all, out1_all): # calculate similarities # here n = batch_size and m = batch_size * world_size. - sim_00 = torch.einsum('nc,mc->nm', out0, out0_all) / self.temperature - sim_01 = torch.einsum('nc,mc->nm', out0, out1_all) / self.temperature + sim_00 = torch.einsum("nc,mc->nm", out0, out0_all) / self.temperature + sim_01 = torch.einsum("nc,mc->nm", out0, out1_all) / self.temperature positive_loss = -sim_01[diag_mask] if self.weight_fn: @@ -174,7 +174,7 @@ def _loss(self, out0, out1, out0_all, out1_all): # remove simliarities between same views of the same image sim_00 = sim_00[~diag_mask].view(batch_size, -1) - # remove similarities between different views of the same images + # remove similarities between different views of the same images # this is the key difference compared to NTXentLoss sim_01 = sim_01[~diag_mask].view(batch_size, -1) @@ -182,12 +182,13 @@ def _loss(self, out0, out1, out0_all, out1_all): negative_loss_01 = torch.logsumexp(sim_01, dim=1) return (positive_loss + negative_loss_00 + negative_loss_01).mean() + class DCLWLoss(DCLLoss): """Implementation of the Weighted Decoupled Contrastive Learning Loss from Decoupled Contrastive Learning [0]. - - This code implements Equation 6 in [0] with a negative Mises-Fisher - weighting function. The loss returns the mean over all images `i` and + + This code implements Equation 6 in [0] with a negative Mises-Fisher + weighting function. The loss returns the mean over all images `i` and views `k` in the mini-batch. The implementation was inspired by [1]. - [0] Chun-Hsiao Y. et. al., 2021, Decoupled Contrastive Learning https://arxiv.org/abs/2110.06848 @@ -200,7 +201,7 @@ class DCLWLoss(DCLLoss): Similar to temperature but applies the inverse scaling in the weighting function. gather_distributed: - If True then negatives from all gpus are gathered before the + If True then negatives from all gpus are gathered before the loss calculation. Examples: @@ -217,8 +218,9 @@ class DCLWLoss(DCLLoss): >>> >>> # calculate loss >>> loss = loss_fn(out0, out1) - + """ + def __init__( self, temperature: float = 0.1, diff --git a/lightly/loss/dino_loss.py b/lightly/loss/dino_loss.py index 0a810619b..9c1ba17df 100644 --- a/lightly/loss/dino_loss.py +++ b/lightly/loss/dino_loss.py @@ -8,7 +8,7 @@ class DINOLoss(nn.Module): """ - Implementation of the loss described in 'Emerging Properties in + Implementation of the loss described in 'Emerging Properties in Self-Supervised Vision Transformers'. [0] This implementation follows the code published by the authors. [1] @@ -47,17 +47,18 @@ class DINOLoss(nn.Module): >>> # embed the view with a student and teacher model >>> teacher_out = teacher(view) >>> student_out = student(view) - >>> + >>> >>> # calculate loss >>> loss = loss_fn([teacher_out], [student_out], epoch=0) """ + def __init__( - self, + self, output_dim: int, - warmup_teacher_temp: float = 0.04, + warmup_teacher_temp: float = 0.04, teacher_temp: float = 0.04, - warmup_teacher_temp_epochs: int = 30, + warmup_teacher_temp_epochs: int = 30, student_temp: float = 0.1, center_momentum: float = 0.9, ): @@ -66,23 +67,23 @@ def __init__( self.teacher_temp = teacher_temp self.student_temp = student_temp self.center_momentum = center_momentum - + self.register_buffer("center", torch.zeros(1, 1, output_dim)) # we apply a warm up for the teacher temperature because # a too high temperature makes the training instable at the beginning self.teacher_temp_schedule = torch.linspace( - start=warmup_teacher_temp, + start=warmup_teacher_temp, end=teacher_temp, steps=warmup_teacher_temp_epochs, ) def forward( - self, + self, teacher_out: List[torch.Tensor], student_out: List[torch.Tensor], epoch: int, ) -> torch.Tensor: - """Cross-entropy between softmax outputs of the teacher and student + """Cross-entropy between softmax outputs of the teacher and student networks. Args: @@ -91,8 +92,8 @@ def forward( tensor is assumed to contain features from one view of the batch and have length batch_size. student_out: - List of view feature tensors from the student model. Each tensor - is assumed to contain features from one view of the batch and + List of view feature tensors from the student model. Each tensor + is assumed to contain features from one view of the batch and have length batch_size. epoch: The current training epoch. @@ -106,7 +107,7 @@ def forward( teacher_temp = self.teacher_temp_schedule[epoch] else: teacher_temp = self.teacher_temp - + teacher_out = torch.stack(teacher_out) t_out = F.softmax((teacher_out - self.center) / teacher_temp, dim=-1) @@ -115,11 +116,11 @@ def forward( # calculate feature similarities where: # b = batch_size, t = n_views_teacher, s = n_views_student, d = output_dim - # the diagonal is ignored as it contains features from the same views - loss = -torch.einsum('tbd,sbd->ts', t_out, s_out) + # the diagonal is ignored as it contains features from the same views + loss = -torch.einsum("tbd,sbd->ts", t_out, s_out) loss.fill_diagonal_(0) - # number of loss terms, ignoring the diagonal + # number of loss terms, ignoring the diagonal n_terms = loss.numel() - loss.diagonal().numel() batch_size = teacher_out.shape[1] loss = loss.sum() / (n_terms * batch_size) @@ -142,4 +143,6 @@ def update_center(self, teacher_out: torch.Tensor) -> None: batch_center = batch_center / dist.get_world_size() # ema update - self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) + self.center = self.center * self.center_momentum + batch_center * ( + 1 - self.center_momentum + ) diff --git a/lightly/loss/hypersphere_loss.py b/lightly/loss/hypersphere_loss.py index bf9e74b75..3d1d248a6 100644 --- a/lightly/loss/hypersphere_loss.py +++ b/lightly/loss/hypersphere_loss.py @@ -11,7 +11,7 @@ class HypersphereLoss(torch.nn.Module): """ Implementation of the loss described in 'Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere.' [0] - + [0] Tongzhou Wang. et.al, 2020, ... https://arxiv.org/abs/2005.10242 Note: @@ -46,7 +46,7 @@ class HypersphereLoss(torch.nn.Module): """ - def __init__(self, t=1., lam=1., alpha=2.): + def __init__(self, t=1.0, lam=1.0, alpha=2.0): """Parameters as described in [0] Args: @@ -81,7 +81,9 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: def lalign(x, y): return (x - y).norm(dim=1).pow(self.alpha).mean() + def lunif(x): sq_pdist = torch.pdist(x, p=2).pow(2) return sq_pdist.mul(-self.t).exp().mean().log() + return lalign(x, y) + self.lam * (lunif(x) + lunif(y)) / 2 diff --git a/lightly/loss/memory_bank.py b/lightly/loss/memory_bank.py index f3ecfc265..7bc972b74 100644 --- a/lightly/loss/memory_bank.py +++ b/lightly/loss/memory_bank.py @@ -3,14 +3,16 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -import torch import functools +import torch + + class MemoryBankModule(torch.nn.Module): """Memory bank implementation This is a parent class to all loss functions implemented by the lightly - Python package. This way, any loss can be used with a memory bank if + Python package. This way, any loss can be used with a memory bank if desired. Attributes: @@ -37,17 +39,20 @@ class MemoryBankModule(torch.nn.Module): """ - def __init__(self, size: int = 2 ** 16): - + def __init__(self, size: int = 2**16): super(MemoryBankModule, self).__init__() if size < 0: - msg = f'Illegal memory bank size {size}, must be non-negative.' + msg = f"Illegal memory bank size {size}, must be non-negative." raise ValueError(msg) self.size = size - self.register_buffer("bank", tensor=torch.empty(0, dtype=torch.float), persistent=False) - self.register_buffer("bank_ptr", tensor=torch.empty(0, dtype=torch.long), persistent=False) + self.register_buffer( + "bank", tensor=torch.empty(0, dtype=torch.float), persistent=False + ) + self.register_buffer( + "bank_ptr", tensor=torch.empty(0, dtype=torch.long), persistent=False + ) @torch.no_grad() def _init_memory_bank(self, dim: int): @@ -79,16 +84,15 @@ def _dequeue_and_enqueue(self, batch: torch.Tensor): ptr = int(self.bank_ptr) if ptr + batch_size >= self.size: - self.bank[:, ptr:] = batch[:self.size - ptr].T.detach() + self.bank[:, ptr:] = batch[: self.size - ptr].T.detach() self.bank_ptr[0] = 0 else: - self.bank[:, ptr:ptr + batch_size] = batch.T.detach() + self.bank[:, ptr : ptr + batch_size] = batch.T.detach() self.bank_ptr[0] = ptr + batch_size - def forward(self, - output: torch.Tensor, - labels: torch.Tensor = None, - update: bool = False): + def forward( + self, output: torch.Tensor, labels: torch.Tensor = None, update: bool = False + ): """Query memory bank for additional negative samples Args: diff --git a/lightly/loss/msn_loss.py b/lightly/loss/msn_loss.py index 5d8b64182..daa52745e 100644 --- a/lightly/loss/msn_loss.py +++ b/lightly/loss/msn_loss.py @@ -1,9 +1,9 @@ import math import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -import torch.distributed as dist def prototype_probabilities( @@ -28,6 +28,7 @@ def prototype_probabilities( """ return F.softmax(torch.matmul(queries, prototypes.T) / temperature, dim=1) + def sharpen(probabilities: torch.Tensor, temperature: float) -> torch.Tensor: """Sharpens the probabilities with the given temperature. @@ -45,6 +46,7 @@ def sharpen(probabilities: torch.Tensor, temperature: float) -> torch.Tensor: probabilities /= torch.sum(probabilities, dim=1, keepdim=True) return probabilities + @torch.no_grad() def sinkhorn( probabilities: torch.Tensor, @@ -134,6 +136,7 @@ class MSNLoss(nn.Module): >>> loss = loss_fn(anchors_out, targets_out, prototypes=model.prototypes) """ + def __init__( self, temperature: float = 0.1, @@ -176,27 +179,33 @@ def forward( prototypes = F.normalize(prototypes, dim=1) # anchor predictions - anchor_probs = prototype_probabilities(anchors, prototypes, temperature=self.temperature) + anchor_probs = prototype_probabilities( + anchors, prototypes, temperature=self.temperature + ) # target predictions with torch.no_grad(): - target_probs = prototype_probabilities(targets, prototypes, temperature=self.temperature) + target_probs = prototype_probabilities( + targets, prototypes, temperature=self.temperature + ) target_probs = sharpen(target_probs, temperature=target_sharpen_temperature) if self.sinkhorn_iterations > 0: target_probs = sinkhorn( probabilities=target_probs, iterations=self.sinkhorn_iterations, - gather_distributed=self.gather_distributed + gather_distributed=self.gather_distributed, ) target_probs = target_probs.repeat((num_views, 1)) # cross entropy loss - loss = torch.mean(torch.sum(torch.log(anchor_probs**(-target_probs)), dim=1)) + loss = torch.mean(torch.sum(torch.log(anchor_probs ** (-target_probs)), dim=1)) - # mean entropy maximization regularization + # mean entropy maximization regularization if self.me_max_weight > 0: mean_anchor_probs = torch.mean(anchor_probs, dim=0) - me_max_loss = torch.sum(torch.log(mean_anchor_probs**(-mean_anchor_probs))) + me_max_loss = torch.sum( + torch.log(mean_anchor_probs ** (-mean_anchor_probs)) + ) me_max_loss += math.log(float(len(mean_anchor_probs))) loss -= self.me_max_weight * me_max_loss diff --git a/lightly/loss/ntx_ent_loss.py b/lightly/loss/ntx_ent_loss.py index 14c2bfe6c..05c157f3c 100644 --- a/lightly/loss/ntx_ent_loss.py +++ b/lightly/loss/ntx_ent_loss.py @@ -14,20 +14,20 @@ class NTXentLoss(MemoryBankModule): """Implementation of the Contrastive Cross Entropy Loss. This implementation follows the SimCLR[0] paper. If you enable the memory - bank by setting the `memory_bank_size` value > 0 the loss behaves like + bank by setting the `memory_bank_size` value > 0 the loss behaves like the one described in the MoCo[1] paper. - [0] SimCLR, 2020, https://arxiv.org/abs/2002.05709 - [1] MoCo, 2020, https://arxiv.org/abs/1911.05722 - + Attributes: temperature: Scale logits by the inverse of the temperature. memory_bank_size: - Number of negative samples to store in the memory bank. + Number of negative samples to store in the memory bank. Use 0 for SimCLR. For MoCo we typically use numbers like 4096 or 65536. gather_distributed: - If True then negatives from all gpus are gathered before the + If True then negatives from all gpus are gathered before the loss calculation. This flag has no effect if memory_bank_size > 0. Raises: @@ -51,10 +51,12 @@ class NTXentLoss(MemoryBankModule): """ - def __init__(self, - temperature: float = 0.5, - memory_bank_size: int = 0, - gather_distributed: bool = False): + def __init__( + self, + temperature: float = 0.5, + memory_bank_size: int = 0, + gather_distributed: bool = False, + ): super(NTXentLoss, self).__init__(size=memory_bank_size) self.temperature = temperature self.gather_distributed = gather_distributed @@ -62,16 +64,15 @@ def __init__(self, self.eps = 1e-8 if abs(self.temperature) < self.eps: - raise ValueError('Illegal temperature: abs({}) < 1e-8' - .format(self.temperature)) + raise ValueError( + "Illegal temperature: abs({}) < 1e-8".format(self.temperature) + ) - def forward(self, - out0: torch.Tensor, - out1: torch.Tensor): + def forward(self, out0: torch.Tensor, out1: torch.Tensor): """Forward pass through Contrastive Cross-Entropy Loss. If used with a memory bank, the samples from the memory bank are used - as negative examples. Otherwise, within-batch samples are used as + as negative examples. Otherwise, within-batch samples are used as negative samples. Args: @@ -94,14 +95,15 @@ def forward(self, out0 = nn.functional.normalize(out0, dim=1) out1 = nn.functional.normalize(out1, dim=1) - # ask memory bank for negative samples and extend it with out1 if - # out1 requires a gradient, otherwise keep the same vectors in the + # ask memory bank for negative samples and extend it with out1 if + # out1 requires a gradient, otherwise keep the same vectors in the # memory bank (this allows for keeping the memory bank constant e.g. # for evaluating the loss on the test set) # out1: shape: (batch_size, embedding_size) # negatives: shape: (embedding_size, memory_bank_size) - out1, negatives = \ - super(NTXentLoss, self).forward(out1, update=out0.requires_grad) + out1, negatives = super(NTXentLoss, self).forward( + out1, update=out0.requires_grad + ) # We use the cosine similarity, which is a dot product (einsum) here, # as all vectors are already normalized to unit length. @@ -113,11 +115,11 @@ def forward(self, # sim_pos is of shape (batch_size, 1) and sim_pos[i] denotes the similarity # of the i-th sample in the batch to its positive pair - sim_pos = torch.einsum('nc,nc->n', out0, out1).unsqueeze(-1) + sim_pos = torch.einsum("nc,nc->n", out0, out1).unsqueeze(-1) # sim_neg is of shape (batch_size, memory_bank_size) and sim_neg[i,j] denotes the similarity # of the i-th sample to the j-th negative sample - sim_neg = torch.einsum('nc,ck->nk', out0, negatives) + sim_neg = torch.einsum("nc,ck->nk", out0, negatives) # set the labels to the first "class", i.e. sim_pos, # so that it is maximized in relation to sim_neg @@ -126,7 +128,7 @@ def forward(self, else: # user other samples from batch as negatives - # and create diagonal mask that only selects similarities between + # and create diagonal mask that only selects similarities between # views of the same image if self.gather_distributed and dist.world_size() > 1: # gather hidden representations from other processes @@ -142,11 +144,11 @@ def forward(self, # calculate similiarities # here n = batch_size and m = batch_size * world_size # the resulting vectors have shape (n, m) - logits_00 = torch.einsum('nc,mc->nm', out0, out0_large) / self.temperature - logits_01 = torch.einsum('nc,mc->nm', out0, out1_large) / self.temperature - logits_10 = torch.einsum('nc,mc->nm', out1, out0_large) / self.temperature - logits_11 = torch.einsum('nc,mc->nm', out1, out1_large) / self.temperature - + logits_00 = torch.einsum("nc,mc->nm", out0, out0_large) / self.temperature + logits_01 = torch.einsum("nc,mc->nm", out0, out1_large) / self.temperature + logits_10 = torch.einsum("nc,mc->nm", out1, out0_large) / self.temperature + logits_11 = torch.einsum("nc,mc->nm", out1, out1_large) / self.temperature + # remove simliarities between same views of the same image logits_00 = logits_00[~diag_mask].view(batch_size, -1) logits_11 = logits_11[~diag_mask].view(batch_size, -1) diff --git a/lightly/loss/regularizer/co2.py b/lightly/loss/regularizer/co2.py index 8daf1cc7c..2145cf04a 100644 --- a/lightly/loss/regularizer/co2.py +++ b/lightly/loss/regularizer/co2.py @@ -4,6 +4,7 @@ # All Rights Reserved import torch + from lightly.loss.memory_bank import MemoryBankModule @@ -34,38 +35,30 @@ class CO2Regularizer(MemoryBankModule): >>> >>> # feed through the MoCo model >>> out0, out1 = model(t0, t1) - >>> + >>> >>> # calculate loss and apply regularizer >>> loss = loss_fn(out0, out1) + co2(out0, out1) """ - def __init__(self, - alpha: float = 1, - t_consistency: float = 0.05, - memory_bank_size: int = 0): - + def __init__( + self, alpha: float = 1, t_consistency: float = 0.05, memory_bank_size: int = 0 + ): super(CO2Regularizer, self).__init__(size=memory_bank_size) # try-catch the KLDivLoss construction for backwards compatability self.log_target = True try: - self.kl_div = torch.nn.KLDivLoss( - reduction='batchmean', - log_target=True - ) + self.kl_div = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) except TypeError: self.log_target = False - self.kl_div = torch.nn.KLDivLoss( - reduction='batchmean' - ) + self.kl_div = torch.nn.KLDivLoss(reduction="batchmean") self.t_consistency = t_consistency self.alpha = alpha - def _get_pseudo_labels(self, - out0: torch.Tensor, - out1: torch.Tensor, - negatives: torch.Tensor = None): + def _get_pseudo_labels( + self, out0: torch.Tensor, out1: torch.Tensor, negatives: torch.Tensor = None + ): """Computes the soft pseudo labels across negative samples. Args: @@ -90,8 +83,8 @@ def _get_pseudo_labels(self, if negatives is None: # use second batch as negative samples # l_pos has shape bsz x 1 and l_neg has shape bsz x bsz - l_pos = torch.einsum('nc,nc->n', [out0, out1]).unsqueeze(-1) - l_neg = torch.einsum('nc,ck->nk', [out0, out1.t()]) + l_pos = torch.einsum("nc,nc->n", [out0, out1]).unsqueeze(-1) + l_neg = torch.einsum("nc,ck->nk", [out0, out1.t()]) # remove elements on the diagonal # l_neg has shape bsz x (bsz - 1) l_neg = l_neg.masked_select( @@ -101,21 +94,18 @@ def _get_pseudo_labels(self, # use memory bank as negative samples # l_pos has shape bsz x 1 and l_neg has shape bsz x memory_bank_size negatives = negatives.to(out0.device) - l_pos = torch.einsum('nc,nc->n', [out0, out1]).unsqueeze(-1) - l_neg = torch.einsum('nc,ck->nk', [out0, negatives.clone().detach()]) - + l_pos = torch.einsum("nc,nc->n", [out0, out1]).unsqueeze(-1) + l_neg = torch.einsum("nc,ck->nk", [out0, negatives.clone().detach()]) + # concatenate such that positive samples are at index 0 logits = torch.cat([l_pos, l_neg], dim=1) # divide by temperature logits = logits / self.t_consistency - # the input to kl_div is expected to be log(p) + # the input to kl_div is expected to be log(p) return torch.nn.functional.log_softmax(logits, dim=-1) - - def forward(self, - out0: torch.Tensor, - out1: torch.Tensor): + def forward(self, out0: torch.Tensor, out1: torch.Tensor): """Computes the CO2 regularization term for two model outputs. Args: @@ -133,18 +123,17 @@ def forward(self, out0 = torch.nn.functional.normalize(out0, dim=1) out1 = torch.nn.functional.normalize(out1, dim=1) - # ask memory bank for negative samples and extend it with out1 if - # out1 requires a gradient, otherwise keep the same vectors in the + # ask memory bank for negative samples and extend it with out1 if + # out1 requires a gradient, otherwise keep the same vectors in the # memory bank (this allows for keeping the memory bank constant e.g. # for evaluating the loss on the test set) # if the memory_bank size is 0, negatives will be None - out1, negatives = \ - super(CO2Regularizer, self).forward(out1, update=True) - + out1, negatives = super(CO2Regularizer, self).forward(out1, update=True) + # get log probabilities p = self._get_pseudo_labels(out0, out1, negatives) q = self._get_pseudo_labels(out1, out0, negatives) - + # calculate symmetrized kullback leibler divergence if self.log_target: div = self.kl_div(p, q) + self.kl_div(q, p) diff --git a/lightly/loss/swav_loss.py b/lightly/loss/swav_loss.py index 9664c4905..c7a33cf2e 100644 --- a/lightly/loss/swav_loss.py +++ b/lightly/loss/swav_loss.py @@ -1,24 +1,24 @@ from typing import List import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F -import torch.distributed as dist @torch.no_grad() def sinkhorn( - out: torch.Tensor, - iterations: int = 3, + out: torch.Tensor, + iterations: int = 3, epsilon: float = 0.05, gather_distributed: bool = False, ) -> torch.Tensor: """Distributed sinkhorn algorithm. As outlined in [0] and implemented in [1]. - + [0]: SwaV, 2020, https://arxiv.org/abs/2006.09882 - [1]: https://github.com/facebookresearch/swav/ + [1]: https://github.com/facebookresearch/swav/ Args: out: @@ -29,11 +29,11 @@ def sinkhorn( Temperature parameter. gather_distributed: If True then features from all gpus are gathered to calculate the - soft codes Q. + soft codes Q. Returns: Soft codes Q assigning each feature to a prototype. - + """ world_size = 1 if gather_distributed and dist.is_initialized(): @@ -74,22 +74,23 @@ class SwaVLoss(nn.Module): Temperature parameter used in the sinkhorn algorithm. sinkhorn_gather_distributed: If True then features from all gpus are gathered to calculate the - soft codes in the sinkhorn algorithm. - + soft codes in the sinkhorn algorithm. + """ - def __init__(self, - temperature: float = 0.1, - sinkhorn_iterations: int = 3, - sinkhorn_epsilon: float = 0.05, - sinkhorn_gather_distributed: bool = False): + def __init__( + self, + temperature: float = 0.1, + sinkhorn_iterations: int = 3, + sinkhorn_epsilon: float = 0.05, + sinkhorn_gather_distributed: bool = False, + ): super(SwaVLoss, self).__init__() self.temperature = temperature self.sinkhorn_iterations = sinkhorn_iterations self.sinkhorn_epsilon = sinkhorn_epsilon self.sinkhorn_gather_distributed = sinkhorn_gather_distributed - def subloss(self, z: torch.Tensor, q: torch.Tensor): """Calculates the cross entropy for the SwaV prediction problem. @@ -103,15 +104,16 @@ def subloss(self, z: torch.Tensor, q: torch.Tensor): Cross entropy between predictions z and codes q. """ - return - torch.mean( + return -torch.mean( torch.sum(q * F.log_softmax(z / self.temperature, dim=1), dim=1) ) - - def forward(self, - high_resolution_outputs: List[torch.Tensor], - low_resolution_outputs: List[torch.Tensor], - queue_outputs: List[torch.Tensor]=None): + def forward( + self, + high_resolution_outputs: List[torch.Tensor], + low_resolution_outputs: List[torch.Tensor], + queue_outputs: List[torch.Tensor] = None, + ): """Computes the SwaV loss for a set of high and low resolution outputs. Args: @@ -134,7 +136,7 @@ def forward(self, n_crops = len(high_resolution_outputs) + len(low_resolution_outputs) # multi-crop iterations - loss = 0. + loss = 0.0 for i in range(len(high_resolution_outputs)): # compute codes of i-th high resolution crop with torch.no_grad(): @@ -154,10 +156,10 @@ def forward(self, # Drop queue similarities if queue_outputs is not None: - q = q[:len(high_resolution_outputs[i])] + q = q[: len(high_resolution_outputs[i])] # compute subloss for each pair of crops - subloss = 0. + subloss = 0.0 for v in range(len(high_resolution_outputs)): if v != i: subloss += self.subloss(high_resolution_outputs[v], q) diff --git a/lightly/loss/sym_neg_cos_sim_loss.py b/lightly/loss/sym_neg_cos_sim_loss.py index d906159ad..19af40930 100644 --- a/lightly/loss/sym_neg_cos_sim_loss.py +++ b/lightly/loss/sym_neg_cos_sim_loss.py @@ -3,14 +3,16 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -import torch import warnings +import torch + + class SymNegCosineSimilarityLoss(torch.nn.Module): """Implementation of the Symmetrized Loss used in the SimSiam[0] paper. [0] SimSiam, 2020, https://arxiv.org/abs/2011.10566 - + Examples: >>> # initialize loss function @@ -39,37 +41,36 @@ def __init__(self) -> None: ) def _neg_cosine_simililarity(self, x, y): - v = - torch.nn.functional.cosine_similarity(x, y.detach(), dim=-1).mean() + v = -torch.nn.functional.cosine_similarity(x, y.detach(), dim=-1).mean() return v - def forward(self, - out0: torch.Tensor, - out1: torch.Tensor): + def forward(self, out0: torch.Tensor, out1: torch.Tensor): """Forward pass through Symmetric Loss. - Args: - out0: - Output projections of the first set of transformed images. - Expects the tuple to be of the form (z0, p0), where z0 is - the output of the backbone and projection mlp, and p0 is the - output of the prediction head. - out1: - Output projections of the second set of transformed images. - Expects the tuple to be of the form (z1, p1), where z1 is - the output of the backbone and projection mlp, and p1 is the - output of the prediction head. - - Returns: - Contrastive Cross Entropy Loss value. - - Raises: - ValueError if shape of output is not multiple of batch_size. + Args: + out0: + Output projections of the first set of transformed images. + Expects the tuple to be of the form (z0, p0), where z0 is + the output of the backbone and projection mlp, and p0 is the + output of the prediction head. + out1: + Output projections of the second set of transformed images. + Expects the tuple to be of the form (z1, p1), where z1 is + the output of the backbone and projection mlp, and p1 is the + output of the prediction head. + + Returns: + Contrastive Cross Entropy Loss value. + + Raises: + ValueError if shape of output is not multiple of batch_size. """ z0, p0 = out0 z1, p1 = out1 - loss = self._neg_cosine_simililarity(p0, z1) / 2 + \ - self._neg_cosine_simililarity(p1, z0) / 2 + loss = ( + self._neg_cosine_simililarity(p0, z1) / 2 + + self._neg_cosine_simililarity(p1, z0) / 2 + ) return loss - \ No newline at end of file diff --git a/lightly/loss/tico_loss.py b/lightly/loss/tico_loss.py index 29bc4f118..f08047fca 100644 --- a/lightly/loss/tico_loss.py +++ b/lightly/loss/tico_loss.py @@ -5,16 +5,17 @@ from lightly.utils.dist import gather + class TiCoLoss(torch.nn.Module): """Implementation of the Tico Loss from Tico[0] paper. - This implementation takes inspiration from the code published + This implementation takes inspiration from the code published by sayannag using Lightly. [1] [0] Jiachen Zhu et. al, 2022, Tico... https://arxiv.org/abs/2206.10698 [1] https://github.com/sayannag/TiCo-pytorch - + Attributes: - + Args: beta: Coefficient for the EMA update of the covariance @@ -25,9 +26,9 @@ class TiCoLoss(torch.nn.Module): gather_distributed: If True then the cross-correlation matrices from all gpus are gathered and summed before the loss calculation. - + Examples: - + >>> # initialize loss function >>> loss_fn = TiCoLoss() >>> @@ -54,7 +55,12 @@ def __init__( self.C = None self.gather_distributed = gather_distributed - def forward(self, z_a: torch.Tensor, z_b: torch.Tensor, update_covariance_matrix: bool = True) -> torch.Tensor: + def forward( + self, + z_a: torch.Tensor, + z_b: torch.Tensor, + update_covariance_matrix: bool = True, + ) -> torch.Tensor: """Tico Loss computation. It maximize the agreement among embeddings of different distorted versions of the same image while avoiding collapse using Covariance matrix. @@ -71,8 +77,12 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor, update_covariance_matrix """ - assert z_a.shape[0] > 1 and z_b.shape[0] > 1, f"z_a and z_b must have batch size > 1 but found {z_a.shape[0]} and {z_b.shape[0]}" - assert z_a.shape == z_b.shape, f"z_a and z_b must have same shape but found {z_a.shape} and {z_b.shape}." + assert ( + z_a.shape[0] > 1 and z_b.shape[0] > 1 + ), f"z_a and z_b must have batch size > 1 but found {z_a.shape[0]} and {z_b.shape[0]}" + assert ( + z_a.shape == z_b.shape + ), f"z_a and z_b must have same shape but found {z_a.shape} and {z_b.shape}." # gather all batches if self.gather_distributed and dist.is_initialized(): @@ -82,22 +92,26 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor, update_covariance_matrix z_b = torch.cat(gather(z_b), dim=0) # normalize image - z_a = torch.nn.functional.normalize(z_a, dim = 1) - z_b = torch.nn.functional.normalize(z_b, dim = 1) - + z_a = torch.nn.functional.normalize(z_a, dim=1) + z_b = torch.nn.functional.normalize(z_b, dim=1) + # compute auxiliary matrix B - B = torch.mm(z_a.T, z_a)/z_a.shape[0] + B = torch.mm(z_a.T, z_a) / z_a.shape[0] # init covariance matrix if self.C is None: - self.C = B.new_zeros(B.shape).detach() + self.C = B.new_zeros(B.shape).detach() # compute loss C = self.beta * self.C + (1 - self.beta) * B - loss = 1 - (z_a * z_b).sum(dim=1).mean() + self.rho * (torch.mm(z_a, C) * z_a).sum(dim=1).mean() + loss = ( + 1 + - (z_a * z_b).sum(dim=1).mean() + + self.rho * (torch.mm(z_a, C) * z_a).sum(dim=1).mean() + ) # update covariance matrix if update_covariance_matrix: self.C = C.detach() - + return loss diff --git a/lightly/loss/vicreg_loss.py b/lightly/loss/vicreg_loss.py index 6faac2ec2..bfbd6a436 100644 --- a/lightly/loss/vicreg_loss.py +++ b/lightly/loss/vicreg_loss.py @@ -4,15 +4,16 @@ from lightly.utils.dist import gather + class VICRegLoss(torch.nn.Module): """Implementation of the VICReg Loss from VICReg[0] paper. This implementation follows the code published by the authors. [1] [0] Bardes, A. et. al, 2022, VICReg... https://arxiv.org/abs/2105.04906 [1] https://github.com/facebookresearch/vicreg/ - + Examples: - + >>> # initialize loss function >>> loss_fn = VICRegLoss() >>> @@ -32,8 +33,8 @@ def __init__( lambda_param: float = 25.0, mu_param: float = 25.0, nu_param: float = 1.0, - gather_distributed : bool = False, - eps = 0.0001 + gather_distributed: bool = False, + eps=0.0001, ): """Lambda, mu and nu params configuration with default value like in [0] Args: @@ -62,8 +63,12 @@ def __init__( self.eps = eps def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: - assert z_a.shape[0] > 1 and z_b.shape[0] > 1, f"z_a and z_b must have batch size > 1 but found {z_a.shape[0]} and {z_b.shape[0]}" - assert z_a.shape == z_b.shape, f"z_a and z_b must have same shape but found {z_a.shape} and {z_b.shape}." + assert ( + z_a.shape[0] > 1 and z_b.shape[0] > 1 + ), f"z_a and z_b must have batch size > 1 but found {z_a.shape[0]} and {z_b.shape[0]}" + assert ( + z_a.shape == z_b.shape + ), f"z_a and z_b must have same shape but found {z_a.shape} and {z_b.shape}." # invariance term of the loss repr_loss = F.mse_loss(z_a, z_b) @@ -76,8 +81,8 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: z_b = torch.cat(gather(z_b), dim=0) # normalize repr. along the batch dimension - z_a = z_a - z_a.mean(0) # NxD - z_b = z_b - z_b.mean(0) # NxD + z_a = z_a - z_a.mean(0) # NxD + z_b = z_b - z_b.mean(0) # NxD N = z_a.size(0) D = z_a.size(1) @@ -96,9 +101,15 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: off_diag_cov_x = cov_x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() off_diag_cov_y = cov_y.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() - cov_loss = off_diag_cov_x.pow_(2).sum().div(D) + off_diag_cov_y.pow_(2).sum().div(D) + cov_loss = off_diag_cov_x.pow_(2).sum().div(D) + off_diag_cov_y.pow_( + 2 + ).sum().div(D) # loss - loss = self.lambda_param * repr_loss + self.mu_param * std_loss + self.nu_param * cov_loss + loss = ( + self.lambda_param * repr_loss + + self.mu_param * std_loss + + self.nu_param * cov_loss + ) return loss diff --git a/lightly/loss/vicregl_loss.py b/lightly/loss/vicregl_loss.py index c075a739d..0b504d7da 100644 --- a/lightly/loss/vicregl_loss.py +++ b/lightly/loss/vicregl_loss.py @@ -1,12 +1,13 @@ +import copy +from typing import List, Tuple + import torch import torch.distributed as dist import torch.nn.functional as F -from lightly.utils.dist import gather from lightly.loss.vicreg_loss import VICRegLoss from lightly.models.utils import nearest_neighbors -from typing import List, Tuple -import copy +from lightly.utils.dist import gather class VICRegLLoss(torch.nn.Module): @@ -127,7 +128,8 @@ def local_loss( A tensor of grids for the local maps. It has size: [batch_size, grid_size, grid_size, 2] Returns: - A tensor of the local loss between the two sets of maps. It has size: [batch_size]""" + A tensor of the local loss between the two sets of maps. It has size: [batch_size] + """ inv_loss = 0.0 diff --git a/lightly/models/__init__.py b/lightly/models/__init__.py index 5a500e940..ed2f4613e 100644 --- a/lightly/models/__init__.py +++ b/lightly/models/__init__.py @@ -18,14 +18,12 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -from lightly.models.resnet import ResNetGenerator +from lightly.models import utils from lightly.models.barlowtwins import BarlowTwins -from lightly.models.simclr import SimCLR -from lightly.models.simsiam import SimSiam from lightly.models.byol import BYOL from lightly.models.moco import MoCo from lightly.models.nnclr import NNCLR -from lightly.models.zoo import ZOO -from lightly.models.zoo import checkpoints - -from lightly.models import utils +from lightly.models.resnet import ResNetGenerator +from lightly.models.simclr import SimCLR +from lightly.models.simsiam import SimSiam +from lightly.models.zoo import ZOO, checkpoints diff --git a/lightly/models/_momentum.py b/lightly/models/_momentum.py index 3a9f944eb..f7cdec507 100644 --- a/lightly/models/_momentum.py +++ b/lightly/models/_momentum.py @@ -10,19 +10,15 @@ def _deactivate_requires_grad(params): - """Deactivates the requires_grad flag for all parameters. - - """ + """Deactivates the requires_grad flag for all parameters.""" for param in params: param.requires_grad = False def _do_momentum_update(prev_params, params, m): - """Updates the weights of the previous parameters. - - """ + """Updates the weights of the previous parameters.""" for prev_param, param in zip(prev_params, params): - prev_param.data = prev_param.data * m + param.data * (1. - m) + prev_param.data = prev_param.data * m + param.data * (1.0 - m) class _MomentumEncoderMixin: @@ -64,23 +60,19 @@ class _MomentumEncoderMixin: momentum_projection_head: nn.Module def _init_momentum_encoder(self): - """Initializes momentum backbone and a momentum projection head. - - """ + """Initializes momentum backbone and a momentum projection head.""" assert self.backbone is not None assert self.projection_head is not None self.momentum_backbone = copy.deepcopy(self.backbone) self.momentum_projection_head = copy.deepcopy(self.projection_head) - + _deactivate_requires_grad(self.momentum_backbone.parameters()) _deactivate_requires_grad(self.momentum_projection_head.parameters()) @torch.no_grad() def _momentum_update(self, m: float = 0.999): - """Performs the momentum update for the backbone and projection head. - - """ + """Performs the momentum update for the backbone and projection head.""" _do_momentum_update( self.momentum_backbone.parameters(), self.backbone.parameters(), @@ -94,17 +86,13 @@ def _momentum_update(self, m: float = 0.999): @torch.no_grad() def _batch_shuffle(self, batch: torch.Tensor): - """Returns the shuffled batch and the indices to undo. - - """ + """Returns the shuffled batch and the indices to undo.""" batch_size = batch.shape[0] shuffle = torch.randperm(batch_size, device=batch.device) return batch[shuffle], shuffle @torch.no_grad() def _batch_unshuffle(self, batch: torch.Tensor, shuffle: torch.Tensor): - """Returns the unshuffled batch. - - """ + """Returns the unshuffled batch.""" unshuffle = torch.argsort(shuffle) return batch[unshuffle] diff --git a/lightly/models/barlowtwins.py b/lightly/models/barlowtwins.py index 32bfad783..7b3bf90ab 100644 --- a/lightly/models/barlowtwins.py +++ b/lightly/models/barlowtwins.py @@ -42,7 +42,6 @@ def __init__( proj_hidden_dim: int = 8192, out_dim: int = 8192, ): - super(BarlowTwins, self).__init__() self.backbone = backbone @@ -66,7 +65,6 @@ def __init__( def forward( self, x0: torch.Tensor, x1: torch.Tensor = None, return_features: bool = False ): - """Forward pass through BarlowTwins. Extracts features with the backbone and applies the projection diff --git a/lightly/models/batchnorm.py b/lightly/models/batchnorm.py index 0e8906d13..e4424bbad 100644 --- a/lightly/models/batchnorm.py +++ b/lightly/models/batchnorm.py @@ -20,36 +20,29 @@ class SplitBatchNorm(nn.BatchNorm2d): Number of splits. """ + def __init__(self, num_features, num_splits, **kw): super().__init__(num_features, **kw) self.num_splits = num_splits self.register_buffer( - 'running_mean', torch.zeros(num_features*self.num_splits) - ) - self.register_buffer( - 'running_var', torch.ones(num_features*self.num_splits) + "running_mean", torch.zeros(num_features * self.num_splits) ) + self.register_buffer("running_var", torch.ones(num_features * self.num_splits)) def train(self, mode=True): # lazily collate stats when we are going to use them if (self.training is True) and (mode is False): - self.running_mean = \ - torch.mean( - self.running_mean.view(self.num_splits, self.num_features), - dim=0 - ).repeat(self.num_splits) - self.running_var = \ - torch.mean( - self.running_var.view(self.num_splits, self.num_features), - dim=0 - ).repeat(self.num_splits) + self.running_mean = torch.mean( + self.running_mean.view(self.num_splits, self.num_features), dim=0 + ).repeat(self.num_splits) + self.running_var = torch.mean( + self.running_var.view(self.num_splits, self.num_features), dim=0 + ).repeat(self.num_splits) return super().train(mode) def forward(self, input): - """Computes the SplitBatchNorm on the input. - - """ + """Computes the SplitBatchNorm on the input.""" # get input shape N, C, H, W = input.shape @@ -57,33 +50,32 @@ def forward(self, input): # use the stats from the first split if self.training or not self.track_running_stats: result = nn.functional.batch_norm( - input.view(-1, C*self.num_splits, H, W), - self.running_mean, self.running_var, + input.view(-1, C * self.num_splits, H, W), + self.running_mean, + self.running_var, self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits), True, self.momentum, - self.eps + self.eps, ).view(N, C, H, W) else: result = nn.functional.batch_norm( input, - self.running_mean[:self.num_features], - self.running_var[:self.num_features], + self.running_mean[: self.num_features], + self.running_var[: self.num_features], self.weight, self.bias, False, - self.momentum, - self.eps + self.momentum, + self.eps, ) - + return result def get_norm_layer(num_features: int, num_splits: int, **kw): - """Utility to switch between BatchNorm2d and SplitBatchNorm. - - """ + """Utility to switch between BatchNorm2d and SplitBatchNorm.""" if num_splits > 0: return SplitBatchNorm(num_features, num_splits) else: diff --git a/lightly/models/byol.py b/lightly/models/byol.py index 6dd74cb1a..c860e608f 100644 --- a/lightly/models/byol.py +++ b/lightly/models/byol.py @@ -8,8 +8,8 @@ import torch import torch.nn as nn -from lightly.models.modules import BYOLProjectionHead from lightly.models._momentum import _MomentumEncoderMixin +from lightly.models.modules import BYOLProjectionHead def _get_byol_mlp(num_ftrs: int, hidden_dim: int, out_dim: int): @@ -52,7 +52,6 @@ def __init__( out_dim: int = 256, m: float = 0.9, ): - super(BYOL, self).__init__() self.backbone = backbone @@ -113,7 +112,6 @@ def _forward(self, x0: torch.Tensor, x1: torch.Tensor = None): # forward pass of second input x1 with torch.no_grad(): - f1 = self.momentum_backbone(x1).flatten(start_dim=1) out1 = self.momentum_projection_head(f1) diff --git a/lightly/models/moco.py b/lightly/models/moco.py index e8148d178..fde1f0234 100644 --- a/lightly/models/moco.py +++ b/lightly/models/moco.py @@ -40,7 +40,6 @@ def __init__( m: float = 0.999, batch_shuffle: bool = False, ): - super(MoCo, self).__init__() self.backbone = backbone @@ -117,7 +116,6 @@ def forward( # forward pass of second input x1 with torch.no_grad(): - # shuffle for batchnorm if self.batch_shuffle: x1, shuffle = self._batch_shuffle(x1) diff --git a/lightly/models/modules/__init__.py b/lightly/models/modules/__init__.py index 4d500c40a..3d8fe3d7c 100644 --- a/lightly/models/modules/__init__.py +++ b/lightly/models/modules/__init__.py @@ -8,27 +8,30 @@ # Copyright (c) 2021. Lightly AG and its affiliates. # All Rights Reserved -from lightly.models.modules.heads import BarlowTwinsProjectionHead -from lightly.models.modules.heads import BYOLProjectionHead -from lightly.models.modules.heads import BYOLPredictionHead -from lightly.models.modules.heads import DINOProjectionHead -from lightly.models.modules.heads import MoCoProjectionHead -from lightly.models.modules.heads import NNCLRProjectionHead -from lightly.models.modules.heads import NNCLRPredictionHead -from lightly.models.modules.heads import SimCLRProjectionHead -from lightly.models.modules.heads import SimSiamProjectionHead -from lightly.models.modules.heads import SimSiamPredictionHead -from lightly.models.modules.heads import SMoGPrototypes -from lightly.models.modules.heads import SMoGProjectionHead -from lightly.models.modules.heads import SMoGPredictionHead -from lightly.models.modules.heads import SwaVProjectionHead -from lightly.models.modules.heads import SwaVPrototypes -from lightly.models.modules.nn_memory_bank import NNMemoryBankModule - from lightly import _torchvision_vit_available +from lightly.models.modules.heads import ( + BarlowTwinsProjectionHead, + BYOLPredictionHead, + BYOLProjectionHead, + DINOProjectionHead, + MoCoProjectionHead, + NNCLRPredictionHead, + NNCLRProjectionHead, + SimCLRProjectionHead, + SimSiamPredictionHead, + SimSiamProjectionHead, + SMoGPredictionHead, + SMoGProjectionHead, + SMoGPrototypes, + SwaVProjectionHead, + SwaVPrototypes, +) +from lightly.models.modules.nn_memory_bank import NNMemoryBankModule if _torchvision_vit_available: # Requires torchvision >=0.12 - from lightly.models.modules.masked_autoencoder import MAEBackbone - from lightly.models.modules.masked_autoencoder import MAEDecoder - from lightly.models.modules.masked_autoencoder import MAEEncoder + from lightly.models.modules.masked_autoencoder import ( + MAEBackbone, + MAEDecoder, + MAEEncoder, + ) diff --git a/lightly/models/modules/heads.py b/lightly/models/modules/heads.py index 4a9746d07..a6b346e4d 100644 --- a/lightly/models/modules/heads.py +++ b/lightly/models/modules/heads.py @@ -32,8 +32,7 @@ class ProjectionHead(nn.Module): """ def __init__( - self, - blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]] + self, blocks: List[Tuple[int, int, Optional[nn.Module], Optional[nn.Module]]] ): super(ProjectionHead, self).__init__() @@ -69,15 +68,16 @@ class BarlowTwinsProjectionHead(ProjectionHead): """ - def __init__(self, - input_dim: int = 2048, - hidden_dim: int = 8192, - output_dim: int = 8192): - super(BarlowTwinsProjectionHead, self).__init__([ - (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, output_dim, None, None), - ]) + def __init__( + self, input_dim: int = 2048, hidden_dim: int = 8192, output_dim: int = 8192 + ): + super(BarlowTwinsProjectionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) class BYOLProjectionHead(ProjectionHead): @@ -90,14 +90,16 @@ class BYOLProjectionHead(ProjectionHead): [0]: BYOL, 2020, https://arxiv.org/abs/2006.07733 """ - def __init__(self, - input_dim: int = 2048, - hidden_dim: int = 4096, - output_dim: int = 256): - super(BYOLProjectionHead, self).__init__([ - (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, output_dim, None, None), - ]) + + def __init__( + self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256 + ): + super(BYOLProjectionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) class BYOLPredictionHead(ProjectionHead): @@ -110,14 +112,16 @@ class BYOLPredictionHead(ProjectionHead): [0]: BYOL, 2020, https://arxiv.org/abs/2006.07733 """ - def __init__(self, - input_dim: int = 256, - hidden_dim: int = 4096, - output_dim: int = 256): - super(BYOLPredictionHead, self).__init__([ - (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, output_dim, None, None), - ]) + + def __init__( + self, input_dim: int = 256, hidden_dim: int = 4096, output_dim: int = 256 + ): + super(BYOLPredictionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) class MoCoProjectionHead(ProjectionHead): @@ -130,14 +134,15 @@ class MoCoProjectionHead(ProjectionHead): """ - def __init__(self, - input_dim: int = 2048, - hidden_dim: int = 2048, - output_dim: int = 128): - super(MoCoProjectionHead, self).__init__([ - (input_dim, hidden_dim, None, nn.ReLU()), - (hidden_dim, output_dim, None, None), - ]) + def __init__( + self, input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 128 + ): + super(MoCoProjectionHead, self).__init__( + [ + (input_dim, hidden_dim, None, nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) class NNCLRProjectionHead(ProjectionHead): @@ -152,15 +157,17 @@ class NNCLRProjectionHead(ProjectionHead): [0]: NNCLR, 2021, https://arxiv.org/abs/2104.14548 """ - def __init__(self, - input_dim: int = 2048, - hidden_dim: int = 2048, - output_dim: int = 256): - super(NNCLRProjectionHead, self).__init__([ - (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, output_dim, nn.BatchNorm1d(output_dim), None), - ]) + + def __init__( + self, input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 256 + ): + super(NNCLRProjectionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, output_dim, nn.BatchNorm1d(output_dim), None), + ] + ) class NNCLRPredictionHead(ProjectionHead): @@ -173,14 +180,16 @@ class NNCLRPredictionHead(ProjectionHead): [0]: NNCLR, 2021, https://arxiv.org/abs/2104.14548 """ - def __init__(self, - input_dim: int = 256, - hidden_dim: int = 4096, - output_dim: int = 256): - super(NNCLRPredictionHead, self).__init__([ - (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, output_dim, None, None), - ]) + + def __init__( + self, input_dim: int = 256, hidden_dim: int = 4096, output_dim: int = 256 + ): + super(NNCLRPredictionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) class SimCLRProjectionHead(ProjectionHead): @@ -192,14 +201,16 @@ class SimCLRProjectionHead(ProjectionHead): [0] SimCLR, 2020, https://arxiv.org/abs/2002.05709 """ - def __init__(self, - input_dim: int = 2048, - hidden_dim: int = 2048, - output_dim: int = 128): - super(SimCLRProjectionHead, self).__init__([ - (input_dim, hidden_dim, None, nn.ReLU()), - (hidden_dim, output_dim, None, None), - ]) + + def __init__( + self, input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 128 + ): + super(SimCLRProjectionHead, self).__init__( + [ + (input_dim, hidden_dim, None, nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) class SimSiamProjectionHead(ProjectionHead): @@ -212,30 +223,39 @@ class SimSiamProjectionHead(ProjectionHead): [0]: SimSiam, 2020, https://arxiv.org/abs/2011.10566 """ - def __init__(self, - input_dim: int = 2048, - hidden_dim: int = 2048, - output_dim: int = 2048): - super(SimSiamProjectionHead, self).__init__([ - (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, output_dim, nn.BatchNorm1d(output_dim, affine=False), None), - ]) + + def __init__( + self, input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 2048 + ): + super(SimSiamProjectionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + ( + hidden_dim, + output_dim, + nn.BatchNorm1d(output_dim, affine=False), + None, + ), + ] + ) class SMoGPrototypes(nn.Module): - """SMoG prototypes module for synchronous momentum grouping. - - """ + """SMoG prototypes module for synchronous momentum grouping.""" def __init__( - self, group_features: torch.Tensor, beta: float, + self, + group_features: torch.Tensor, + beta: float, ): super(SMoGPrototypes, self).__init__() self.group_features = nn.Parameter(group_features, requires_grad=False) self.beta = beta - def forward(self, x: torch.Tensor, group_features: torch.Tensor, temperature: float = 0.1) -> torch.Tensor: + def forward( + self, x: torch.Tensor, group_features: torch.Tensor, temperature: float = 0.1 + ) -> torch.Tensor: """Computes the logits for given model outputs and group features. Args: @@ -255,7 +275,7 @@ def forward(self, x: torch.Tensor, group_features: torch.Tensor, temperature: fl logits = torch.mm(x, group_features.t()) return logits / temperature - def get_updated_group_features(self, x: torch.Tensor) -> None: + def get_updated_group_features(self, x: torch.Tensor) -> None: """Performs the synchronous momentum update of the group vectors. Args: @@ -268,14 +288,16 @@ def get_updated_group_features(self, x: torch.Tensor) -> None: """ assignments = self.assign_groups(x) group_features = torch.clone(self.group_features.data) - for assigned_class in torch.unique(assignments): + for assigned_class in torch.unique(assignments): mask = assignments == assigned_class - group_features[assigned_class] = self.beta * self.group_features[assigned_class] + (1 - self.beta) * x[mask].mean(axis=0) + group_features[assigned_class] = self.beta * self.group_features[ + assigned_class + ] + (1 - self.beta) * x[mask].mean(axis=0) return group_features def set_group_features(self, x: torch.Tensor) -> None: - """Sets the group features and asserts they don't require gradient. """ + """Sets the group features and asserts they don't require gradient.""" self.group_features.data = x.to(self.group_features.device) @torch.no_grad() @@ -287,7 +309,7 @@ def assign_groups(self, x: torch.Tensor) -> torch.LongTensor: Returns: LongTensor of shape bsz indicating group assignments. - + """ return torch.argmax(self.forward(x, self.group_features), dim=-1) @@ -300,16 +322,23 @@ class SMoGProjectionHead(ProjectionHead): layer of projection head also has BN" [0] [0]: SMoG, 2022, https://arxiv.org/pdf/2207.06167.pdf - + """ - def __init__(self, - input_dim: int = 2048, - hidden_dim: int = 2048, - output_dim: int = 128): - super(SMoGProjectionHead, self).__init__([ - (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, output_dim, nn.BatchNorm1d(output_dim, affine=False), None) - ]) + + def __init__( + self, input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 128 + ): + super(SMoGProjectionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + ( + hidden_dim, + output_dim, + nn.BatchNorm1d(output_dim, affine=False), + None, + ), + ] + ) class SMoGPredictionHead(ProjectionHead): @@ -320,17 +349,18 @@ class SMoGPredictionHead(ProjectionHead): layer of projection head also has BN" [0] [0]: SMoG, 2022, https://arxiv.org/pdf/2207.06167.pdf - + """ - def __init__(self, - input_dim: int = 128, - hidden_dim: int = 2048, - output_dim: int = 128): - super(SMoGPredictionHead, self).__init__([ - (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, output_dim, None, None) - ]) + def __init__( + self, input_dim: int = 128, hidden_dim: int = 2048, output_dim: int = 128 + ): + super(SMoGPredictionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) class SimSiamPredictionHead(ProjectionHead): @@ -342,14 +372,16 @@ class SimSiamPredictionHead(ProjectionHead): [0]: SimSiam, 2020, https://arxiv.org/abs/2011.10566 """ - def __init__(self, - input_dim: int = 2048, - hidden_dim: int = 512, - output_dim: int = 2048): - super(SimSiamPredictionHead, self).__init__([ - (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, output_dim, None, None), - ]) + + def __init__( + self, input_dim: int = 2048, hidden_dim: int = 512, output_dim: int = 2048 + ): + super(SimSiamPredictionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) class SwaVProjectionHead(ProjectionHead): @@ -357,14 +389,16 @@ class SwaVProjectionHead(ProjectionHead): [0]: SwAV, 2020, https://arxiv.org/abs/2006.09882 """ - def __init__(self, - input_dim: int = 2048, - hidden_dim: int = 2048, - output_dim: int = 128): - super(SwaVProjectionHead, self).__init__([ - (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, output_dim, None, None), - ]) + + def __init__( + self, input_dim: int = 2048, hidden_dim: int = 2048, output_dim: int = 128 + ): + super(SwaVProjectionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) class SwaVPrototypes(nn.Module): @@ -394,15 +428,22 @@ class SwaVPrototypes(nn.Module): >>> logits = prototypes(features) """ - def __init__(self, - input_dim: int = 128, - n_prototypes: Union[List[int], int] = 3000, - n_steps_frozen_prototypes: int = 0): + + def __init__( + self, + input_dim: int = 128, + n_prototypes: Union[List[int], int] = 3000, + n_steps_frozen_prototypes: int = 0, + ): super(SwaVPrototypes, self).__init__() # Default to a list of 1 if n_prototypes is an int. - self.n_prototypes = n_prototypes if isinstance(n_prototypes, list) else [n_prototypes] + self.n_prototypes = ( + n_prototypes if isinstance(n_prototypes, list) else [n_prototypes] + ) self._is_single_prototype = True if isinstance(n_prototypes, int) else False - self.heads = nn.ModuleList([nn.Linear(input_dim, prototypes) for prototypes in self.n_prototypes]) + self.heads = nn.ModuleList( + [nn.Linear(input_dim, prototypes) for prototypes in self.n_prototypes] + ) self.n_steps_frozen_prototypes = n_steps_frozen_prototypes def forward(self, x, step=None) -> Union[torch.Tensor, List[torch.Tensor]]: @@ -411,24 +452,26 @@ def forward(self, x, step=None) -> Union[torch.Tensor, List[torch.Tensor]]: for layer in self.heads: out.append(layer(x)) return out[0] if self._is_single_prototype else out - + def normalize(self): """Normalizes the prototypes so that they are on the unit sphere.""" for layer in self.heads: utils.normalize_weight(layer.weight) - + def _freeze_prototypes_if_required(self, step): if self.n_steps_frozen_prototypes > 0: if step is None: - raise ValueError("`n_steps_frozen_prototypes` is greater than 0, please" - " provide the `step` argument to the `forward()` method.") + raise ValueError( + "`n_steps_frozen_prototypes` is greater than 0, please" + " provide the `step` argument to the `forward()` method." + ) self.requires_grad_(step >= self.n_steps_frozen_prototypes) class DINOProjectionHead(ProjectionHead): """Projection head used in DINO. - "The projection head consists of a 3-layer multi-layer perceptron (MLP) + "The projection head consists of a 3-layer multi-layer perceptron (MLP) with hidden dimension 2048 followed by l2 normalization and a weight normalized fully connected layer with K dimensions, which is similar to the design from SwAV [1]." [0] @@ -449,18 +492,19 @@ class DINOProjectionHead(ProjectionHead): Whether to use batch norm or not. Should be set to False when using a vision transformer backbone. freeze_last_layer: - Number of epochs during which we keep the output layer fixed. - Typically doing so during the first epoch helps training. Try + Number of epochs during which we keep the output layer fixed. + Typically doing so during the first epoch helps training. Try increasing this value if the loss does not decrease. norm_last_layer: Whether or not to weight normalize the last layer of the DINO head. - Not normalizing leads to better performance but can make the + Not normalizing leads to better performance but can make the training unstable. - + """ + def __init__( - self, - input_dim: int = 2048, + self, + input_dim: int = 2048, hidden_dim: int = 2048, bottleneck_dim: int = 256, output_dim: int = 65536, @@ -470,57 +514,56 @@ def __init__( ): bn = nn.BatchNorm1d(hidden_dim) if batch_norm else None - super().__init__([ - (input_dim, hidden_dim, bn, nn.GELU()), - (hidden_dim, hidden_dim, bn, nn.GELU()), - (hidden_dim, bottleneck_dim, None, None), - ]) + super().__init__( + [ + (input_dim, hidden_dim, bn, nn.GELU()), + (hidden_dim, hidden_dim, bn, nn.GELU()), + (hidden_dim, bottleneck_dim, None, None), + ] + ) self.apply(self._init_weights) self.freeze_last_layer = freeze_last_layer - self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, output_dim, bias=False)) + self.last_layer = nn.utils.weight_norm( + nn.Linear(bottleneck_dim, output_dim, bias=False) + ) self.last_layer.weight_g.data.fill_(1) - # Option to normalize last layer. + # Option to normalize last layer. if norm_last_layer: self.last_layer.weight_g.requires_grad = False - + def cancel_last_layer_gradients(self, current_epoch: int): - """Cancel last layer gradients to stabilize the training. - - """ + """Cancel last layer gradients to stabilize the training.""" if current_epoch >= self.freeze_last_layer: return for param in self.last_layer.parameters(): param.grad = None def _init_weights(self, module): - """Initializes layers with a truncated normal distribution. - - """ + """Initializes layers with a truncated normal distribution.""" if isinstance(module, nn.Linear): utils._no_grad_trunc_normal( - module.weight, - mean=0, - std=0.2, - a=-2, + module.weight, + mean=0, + std=0.2, + a=-2, b=2, ) if module.bias is not None: nn.init.constant_(module.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Computes one forward pass through the head. - - """ + """Computes one forward pass through the head.""" x = self.layers(x) # l2 normalization x = nn.functional.normalize(x, dim=-1, p=2) x = self.last_layer(x) return x + class MSNProjectionHead(ProjectionHead): """Projection head for MSN [0]. - "We train with a 3-layer projection head with output dimension 256 and + "We train with a 3-layer projection head with output dimension 256 and batch-normalization at the input and hidden layers.." [0] Code inspired by [1]. @@ -528,60 +571,68 @@ class MSNProjectionHead(ProjectionHead): - [1]: https://github.com/facebookresearch/msn Attributes: - input_dim: + input_dim: Input dimension, default value 768 is for a ViT base model. - hidden_dim: + hidden_dim: Hidden dimension. - output_dim: + output_dim: Output dimension. """ + def __init__( self, input_dim: int = 768, hidden_dim: int = 2048, output_dim: int = 256, ): - super().__init__(blocks=[ - (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.GELU()), - (hidden_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.GELU()), - (hidden_dim, output_dim, None, None), - ]) + super().__init__( + blocks=[ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.GELU()), + (hidden_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.GELU()), + (hidden_dim, output_dim, None, None), + ] + ) + class TiCoProjectionHead(ProjectionHead): - """Projection head used for TiCo. + """Projection head used for TiCo. + + "This MLP consists in a linear layer with output size 4096 followed by + batch normalization, rectified linear units (ReLU), and a final + linear layer with output dimension 256." [0] - "This MLP consists in a linear layer with output size 4096 followed by - batch normalization, rectified linear units (ReLU), and a final - linear layer with output dimension 256." [0] + [0]: TiCo, 2022, https://arxiv.org/pdf/2206.10698.pdf + + """ - [0]: TiCo, 2022, https://arxiv.org/pdf/2206.10698.pdf + def __init__( + self, input_dim: int = 2048, hidden_dim: int = 4096, output_dim: int = 256 + ): + super(TiCoProjectionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) - """ - def __init__(self, - input_dim: int = 2048, - hidden_dim: int = 4096, - output_dim: int = 256): - super(TiCoProjectionHead, self).__init__([ - (input_dim, hidden_dim, nn.BatchNorm1d(hidden_dim), nn.ReLU()), - (hidden_dim, output_dim, None, None), - ]) class VicRegLLocalProjectionHead(ProjectionHead): """Projection head used for the local head of VICRegL. - The projector network has three linear layers. The first two layers of the projector + The projector network has three linear layers. The first two layers of the projector are followed by a batch normalization layer and rectified linear units. 2022, VICRegL, https://arxiv.org/abs/2210.01571 """ - def __init__(self, - input_dim: int = 2048, - hidden_dim: int = 8192, - output_dim: int = 8192): - super(VicRegLLocalProjectionHead, self).__init__([ - (input_dim, hidden_dim, nn.LayerNorm(hidden_dim), nn.ReLU()), - (hidden_dim, hidden_dim, nn.LayerNorm(hidden_dim), nn.ReLU()), - (hidden_dim, output_dim, None, None), - ]) + def __init__( + self, input_dim: int = 2048, hidden_dim: int = 8192, output_dim: int = 8192 + ): + super(VicRegLLocalProjectionHead, self).__init__( + [ + (input_dim, hidden_dim, nn.LayerNorm(hidden_dim), nn.ReLU()), + (hidden_dim, hidden_dim, nn.LayerNorm(hidden_dim), nn.ReLU()), + (hidden_dim, output_dim, None, None), + ] + ) diff --git a/lightly/models/modules/masked_autoencoder.py b/lightly/models/modules/masked_autoencoder.py index 7a34e6e24..f1695c4ce 100644 --- a/lightly/models/modules/masked_autoencoder.py +++ b/lightly/models/modules/masked_autoencoder.py @@ -1,16 +1,18 @@ from __future__ import annotations -from functools import partial + import math +from functools import partial from typing import Callable, List, Optional import torch import torch.nn as nn -from lightly.models import utils -# vision_transformer requires torchvision >= 0.12 +# vision_transformer requires torchvision >= 0.12 from torchvision.models import vision_transformer from torchvision.models.vision_transformer import ConvStemConfig +from lightly.models import utils + class MAEEncoder(vision_transformer.Encoder): """Encoder for the Masked Autoencoder model [0]. @@ -37,15 +39,16 @@ class MAEEncoder(vision_transformer.Encoder): Percentage of elements set to zero after the attention head. """ + def __init__( - self, - seq_length: int, - num_layers: int, - num_heads: int, + self, + seq_length: int, + num_layers: int, + num_heads: int, hidden_dim: int, - mlp_dim: int, - dropout: float, - attention_dropout: float, + mlp_dim: int, + dropout: float, + attention_dropout: float, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), ): super().__init__( @@ -63,7 +66,7 @@ def __init__( def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder) -> MAEEncoder: """Creates a MAEEncoder from a torchvision ViT encoder.""" # Create a new instance with dummy values as they will be overwritten - # by the copied vit_encoder attributes + # by the copied vit_encoder attributes encoder = cls( seq_length=1, num_layers=1, @@ -80,9 +83,7 @@ def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder) -> MAEEncoder return encoder def forward( - self, - input: torch.Tensor, - idx_keep: Optional[torch.Tensor] = None + self, input: torch.Tensor, idx_keep: Optional[torch.Tensor] = None ) -> torch.Tensor: """Encode input tokens. @@ -109,7 +110,7 @@ def interpolate_pos_encoding(self, input: torch.Tensor): ignoring the class token. This allows encoding variable sized images. Args: - input: + input: Input tensor with shape (batch_size, num_sequences). """ @@ -123,9 +124,11 @@ def interpolate_pos_encoding(self, input: torch.Tensor): pos_embedding = self.pos_embedding[:, 1:] dim = input.shape[-1] pos_embedding = nn.functional.interpolate( - pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute( + 0, 3, 1, 2 + ), scale_factor=math.sqrt(npatch / N), - mode='bicubic', + mode="bicubic", ) pos_embedding = pos_embedding.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_emb.unsqueeze(0), pos_embedding), dim=1) @@ -134,7 +137,7 @@ def interpolate_pos_encoding(self, input: torch.Tensor): class MAEBackbone(vision_transformer.VisionTransformer): """Backbone for the Masked Autoencoder model [0]. - Converts images into patches and encodes them. Code inspired by [1]. + Converts images into patches and encodes them. Code inspired by [1]. Note that this implementation uses a learned positional embedding while [0] uses a fixed positional embedding. @@ -163,7 +166,7 @@ class MAEBackbone(vision_transformer.VisionTransformer): num_classes: Number of classes for the classification head. Currently not used. representation_size: - If specified, an additional linear layer is added before the + If specified, an additional linear layer is added before the classification head to change the token dimension from hidden_dim to representation_size. Currently not used. norm_layer: @@ -174,20 +177,21 @@ class MAEBackbone(vision_transformer.VisionTransformer): paper [0]. """ + def __init__( - self, - image_size: int, - patch_size: int, - num_layers: int, - num_heads: int, - hidden_dim: int, - mlp_dim: int, - dropout: float = 0, - attention_dropout: float = 0, - num_classes: int = 1000, - representation_size: Optional[int] = None, + self, + image_size: int, + patch_size: int, + num_layers: int, + num_heads: int, + hidden_dim: int, + mlp_dim: int, + dropout: float = 0, + attention_dropout: float = 0, + num_classes: int = 1000, + representation_size: Optional[int] = None, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), - conv_stem_configs: Optional[List[ConvStemConfig]] = None + conv_stem_configs: Optional[List[ConvStemConfig]] = None, ): super().__init__( image_size=image_size, @@ -218,7 +222,7 @@ def __init__( def from_vit(cls, vit: vision_transformer.VisionTransformer) -> MAEBackbone: """Creates a MAEBackbone from a torchvision ViT model.""" # Create a new instance with dummy values as they will be overwritten - # by the copied vit_encoder attributes + # by the copied vit_encoder attributes backbone = cls( image_size=vit.image_size, patch_size=vit.patch_size, @@ -240,9 +244,7 @@ def from_vit(cls, vit: vision_transformer.VisionTransformer) -> MAEBackbone: return backbone def forward( - self, - images: torch.Tensor, - idx_keep: Optional[torch.Tensor] = None + self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None ) -> torch.Tensor: """Returns encoded class tokens from a batch of images. @@ -254,9 +256,9 @@ def forward( entry is an index of the token to keep in the respective batch. If specified, only the indexed tokens will be passed to the encoder. - + Returns: - Tensor with shape (batch_size, hidden_dim) containing the + Tensor with shape (batch_size, hidden_dim) containing the encoded class token for every image. """ @@ -265,9 +267,7 @@ def forward( return class_token def encode( - self, - images: torch.Tensor, - idx_keep: Optional[torch.Tensor] = None + self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None ) -> torch.Tensor: """Returns encoded class and patch tokens from images. @@ -279,29 +279,30 @@ def encode( entry is an index of the token to keep in the respective batch. If specified, only the indexed tokens will be passed to the encoder. - + Returns: - Tensor with shape (batch_size, sequence_length, hidden_dim) + Tensor with shape (batch_size, sequence_length, hidden_dim) containing the encoded class and patch tokens for every image. """ out = self.images_to_tokens(images, prepend_class_token=True) return self.encoder(out, idx_keep) - - def images_to_tokens(self, images: torch.Tensor, prepend_class_token: bool) -> torch.Tensor: + def images_to_tokens( + self, images: torch.Tensor, prepend_class_token: bool + ) -> torch.Tensor: """Converts images into patch tokens. - + Args: images: Tensor with shape (batch_size, channels, image_size, image_size). - + Returns: Tensor with shape (batch_size, sequence_length - 1, hidden_dim) containing the patch tokens. """ x = self.conv_proj(images) - tokens = x.flatten(2).transpose(1, 2) + tokens = x.flatten(2).transpose(1, 2) if prepend_class_token: tokens = utils.prepend_class_token(tokens, self.class_token) return tokens @@ -339,6 +340,7 @@ class MAEDecoder(vision_transformer.Encoder): Percentage of elements set to zero after the attention head. """ + def __init__( self, seq_length: int, diff --git a/lightly/models/modules/nn_memory_bank.py b/lightly/models/modules/nn_memory_bank.py index 7997d8eca..1cedc12d4 100644 --- a/lightly/models/modules/nn_memory_bank.py +++ b/lightly/models/modules/nn_memory_bank.py @@ -4,6 +4,7 @@ # All Rights Reserved import torch + from lightly.loss.memory_bank import MemoryBankModule @@ -35,12 +36,11 @@ class NNMemoryBankModule(MemoryBankModule): >>> loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0)) """ - def __init__(self, size: int = 2 ** 16): + + def __init__(self, size: int = 2**16): super(NNMemoryBankModule, self).__init__(size) - def forward(self, - output: torch.Tensor, - update: bool = False): + def forward(self, output: torch.Tensor, update: bool = False): """Returns nearest neighbour of output tensor from memory bank Args: @@ -49,17 +49,16 @@ def forward(self, """ - output, bank = \ - super(NNMemoryBankModule, self).forward(output, update=update) + output, bank = super(NNMemoryBankModule, self).forward(output, update=update) bank = bank.to(output.device).t() output_normed = torch.nn.functional.normalize(output, dim=1) bank_normed = torch.nn.functional.normalize(bank, dim=1) - similarity_matrix = \ - torch.einsum("nd,md->nm", output_normed, bank_normed) + similarity_matrix = torch.einsum("nd,md->nm", output_normed, bank_normed) index_nearest_neighbours = torch.argmax(similarity_matrix, dim=1) - nearest_neighbours = \ - torch.index_select(bank, dim=0, index=index_nearest_neighbours) + nearest_neighbours = torch.index_select( + bank, dim=0, index=index_nearest_neighbours + ) return nearest_neighbours diff --git a/lightly/models/nnclr.py b/lightly/models/nnclr.py index 5fb4a42b9..eb1111191 100644 --- a/lightly/models/nnclr.py +++ b/lightly/models/nnclr.py @@ -8,8 +8,7 @@ import torch import torch.nn as nn -from lightly.models.modules import NNCLRProjectionHead -from lightly.models.modules import NNCLRPredictionHead +from lightly.models.modules import NNCLRPredictionHead, NNCLRProjectionHead def _prediction_mlp(in_dims: int, h_dims: int, out_dims: int) -> nn.Sequential: @@ -130,7 +129,6 @@ def __init__( pred_hidden_dim: int = 4096, out_dim: int = 256, ): - super(NNCLR, self).__init__() self.backbone = backbone diff --git a/lightly/models/resnet.py b/lightly/models/resnet.py index 179823b37..44c6c5829 100644 --- a/lightly/models/resnet.py +++ b/lightly/models/resnet.py @@ -19,48 +19,47 @@ from lightly.models.batchnorm import get_norm_layer + class BasicBlock(nn.Module): - """ Implementation of the ResNet Basic Block. + """Implementation of the ResNet Basic Block. - Attributes: - in_planes: - Number of input channels. - planes: - Number of channels. - stride: - Stride of the first convolutional. + Attributes: + in_planes: + Number of input channels. + planes: + Number of channels. + stride: + Stride of the first convolutional. """ - expansion = 1 - def __init__(self, in_planes: int, planes: int, stride: int = 1, num_splits: int = 0): + expansion = 1 + def __init__( + self, in_planes: int, planes: int, stride: int = 1, num_splits: int = 0 + ): super(BasicBlock, self).__init__() - self.conv1 = nn.Conv2d(in_planes, - planes, - kernel_size=3, - stride=stride, - padding=1, - bias=False) + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) self.bn1 = get_norm_layer(planes, num_splits) - self.conv2 = nn.Conv2d(planes, - planes, - kernel_size=3, - stride=1, - padding=1, - bias=False) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) self.bn2 = get_norm_layer(planes, num_splits) self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion*planes: + if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, - self.expansion*planes, - kernel_size=1, - stride=stride, - bias=False), - get_norm_layer(self.expansion * planes, num_splits) + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ), + get_norm_layer(self.expansion * planes, num_splits), ) def forward(self, x: torch.Tensor): @@ -88,7 +87,7 @@ def forward(self, x: torch.Tensor): class Bottleneck(nn.Module): - """ Implementation of the ResNet Bottleneck Block. + """Implementation of the ResNet Bottleneck Block. Attributes: in_planes: @@ -99,38 +98,38 @@ class Bottleneck(nn.Module): Stride of the first convolutional. """ - expansion = 4 - def __init__(self, in_planes: int, planes: int, stride: int = 1, num_splits: int = 0): + expansion = 4 + def __init__( + self, in_planes: int, planes: int, stride: int = 1, num_splits: int = 0 + ): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) self.bn1 = get_norm_layer(planes, num_splits) - self.conv2 = nn.Conv2d(planes, - planes, - kernel_size=3, - stride=stride, - padding=1, - bias=False) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) self.bn2 = get_norm_layer(planes, num_splits) - self.conv3 = nn.Conv2d(planes, - self.expansion*planes, - kernel_size=1, - bias=False) + self.conv3 = nn.Conv2d( + planes, self.expansion * planes, kernel_size=1, bias=False + ) self.bn3 = get_norm_layer(self.expansion * planes, num_splits) self.shortcut = nn.Sequential() - if stride != 1 or in_planes != self.expansion*planes: + if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( - nn.Conv2d(in_planes, - self.expansion*planes, - kernel_size=1, - stride=stride, - bias=False), - get_norm_layer(self.expansion * planes, num_splits) + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ), + get_norm_layer(self.expansion * planes, num_splits), ) def forward(self, x): @@ -178,33 +177,39 @@ class ResNet(nn.Module): Multiplier for ResNet width. """ - def __init__(self, - block: nn.Module = BasicBlock, - layers: List[int] = [2, 2, 2, 2], - num_classes: int = 10, - width: float = 1., - num_splits: int = 0): - + def __init__( + self, + block: nn.Module = BasicBlock, + layers: List[int] = [2, 2, 2, 2], + num_classes: int = 10, + width: float = 1.0, + num_splits: int = 0, + ): super(ResNet, self).__init__() self.in_planes = int(64 * width) self.base = int(64 * width) - self.conv1 = nn.Conv2d(3, - self.base, - kernel_size=3, - stride=1, - padding=1, - bias=False) + self.conv1 = nn.Conv2d( + 3, self.base, kernel_size=3, stride=1, padding=1, bias=False + ) self.bn1 = get_norm_layer(self.base, num_splits) - self.layer1 = self._make_layer(block, self.base, layers[0], stride=1, num_splits=num_splits) - self.layer2 = self._make_layer(block, self.base*2, layers[1], stride=2, num_splits=num_splits) - self.layer3 = self._make_layer(block, self.base*4, layers[2], stride=2, num_splits=num_splits) - self.layer4 = self._make_layer(block, self.base*8, layers[3], stride=2, num_splits=num_splits) - self.linear = nn.Linear(self.base*8*block.expansion, num_classes) + self.layer1 = self._make_layer( + block, self.base, layers[0], stride=1, num_splits=num_splits + ) + self.layer2 = self._make_layer( + block, self.base * 2, layers[1], stride=2, num_splits=num_splits + ) + self.layer3 = self._make_layer( + block, self.base * 4, layers[2], stride=2, num_splits=num_splits + ) + self.layer4 = self._make_layer( + block, self.base * 8, layers[3], stride=2, num_splits=num_splits + ) + self.linear = nn.Linear(self.base * 8 * block.expansion, num_classes) def _make_layer(self, block, planes, layers, stride, num_splits): - strides = [stride] + [1]*(layers-1) + strides = [stride] + [1] * (layers - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride, num_splits)) @@ -217,7 +222,7 @@ def forward(self, x: torch.Tensor): Args: x: Tensor of shape bsz x channels x W x H - + Returns: Output tensor of shape bsz x num_classes @@ -233,10 +238,12 @@ def forward(self, x: torch.Tensor): return out -def ResNetGenerator(name: str = 'resnet-18', - width: float = 1, - num_classes: int = 10, - num_splits: int = 0): +def ResNetGenerator( + name: str = "resnet-18", + width: float = 1, + num_classes: int = 10, + num_splits: int = 0, +): """Builds and returns the specified ResNet. Args: @@ -263,16 +270,24 @@ def ResNetGenerator(name: str = 'resnet-18', """ model_params = { - 'resnet-9': {'block': BasicBlock, 'layers': [1, 1, 1, 1]}, - 'resnet-18': {'block': BasicBlock, 'layers': [2, 2, 2, 2]}, - 'resnet-34': {'block': BasicBlock, 'layers': [3, 4, 6, 3]}, - 'resnet-50': {'block': Bottleneck, 'layers': [3, 4, 6, 3]}, - 'resnet-101': {'block': Bottleneck, 'layers': [3, 4, 23, 3]}, - 'resnet-152': {'block': Bottleneck, 'layers': [3, 8, 36, 3]}, + "resnet-9": {"block": BasicBlock, "layers": [1, 1, 1, 1]}, + "resnet-18": {"block": BasicBlock, "layers": [2, 2, 2, 2]}, + "resnet-34": {"block": BasicBlock, "layers": [3, 4, 6, 3]}, + "resnet-50": {"block": Bottleneck, "layers": [3, 4, 6, 3]}, + "resnet-101": {"block": Bottleneck, "layers": [3, 4, 23, 3]}, + "resnet-152": {"block": Bottleneck, "layers": [3, 8, 36, 3]}, } if name not in model_params.keys(): - raise ValueError('Illegal name: {%s}. \ - Try resnet-9, resnet-18, resnet-34, resnet-50, resnet-101, resnet-152.' % (name)) - - return ResNet(**model_params[name], width=width, num_classes=num_classes, num_splits=num_splits) + raise ValueError( + "Illegal name: {%s}. \ + Try resnet-9, resnet-18, resnet-34, resnet-50, resnet-101, resnet-152." + % (name) + ) + + return ResNet( + **model_params[name], + width=width, + num_classes=num_classes, + num_splits=num_splits + ) diff --git a/lightly/models/simclr.py b/lightly/models/simclr.py index 14f8ec40b..5ed379a97 100644 --- a/lightly/models/simclr.py +++ b/lightly/models/simclr.py @@ -29,7 +29,6 @@ class SimCLR(nn.Module): """ def __init__(self, backbone: nn.Module, num_ftrs: int = 32, out_dim: int = 128): - super(SimCLR, self).__init__() self.backbone = backbone diff --git a/lightly/models/simsiam.py b/lightly/models/simsiam.py index 83516dee7..f439a4363 100644 --- a/lightly/models/simsiam.py +++ b/lightly/models/simsiam.py @@ -8,8 +8,7 @@ import torch import torch.nn as nn -from lightly.models.modules import SimSiamProjectionHead -from lightly.models.modules import SimSiamPredictionHead +from lightly.models.modules import SimSiamPredictionHead, SimSiamProjectionHead class SimSiam(nn.Module): @@ -43,7 +42,6 @@ def __init__( pred_hidden_dim: int = 512, out_dim: int = 2048, ): - super(SimSiam, self).__init__() self.backbone = backbone diff --git a/lightly/models/utils.py b/lightly/models/utils.py index 9a562301f..046f20832 100644 --- a/lightly/models/utils.py +++ b/lightly/models/utils.py @@ -4,13 +4,13 @@ # All Rights Reserved import math -from typing import Optional, Tuple, Union import warnings +from typing import Optional, Tuple, Union +import numpy as np import torch import torch.distributed as dist import torch.nn as nn -import numpy as np @torch.no_grad() diff --git a/lightly/models/zoo.py b/lightly/models/zoo.py index 43eb88a5d..34d7dbede 100644 --- a/lightly/models/zoo.py +++ b/lightly/models/zoo.py @@ -4,27 +4,15 @@ # All Rights Reserved ZOO = { - - 'resnet-9/simclr/d16/w0.0625': - 'https://storage.googleapis.com/models_boris/whattolabel-resnet9-simclr-d16-w0.0625-i-ce0d6bd9.pth', - - 'resnet-9/simclr/d16/w0.125': - 'https://storage.googleapis.com/models_boris/whattolabel-resnet9-simclr-d16-w0.125-i-7269c38d.pth', - - 'resnet-18/simclr/d16/w1.0': - 'https://storage.googleapis.com/models_boris/whattolabel-resnet18-simclr-d16-w1.0-i-58852cb9.pth', - - 'resnet-18/simclr/d32/w1.0': - 'https://storage.googleapis.com/models_boris/whattolabel-resnet18-simclr-d32-w1.0-i-085d0693.pth', - - 'resnet-34/simclr/d16/w1.0': - 'https://storage.googleapis.com/models_boris/whattolabel-resnet34-simclr-d16-w1.0-i-6e80d963.pth', - - 'resnet-34/simclr/d32/w1.0': - 'https://storage.googleapis.com/models_boris/whattolabel-resnet34-simclr-d32-w1.0-i-9f185b45.pth' - + "resnet-9/simclr/d16/w0.0625": "https://storage.googleapis.com/models_boris/whattolabel-resnet9-simclr-d16-w0.0625-i-ce0d6bd9.pth", + "resnet-9/simclr/d16/w0.125": "https://storage.googleapis.com/models_boris/whattolabel-resnet9-simclr-d16-w0.125-i-7269c38d.pth", + "resnet-18/simclr/d16/w1.0": "https://storage.googleapis.com/models_boris/whattolabel-resnet18-simclr-d16-w1.0-i-58852cb9.pth", + "resnet-18/simclr/d32/w1.0": "https://storage.googleapis.com/models_boris/whattolabel-resnet18-simclr-d32-w1.0-i-085d0693.pth", + "resnet-34/simclr/d16/w1.0": "https://storage.googleapis.com/models_boris/whattolabel-resnet34-simclr-d16-w1.0-i-6e80d963.pth", + "resnet-34/simclr/d32/w1.0": "https://storage.googleapis.com/models_boris/whattolabel-resnet34-simclr-d32-w1.0-i-9f185b45.pth", } + def checkpoints(): """Returns the Lightly model zoo as a list of checkpoints. @@ -47,4 +35,3 @@ def checkpoints(): """ return [item for key, item in ZOO.items()] - diff --git a/lightly/transforms/__init__.py b/lightly/transforms/__init__.py index 0230182a9..1833631d9 100644 --- a/lightly/transforms/__init__.py +++ b/lightly/transforms/__init__.py @@ -8,19 +8,21 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform from lightly.transforms.gaussian_blur import GaussianBlur -from lightly.transforms.rotation import RandomRotate -from lightly.transforms.rotation import RandomRotateDegrees -from lightly.transforms.rotation import random_rotation_transform -from lightly.transforms.solarize import RandomSolarization from lightly.transforms.jigsaw import Jigsaw -from lightly.transforms.dino_transform import DINOTransform, DINOViewTransform from lightly.transforms.mae_transform import MAETransform from lightly.transforms.moco_transform import MoCoV1Transform, MoCoV2Transform from lightly.transforms.msn_transform import MSNTransform, MSNViewTransform from lightly.transforms.pirl_transform import PIRLTransform +from lightly.transforms.rotation import ( + RandomRotate, + RandomRotateDegrees, + random_rotation_transform, +) from lightly.transforms.simclr_transform import SimCLRTransform, SimCLRViewTransform from lightly.transforms.smog_transform import SMoGTransform, SmoGViewTransform +from lightly.transforms.solarize import RandomSolarization from lightly.transforms.swav_transform import SwaVTransform, SwaVViewTransform from lightly.transforms.vicreg_transform import VICRegTransform, VICRegViewTransform from lightly.transforms.vicregl_transform import VICRegLTransform, VICRegLViewTransform diff --git a/lightly/transforms/dino_transform.py b/lightly/transforms/dino_transform.py index ed432f668..d5edd590b 100644 --- a/lightly/transforms/dino_transform.py +++ b/lightly/transforms/dino_transform.py @@ -1,13 +1,15 @@ +from typing import Optional, Tuple, Union + +import PIL +import torchvision.transforms as T +from PIL.Image import Image from torch import Tensor + +from lightly.transforms.gaussian_blur import GaussianBlur from lightly.transforms.multi_view_transform import MultiViewTransform -from lightly.transforms.utils import IMAGENET_NORMALIZE from lightly.transforms.rotation import random_rotation_transform -from lightly.transforms.gaussian_blur import GaussianBlur from lightly.transforms.solarize import RandomSolarization -from typing import Optional, Tuple, Union -from PIL.Image import Image -import PIL -import torchvision.transforms as T +from lightly.transforms.utils import IMAGENET_NORMALIZE class DINOTransform(MultiViewTransform): @@ -98,7 +100,6 @@ def __init__( solarization_prob: float = 0.2, normalize: Union[None, dict] = IMAGENET_NORMALIZE, ): - # first global crop global_transform_0 = DINOViewTransform( crop_size=global_crop_size, @@ -192,7 +193,6 @@ def __init__( solarization_prob: float = 0.2, normalize: Union[None, dict] = IMAGENET_NORMALIZE, ): - transform = [ T.RandomResizedCrop( size=crop_size, @@ -233,7 +233,7 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: Applies the transforms to the input image. Args: - image: + image: The input image to apply the transforms to. Returns: diff --git a/lightly/transforms/gaussian_blur.py b/lightly/transforms/gaussian_blur.py index 89b46ac65..571cc4af7 100644 --- a/lightly/transforms/gaussian_blur.py +++ b/lightly/transforms/gaussian_blur.py @@ -3,11 +3,12 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -import numpy as np -from PIL import ImageFilter from typing import Tuple, Union from warnings import warn +import numpy as np +from PIL import ImageFilter + class GaussianBlur: """Implementation of random Gaussian blur. diff --git a/lightly/transforms/image_grid_transform.py b/lightly/transforms/image_grid_transform.py index f2cd79932..30854c910 100644 --- a/lightly/transforms/image_grid_transform.py +++ b/lightly/transforms/image_grid_transform.py @@ -1,7 +1,8 @@ -from torch import Tensor -from PIL.Image import Image from typing import List, Tuple, Union + import torchvision.transforms as T +from PIL.Image import Image +from torch import Tensor class ImageGridTransform: @@ -11,7 +12,7 @@ class ImageGridTransform: Attributes: transforms: - A sequence of (image_grid_transform, view_transform) tuples. + A sequence of (image_grid_transform, view_transform) tuples. The image_grid_transform creates a new view and grid from the image. The view_transform further augments the view. Every transform tuple is applied once to the image, creating len(transforms) views and diff --git a/lightly/transforms/jigsaw.py b/lightly/transforms/jigsaw.py index 63294fb49..16d3898ce 100644 --- a/lightly/transforms/jigsaw.py +++ b/lightly/transforms/jigsaw.py @@ -3,10 +3,10 @@ # Copyright (c) 2021. Lightly AG and its affiliates. # All Rights Reserved -import torch -from torchvision import transforms import numpy as np +import torch from PIL import Image +from torchvision import transforms class Jigsaw(object): diff --git a/lightly/transforms/mae_transform.py b/lightly/transforms/mae_transform.py index ba1e40b82..0510ca8d5 100644 --- a/lightly/transforms/mae_transform.py +++ b/lightly/transforms/mae_transform.py @@ -1,12 +1,14 @@ +from typing import List, Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image from torch import Tensor + from lightly.transforms.multi_view_transform import MultiViewTransform from lightly.transforms.utils import IMAGENET_NORMALIZE -from typing import List, Tuple, Union -from PIL.Image import Image -import torchvision.transforms as T -class MAETransform(): +class MAETransform: """Implements the view augmentation for MAE [0]. - [0]: Masked Autoencoder, 2021, https://arxiv.org/abs/2111.06377 @@ -44,7 +46,7 @@ def __call__(self, image: Union[Tensor, Image]) -> List[Tensor]: Applies the transforms to the input image. Args: - image: + image: The input image to apply the transforms to. Returns: diff --git a/lightly/transforms/moco_transform.py b/lightly/transforms/moco_transform.py index bfe819bcc..c21befe11 100644 --- a/lightly/transforms/moco_transform.py +++ b/lightly/transforms/moco_transform.py @@ -1,6 +1,7 @@ +from typing import Optional, Tuple, Union + from lightly.transforms.simclr_transform import SimCLRTransform from lightly.transforms.utils import IMAGENET_NORMALIZE -from typing import Optional, Tuple, Union class MoCoV1Transform(SimCLRTransform): @@ -58,7 +59,6 @@ def __init__( rr_degrees: Union[None, float, Tuple[float, float]] = None, normalize: dict = IMAGENET_NORMALIZE, ): - super().__init__( input_size=input_size, cj_prob=cj_prob, @@ -77,5 +77,6 @@ def __init__( rr_degrees=rr_degrees, normalize=normalize, ) - -MoCoV2Transform = SimCLRTransform # MoCo v2 uses the same transform as SimCLR + + +MoCoV2Transform = SimCLRTransform # MoCo v2 uses the same transform as SimCLR diff --git a/lightly/transforms/msn_transform.py b/lightly/transforms/msn_transform.py index 20f2a377f..521b321c8 100644 --- a/lightly/transforms/msn_transform.py +++ b/lightly/transforms/msn_transform.py @@ -1,10 +1,12 @@ +from typing import Optional, Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image from torch import Tensor + +from lightly.transforms.gaussian_blur import GaussianBlur from lightly.transforms.multi_view_transform import MultiViewTransform from lightly.transforms.utils import IMAGENET_NORMALIZE -from lightly.transforms.gaussian_blur import GaussianBlur -from typing import Optional, Tuple, Union -from PIL.Image import Image -import torchvision.transforms as T class MSNTransform(MultiViewTransform): @@ -139,7 +141,7 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: Applies the transforms to the input image. Args: - image: + image: The input image to apply the transforms to. Returns: diff --git a/lightly/transforms/multi_crop_transform.py b/lightly/transforms/multi_crop_transform.py index 311fbdcc7..76390b589 100644 --- a/lightly/transforms/multi_crop_transform.py +++ b/lightly/transforms/multi_crop_transform.py @@ -1,8 +1,9 @@ -from lightly.transforms.multi_view_transform import MultiViewTransform from typing import Tuple import torchvision.transforms as T +from lightly.transforms.multi_view_transform import MultiViewTransform + class MultiCropTranform(MultiViewTransform): """Implements the multi-crop transformations. Used by Swav. @@ -29,7 +30,6 @@ def __init__( crop_max_scales: Tuple[float], transforms, ): - if len(crop_sizes) != len(crop_counts): raise ValueError( "Length of crop_sizes and crop_counts must be equal but are" @@ -48,7 +48,6 @@ def __init__( crop_transforms = [] for i in range(len(crop_sizes)): - random_resized_crop = T.RandomResizedCrop( crop_sizes[i], scale=(crop_min_scales[i], crop_max_scales[i]) ) diff --git a/lightly/transforms/multi_view_transform.py b/lightly/transforms/multi_view_transform.py index 9632e723d..a00f4fc00 100644 --- a/lightly/transforms/multi_view_transform.py +++ b/lightly/transforms/multi_view_transform.py @@ -1,7 +1,8 @@ -from torch import Tensor -from PIL.Image import Image from typing import List, Union +from PIL.Image import Image +from torch import Tensor + class MultiViewTransform: """Transforms an image into multiple views. diff --git a/lightly/transforms/pirl_transform.py b/lightly/transforms/pirl_transform.py index babd3fb35..81c7ea8a3 100644 --- a/lightly/transforms/pirl_transform.py +++ b/lightly/transforms/pirl_transform.py @@ -1,11 +1,13 @@ +from typing import Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image from torch import Tensor + +from lightly.transforms.jigsaw import Jigsaw from lightly.transforms.multi_view_transform import MultiViewTransform -from lightly.transforms.utils import IMAGENET_NORMALIZE from lightly.transforms.rotation import random_rotation_transform -from lightly.transforms.jigsaw import Jigsaw -from typing import Tuple, Union -from PIL.Image import Image -import torchvision.transforms as T +from lightly.transforms.utils import IMAGENET_NORMALIZE class PIRLTransform(MultiViewTransform): @@ -54,7 +56,6 @@ def __init__( n_grid: int = 3, normalize: Union[None, dict] = IMAGENET_NORMALIZE, ): - if isinstance(input_size, tuple): input_size_ = max(input_size) else: @@ -88,5 +89,5 @@ def __init__( crop_size=int(input_size_ // n_grid), transform=T.Compose(transform), ) - + super().__init__([no_augment, jigsaw]) diff --git a/lightly/transforms/random_crop_and_flip_with_grid.py b/lightly/transforms/random_crop_and_flip_with_grid.py index a5f09e993..7ef388404 100644 --- a/lightly/transforms/random_crop_and_flip_with_grid.py +++ b/lightly/transforms/random_crop_and_flip_with_grid.py @@ -1,10 +1,11 @@ -import torch -from torch import nn -from typing import List, Tuple, Dict from dataclasses import dataclass +from typing import Dict, List, Tuple + +import torch import torchvision.transforms as T import torchvision.transforms.functional as F from PIL import Image +from torch import nn @dataclass diff --git a/lightly/transforms/rotation.py b/lightly/transforms/rotation.py index ae1226c08..ea48231e3 100644 --- a/lightly/transforms/rotation.py +++ b/lightly/transforms/rotation.py @@ -3,12 +3,13 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved +from typing import Tuple, Union + import numpy as np -from torchvision.transforms import functional as TF import torchvision.transforms as T -from typing import Tuple, Union -from torch import Tensor from PIL.Image import Image +from torch import Tensor +from torchvision.transforms import functional as TF class RandomRotate: diff --git a/lightly/transforms/simclr_transform.py b/lightly/transforms/simclr_transform.py index 1726559e3..02fd587ea 100644 --- a/lightly/transforms/simclr_transform.py +++ b/lightly/transforms/simclr_transform.py @@ -1,11 +1,13 @@ +from typing import Optional, Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image from torch import Tensor + +from lightly.transforms.gaussian_blur import GaussianBlur from lightly.transforms.multi_view_transform import MultiViewTransform -from lightly.transforms.utils import IMAGENET_NORMALIZE from lightly.transforms.rotation import random_rotation_transform -from lightly.transforms.gaussian_blur import GaussianBlur -from typing import Optional, Tuple, Union -from PIL.Image import Image -import torchvision.transforms as T +from lightly.transforms.utils import IMAGENET_NORMALIZE class SimCLRTransform(MultiViewTransform): @@ -66,7 +68,6 @@ def __init__( rr_degrees: Union[None, float, Tuple[float, float]] = None, normalize: Union[None, dict] = IMAGENET_NORMALIZE, ): - view_transform = SimCLRViewTransform( input_size=input_size, cj_prob=cj_prob, @@ -129,7 +130,7 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: Applies the transforms to the input image. Args: - image: + image: The input image to apply the transforms to. Returns: diff --git a/lightly/transforms/smog_transform.py b/lightly/transforms/smog_transform.py index b54abe50d..6a5735b69 100644 --- a/lightly/transforms/smog_transform.py +++ b/lightly/transforms/smog_transform.py @@ -1,11 +1,13 @@ +from typing import Optional, Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image from torch import Tensor + +from lightly.transforms.gaussian_blur import GaussianBlur from lightly.transforms.multi_view_transform import MultiViewTransform -from lightly.transforms.utils import IMAGENET_NORMALIZE from lightly.transforms.solarize import RandomSolarization -from lightly.transforms.gaussian_blur import GaussianBlur -from typing import Optional, Tuple, Union -from PIL.Image import Image -import torchvision.transforms as T +from lightly.transforms.utils import IMAGENET_NORMALIZE class SMoGTransform(MultiViewTransform): @@ -48,7 +50,10 @@ def __init__( crop_min_scales: Tuple[float, float] = (0.2, 0.05), crop_max_scales: Tuple[float, float] = (1.0, 0.2), gaussian_blur_probs: Tuple[float, float] = (0.5, 0.1), - gaussian_blur_kernel_sizes: Tuple[Optional[float], Optional[float]] = (None, None), + gaussian_blur_kernel_sizes: Tuple[Optional[float], Optional[float]] = ( + None, + None, + ), gaussian_blur_sigmas: Tuple[float, float] = (0.2, 2), solarize_probs: Tuple[float, float] = (0.0, 0.2), hf_prob: float = 0.5, @@ -57,10 +62,8 @@ def __init__( random_gray_scale: float = 0.2, normalize: Union[None, dict] = IMAGENET_NORMALIZE, ): - transforms = [] for i in range(len(crop_sizes)): - transforms.extend( [ SmoGViewTransform( @@ -131,7 +134,7 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: Applies the transforms to the input image. Args: - image: + image: The input image to apply the transforms to. Returns: diff --git a/lightly/transforms/swav_transform.py b/lightly/transforms/swav_transform.py index 85fd8d207..22f1760ea 100644 --- a/lightly/transforms/swav_transform.py +++ b/lightly/transforms/swav_transform.py @@ -1,11 +1,13 @@ +from typing import Optional, Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image from torch import Tensor + +from lightly.transforms.gaussian_blur import GaussianBlur from lightly.transforms.multi_crop_transform import MultiCropTranform -from lightly.transforms.utils import IMAGENET_NORMALIZE from lightly.transforms.rotation import random_rotation_transform -from lightly.transforms.gaussian_blur import GaussianBlur -from typing import Optional, Tuple, Union -from PIL.Image import Image -import torchvision.transforms as T +from lightly.transforms.utils import IMAGENET_NORMALIZE class SwaVTransform(MultiCropTranform): @@ -69,7 +71,6 @@ def __init__( sigmas: Tuple[float, float] = (0.2, 2), normalize: Union[None, dict] = IMAGENET_NORMALIZE, ): - transforms = SwaVViewTransform( hf_prob=hf_prob, vf_prob=vf_prob, @@ -136,7 +137,7 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: Applies the transforms to the input image. Args: - image: + image: The input image to apply the transforms to. Returns: diff --git a/lightly/transforms/vicreg_transform.py b/lightly/transforms/vicreg_transform.py index aed4b3b97..3a686b664 100644 --- a/lightly/transforms/vicreg_transform.py +++ b/lightly/transforms/vicreg_transform.py @@ -1,12 +1,14 @@ +from typing import Optional, Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image from torch import Tensor + +from lightly.transforms.gaussian_blur import GaussianBlur from lightly.transforms.multi_view_transform import MultiViewTransform -from lightly.transforms.utils import IMAGENET_NORMALIZE from lightly.transforms.rotation import random_rotation_transform -from lightly.transforms.gaussian_blur import GaussianBlur from lightly.transforms.solarize import RandomSolarization -from typing import Optional, Tuple, Union -from PIL.Image import Image -import torchvision.transforms as T +from lightly.transforms.utils import IMAGENET_NORMALIZE class VICRegTransform(MultiViewTransform): @@ -77,7 +79,6 @@ def __init__( rr_degrees: Union[None, float, Tuple[float, float]] = None, normalize: Union[None, dict] = IMAGENET_NORMALIZE, ): - view_transform = VICRegViewTransform( input_size=input_size, cj_prob=cj_prob, @@ -143,7 +144,7 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: Applies the transforms to the input image. Args: - image: + image: The input image to apply the transforms to. Returns: diff --git a/lightly/transforms/vicregl_transform.py b/lightly/transforms/vicregl_transform.py index c9f01e2f9..c85dc07d5 100644 --- a/lightly/transforms/vicregl_transform.py +++ b/lightly/transforms/vicregl_transform.py @@ -1,12 +1,14 @@ +from typing import Optional, Tuple, Union + +import torchvision.transforms as T +from PIL.Image import Image from torch import Tensor + +from lightly.transforms.gaussian_blur import GaussianBlur from lightly.transforms.image_grid_transform import ImageGridTransform -from lightly.transforms.utils import IMAGENET_NORMALIZE from lightly.transforms.random_crop_and_flip_with_grid import RandomResizedCropAndFlip from lightly.transforms.solarize import RandomSolarization -from lightly.transforms.gaussian_blur import GaussianBlur -from typing import Optional, Tuple, Union -from PIL.Image import Image -import torchvision.transforms as T +from lightly.transforms.utils import IMAGENET_NORMALIZE class VICRegLTransform(ImageGridTransform): @@ -31,21 +33,21 @@ class VICRegLTransform(ImageGridTransform): Probability of Gaussian blur for the local crop category. global_gaussian_blur_kernel_size: Will be deprecated in favor of `global_gaussian_blur_sigmas` argument. - If set, the old behavior applies and `global_gaussian_blur_sigmas` - is ignored. Used to calculate sigma of gaussian blur with + If set, the old behavior applies and `global_gaussian_blur_sigmas` + is ignored. Used to calculate sigma of gaussian blur with global_gaussian_blur_kernel_size * input_size. Applied to global crop category. local_gaussian_blur_kernel_size: - Will be deprecated in favor of `local_gaussian_blur_sigmas` argument. - If set, the old behavior applies and `local_gaussian_blur_sigmas` - is ignored. Used to calculate sigma of gaussian blur with + Will be deprecated in favor of `local_gaussian_blur_sigmas` argument. + If set, the old behavior applies and `local_gaussian_blur_sigmas` + is ignored. Used to calculate sigma of gaussian blur with local_gaussian_blur_kernel_size * input_size. Applied to local crop category. global_gaussian_blur_sigmas: Tuple of min and max value from which the std of the gaussian kernel - is sampled. It is ignored if `global_gaussian_blur_kernel_size` is set. + is sampled. It is ignored if `global_gaussian_blur_kernel_size` is set. Applied to global crop category. local_gaussian_blur_sigmas: Tuple of min and max value from which the std of the gaussian kernel - is sampled. It is ignored if `local_gaussian_blur_kernel_size` is set. + is sampled. It is ignored if `local_gaussian_blur_kernel_size` is set. Applied to local crop category. global_solarize_prob: Probability of solarization for the global crop category. @@ -171,7 +173,7 @@ def __call__(self, image: Union[Tensor, Image]) -> Tensor: Applies the transforms to the input image. Args: - image: + image: The input image to apply the transforms to. Returns: diff --git a/lightly/utils/benchmarking.py b/lightly/utils/benchmarking.py index 8796ed36f..0e447b731 100644 --- a/lightly/utils/benchmarking.py +++ b/lightly/utils/benchmarking.py @@ -14,12 +14,14 @@ # https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb -def knn_predict(feature: torch.Tensor, - feature_bank: torch.Tensor, - feature_labels: torch.Tensor, - num_classes: int, - knn_k: int=200, - knn_t: float=0.1) -> torch.Tensor: +def knn_predict( + feature: torch.Tensor, + feature_bank: torch.Tensor, + feature_labels: torch.Tensor, + num_classes: int, + knn_k: int = 200, + knn_t: float = 0.1, +) -> torch.Tensor: """Run kNN predictions on features based on a feature bank This method is commonly used to monitor performance of self-supervised @@ -29,17 +31,17 @@ def knn_predict(feature: torch.Tensor, used in https://arxiv.org/pdf/1805.01978v1.pdf. Args: - feature: + feature: Tensor of shape [N, D] for which you want predictions - feature_bank: + feature_bank: Tensor of a database of features used for kNN - feature_labels: + feature_labels: Labels for the features in our feature_bank - num_classes: + num_classes: Number of classes (e.g. `10` for CIFAR-10) - knn_k: + knn_k: Number of k neighbors used for kNN - knn_t: + knn_t: Temperature parameter to reweights similarities for kNN Returns: @@ -63,19 +65,25 @@ def knn_predict(feature: torch.Tensor, # [B, K] sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) # [B, K] - sim_labels = torch.gather(feature_labels.expand( - feature.size(0), -1), dim=-1, index=sim_indices) + sim_labels = torch.gather( + feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices + ) # we do a reweighting of the similarities sim_weight = (sim_weight / knn_t).exp() # counts for each class - one_hot_label = torch.zeros(feature.size( - 0) * knn_k, num_classes, device=sim_labels.device) + one_hot_label = torch.zeros( + feature.size(0) * knn_k, num_classes, device=sim_labels.device + ) # [B*K, C] one_hot_label = one_hot_label.scatter( - dim=-1, index=sim_labels.view(-1, 1), value=1.0) + dim=-1, index=sim_labels.view(-1, 1), value=1.0 + ) # weighted score ---> [B, C] - pred_scores = torch.sum(one_hot_label.view(feature.size( - 0), -1, num_classes) * sim_weight.unsqueeze(dim=-1), dim=1) + pred_scores = torch.sum( + one_hot_label.view(feature.size(0), -1, num_classes) + * sim_weight.unsqueeze(dim=-1), + dim=1, + ) pred_labels = pred_scores.argsort(dim=-1, descending=True) return pred_labels @@ -90,7 +98,7 @@ class BenchmarkModule(LightningModule): the predictions on a kNN classifier on the validation data using the feature_bank features from the train data. - We can access the highest test accuracy during a kNN prediction + We can access the highest test accuracy during a kNN prediction using the `max_accuracy` attribute. Attributes: @@ -118,7 +126,7 @@ class BenchmarkModule(LightningModule): >>> *list(resnet.children())[:-1], >>> nn.AdaptiveAvgPool2d(1), >>> ) - >>> self.resnet_simsiam = + >>> self.resnet_simsiam = >>> lightly.models.SimSiam(self.backbone, num_ftrs=512) >>> self.criterion = lightly.loss.SymNegCosineSimilarityLoss() >>> @@ -148,11 +156,13 @@ class BenchmarkModule(LightningModule): """ - def __init__(self, - dataloader_kNN: DataLoader, - num_classes: int, - knn_k: int=200, - knn_t: float=0.1): + def __init__( + self, + dataloader_kNN: DataLoader, + num_classes: int, + knn_k: int = 200, + knn_t: float = 0.1, + ): super().__init__() self.backbone = nn.Module() self.max_accuracy = 0.0 @@ -178,15 +188,13 @@ def training_epoch_end(self, outputs): feature = F.normalize(feature, dim=1) self.feature_bank.append(feature) self.targets_bank.append(target) - self.feature_bank = torch.cat( - self.feature_bank, dim=0).t().contiguous() - self.targets_bank = torch.cat( - self.targets_bank, dim=0).t().contiguous() + self.feature_bank = torch.cat(self.feature_bank, dim=0).t().contiguous() + self.targets_bank = torch.cat(self.targets_bank, dim=0).t().contiguous() self.backbone.train() def validation_step(self, batch, batch_idx): # we can only do kNN predictions once we have a feature bank - if hasattr(self, 'feature_bank') and hasattr(self, 'targets_bank'): + if hasattr(self, "feature_bank") and hasattr(self, "targets_bank"): images, targets, _ = batch feature = self.backbone(images).squeeze() feature = F.normalize(feature, dim=1) @@ -196,7 +204,7 @@ def validation_step(self, batch, batch_idx): self.targets_bank, self.num_classes, self.knn_k, - self.knn_t + self.knn_t, ) num = images.size() top1 = (pred_labels[:, 0] == targets).float().sum() @@ -206,11 +214,11 @@ def validation_epoch_end(self, outputs): device = self.dummy_param.device if outputs: total_num = torch.Tensor([0]).to(device) - total_top1 = torch.Tensor([0.]).to(device) - for (num, top1) in outputs: + total_top1 = torch.Tensor([0.0]).to(device) + for num, top1 in outputs: total_num += num[0] total_top1 += top1 - + if dist.is_initialized() and dist.get_world_size() > 1: dist.all_reduce(total_num) dist.all_reduce(total_top1) @@ -218,4 +226,4 @@ def validation_epoch_end(self, outputs): acc = float(total_top1.item() / total_num.item()) if acc > self.max_accuracy: self.max_accuracy = acc - self.log('kNN_accuracy', acc * 100.0, prog_bar=True) + self.log("kNN_accuracy", acc * 100.0, prog_bar=True) diff --git a/lightly/utils/cropping/crop_image_by_bounding_boxes.py b/lightly/utils/cropping/crop_image_by_bounding_boxes.py index 5d2b1a5eb..cab332a65 100644 --- a/lightly/utils/cropping/crop_image_by_bounding_boxes.py +++ b/lightly/utils/cropping/crop_image_by_bounding_boxes.py @@ -10,12 +10,13 @@ from lightly.data import LightlyDataset -def crop_dataset_by_bounding_boxes_and_save(dataset: LightlyDataset, - output_dir: str, - bounding_boxes_list_list: List[List[BoundingBox]], - class_indices_list_list: List[List[int]], - class_names: List[str] = None - ) -> List[List[str]]: +def crop_dataset_by_bounding_boxes_and_save( + dataset: LightlyDataset, + output_dir: str, + bounding_boxes_list_list: List[List[BoundingBox]], + class_indices_list_list: List[List[int]], + class_names: List[str] = None, +) -> List[List[str]]: """Crops all images in a dataset by the bounding boxes and saves them in the output dir Args: @@ -39,45 +40,55 @@ def crop_dataset_by_bounding_boxes_and_save(dataset: LightlyDataset, """ filenames_images = dataset.get_filenames() - if len(filenames_images) != len(bounding_boxes_list_list) or len(filenames_images) != len(class_indices_list_list): - raise ValueError("There must be one bounding box and class index list for each image in the datasets," - "but the lengths dont align.") + if len(filenames_images) != len(bounding_boxes_list_list) or len( + filenames_images + ) != len(class_indices_list_list): + raise ValueError( + "There must be one bounding box and class index list for each image in the datasets," + "but the lengths dont align." + ) cropped_image_filepath_list_list: List[List[str]] = [] - print(f"Cropping objects out of {len(filenames_images)} images...") - for filename_image, class_indices, bounding_boxes in \ - tqdm(zip(filenames_images, class_indices_list_list, bounding_boxes_list_list)): - + for filename_image, class_indices, bounding_boxes in tqdm( + zip(filenames_images, class_indices_list_list, bounding_boxes_list_list) + ): if not len(class_indices) == len(bounding_boxes): - warnings.warn(UserWarning(f"Length of class indices ({len(class_indices)} does not equal length of bounding boxes" - f"({len(bounding_boxes)}. This is an error in the input arguments. " - f"Skipping this image {filename_image}.")) + warnings.warn( + UserWarning( + f"Length of class indices ({len(class_indices)} does not equal length of bounding boxes" + f"({len(bounding_boxes)}. This is an error in the input arguments. " + f"Skipping this image {filename_image}." + ) + ) continue filepath_image = dataset.get_filepath_from_filename(filename_image) filepath_image_base, image_extension = os.path.splitext(filepath_image) - filepath_out_dir = os.path.join(output_dir, filename_image)\ - .replace(image_extension, '') + filepath_out_dir = os.path.join(output_dir, filename_image).replace( + image_extension, "" + ) Path(filepath_out_dir).mkdir(parents=True, exist_ok=True) image = Image.open(filepath_image) - + cropped_images_filepaths = [] # For every image, crop out multiple cropped images, one for each # bounding box - for index, (class_index, bbox) in \ - enumerate((zip(class_indices, bounding_boxes))): - + for index, (class_index, bbox) in enumerate( + (zip(class_indices, bounding_boxes)) + ): # determine the filename and filepath of the cropped image if class_names: class_name = class_names[class_index] else: class_name = f"class{class_index}" - cropped_image_last_filename = f'{index}_{class_name}{image_extension}' - cropped_image_filepath = os.path.join(filepath_out_dir, cropped_image_last_filename) + cropped_image_last_filename = f"{index}_{class_name}{image_extension}" + cropped_image_filepath = os.path.join( + filepath_out_dir, cropped_image_last_filename + ) # crop out the image and save it w, h = image.size @@ -88,8 +99,7 @@ def crop_dataset_by_bounding_boxes_and_save(dataset: LightlyDataset, # add the filename of the cropped image to the corresponding list cropped_image_filename: str = os.path.join( - filename_image.replace(image_extension, ''), - cropped_image_last_filename + filename_image.replace(image_extension, ""), cropped_image_last_filename ) cropped_images_filepaths.append(cropped_image_filename) diff --git a/lightly/utils/cropping/read_yolo_label_file.py b/lightly/utils/cropping/read_yolo_label_file.py index 2ff0e9a46..c7a97cd3f 100644 --- a/lightly/utils/cropping/read_yolo_label_file.py +++ b/lightly/utils/cropping/read_yolo_label_file.py @@ -3,7 +3,9 @@ from lightly.active_learning.utils import BoundingBox -def read_yolo_label_file(filepath: str, padding: float, separator: str = ' ') -> Tuple[List[int], List[BoundingBox]]: +def read_yolo_label_file( + filepath: str, padding: float, separator: str = " " +) -> Tuple[List[int], List[BoundingBox]]: """Reads a file in the yolo file format Args: @@ -20,7 +22,7 @@ def read_yolo_label_file(filepath: str, padding: float, separator: str = ' ') -> The bounding boxes. """ - with open(filepath, 'r') as f: + with open(filepath, "r") as f: lines = f.readlines() class_indices = [] @@ -31,8 +33,8 @@ def read_yolo_label_file(filepath: str, padding: float, separator: str = ' ') -> class_id = int(class_id) class_indices.append(class_id) - w_norm *= 1+padding - h_norm *= 1+padding + w_norm *= 1 + padding + h_norm *= 1 + padding bbox = BoundingBox.from_yolo_label(x_norm, y_norm, w_norm, h_norm) bounding_boxes.append(bbox) return class_indices, bounding_boxes diff --git a/lightly/utils/debug.py b/lightly/utils/debug.py index 2a74db952..db96dea93 100644 --- a/lightly/utils/debug.py +++ b/lightly/utils/debug.py @@ -1,8 +1,8 @@ from typing import List, Union -from PIL import Image import torch import torchvision +from PIL import Image from lightly.data.collate import BaseCollateFunction, MultiViewCollateFunction @@ -14,6 +14,7 @@ "functionalities. See https://matplotlib.org/ for installation instructions." ) + def _check_matplotlib_available() -> None: if isinstance(plt, Exception): raise plt @@ -38,12 +39,12 @@ def std_of_l2_normalized(z: torch.Tensor) -> torch.Tensor: Returns: The mean of the standard deviation of the l2 normalized tensor z along each dimension. - + """ if len(z.shape) != 2: raise ValueError( - f'Input tensor must have two dimensions but has {len(z.shape)}!' + f"Input tensor must have two dimensions but has {len(z.shape)}!" ) z_norm = torch.nn.functional.normalize(z, dim=1) @@ -54,9 +55,7 @@ def apply_transform_without_normalize( image: Image.Image, transform, ): - """Applies the transform to the image but skips ToTensor and Normalize. - - """ + """Applies the transform to the image but skips ToTensor and Normalize.""" skippable_transforms = ( torchvision.transforms.ToTensor, torchvision.transforms.Normalize, @@ -91,21 +90,25 @@ def generate_grid_of_augmented_images( grid = [] if isinstance(collate_function, BaseCollateFunction): for _ in range(2): - grid.append([ - apply_transform_without_normalize(image, collate_function.transform) - for image in input_images - ]) + grid.append( + [ + apply_transform_without_normalize(image, collate_function.transform) + for image in input_images + ] + ) elif isinstance(collate_function, MultiViewCollateFunction): for transform in collate_function.transforms: - grid.append([ - apply_transform_without_normalize(image, transform) - for image in input_images - ]) + grid.append( + [ + apply_transform_without_normalize(image, transform) + for image in input_images + ] + ) else: raise ValueError( - 'Collate function must be one of ' - '(BaseCollateFunction, MultiViewCollateFunction) ' - f'but is {type(collate_function)}.' + "Collate function must be one of " + "(BaseCollateFunction, MultiViewCollateFunction) " + f"but is {type(collate_function)}." ) return grid @@ -136,7 +139,7 @@ def plot_augmented_images( _check_matplotlib_available() if len(input_images) == 0: - raise ValueError('There must be at least one input image.') + raise ValueError("There must be at least one input image.") grid = generate_grid_of_augmented_images(input_images, collate_function) grid.insert(0, input_images) @@ -153,10 +156,10 @@ def plot_augmented_images( ax.set_axis_off() ax_top_left = axs[0, 0] if len(input_images) > 1 else axs[0] - ax_top_left.set(title='Original images') + ax_top_left.set(title="Original images") ax_top_left.title.set_size(8) ax_top_next = axs[0, 1] if len(input_images) > 1 else axs[1] - ax_top_next.set(title='Augmented images') + ax_top_next.set(title="Augmented images") ax_top_next.title.set_size(8) fig.tight_layout() diff --git a/lightly/utils/dist.py b/lightly/utils/dist.py index 14dd370ac..64f9ed351 100644 --- a/lightly/utils/dist.py +++ b/lightly/utils/dist.py @@ -1,14 +1,15 @@ -from typing import Tuple, Optional +from typing import Optional, Tuple import torch import torch.distributed as dist + class GatherLayer(torch.autograd.Function): """Gather tensors from all processes, supporting backward propagation. - + This code was taken and adapted from here: https://github.com/Spijkervet/SimCLR - + """ @staticmethod @@ -25,24 +26,27 @@ def backward(ctx, *grads: torch.Tensor) -> torch.Tensor: grad_out[:] = grads[dist.get_rank()] return grad_out + def rank() -> int: """Returns the rank of the current process.""" return dist.get_rank() if dist.is_initialized() else 0 + def world_size() -> int: """Returns the current world size (number of distributed processes).""" return dist.get_world_size() if dist.is_initialized() else 1 + def gather(input: torch.Tensor) -> Tuple[torch.Tensor]: """Gathers this tensor from all processes. Supports backprop.""" return GatherLayer.apply(input) -def eye_rank(n: int, device: Optional[torch.device]=None) -> torch.Tensor: +def eye_rank(n: int, device: Optional[torch.device] = None) -> torch.Tensor: """Returns an (n, n * world_size) zero matrix with the diagonal for the rank of this process set to 1. - Example output where n=3, the current process has rank 1, and there are + Example output where n=3, the current process has rank 1, and there are 4 processes in total: rank0 rank1 rank2 rank3 diff --git a/lightly/utils/embeddings_2d.py b/lightly/utils/embeddings_2d.py index 47c28c6a5..0b40a4bac 100644 --- a/lightly/utils/embeddings_2d.py +++ b/lightly/utils/embeddings_2d.py @@ -56,7 +56,7 @@ def transform(self, X: np.ndarray): """ X = X.astype(np.float32) X = X - self.mean + self.eps - return X.dot(self.w)[:, :self.n_components] + return X.dot(self.w)[:, : self.n_components] def fit_pca(embeddings: np.ndarray, n_components: int = 2, fraction: float = None): @@ -83,8 +83,8 @@ def fit_pca(embeddings: np.ndarray, n_components: int = 2, fraction: float = Non """ if fraction is not None: - if fraction < 0. or fraction > 1.: - msg = f'fraction must be in [0, 1] but was {fraction}.' + if fraction < 0.0 or fraction > 1.0: + msg = f"fraction must be in [0, 1] but was {fraction}." raise ValueError(msg) N = embeddings.shape[0] diff --git a/lightly/utils/hipify.py b/lightly/utils/hipify.py index 8f320712e..fd2d7097b 100644 --- a/lightly/utils/hipify.py +++ b/lightly/utils/hipify.py @@ -4,14 +4,14 @@ class bcolors: - HEADER = '\033[95m' - OKBLUE = '\033[94m' - OKGREEN = '\033[92m' - WARNING = '\033[93m' - FAIL = '\033[91m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' + HEADER = "\033[95m" + OKBLUE = "\033[94m" + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" def _custom_formatwarning(msg, *args, **kwargs): @@ -19,11 +19,10 @@ def _custom_formatwarning(msg, *args, **kwargs): return f"{bcolors.WARNING}{msg}{bcolors.WARNING}\n" - def print_as_warning(message: str, warning_class: Type[Warning] = UserWarning): old_format = copy.copy(warnings.formatwarning) warnings.formatwarning = _custom_formatwarning warnings.warn(message, warning_class) - warnings.formatwarning = old_format \ No newline at end of file + warnings.formatwarning = old_format diff --git a/lightly/utils/io.py b/lightly/utils/io.py index 81078c3f4..00ca63a1c 100644 --- a/lightly/utils/io.py +++ b/lightly/utils/io.py @@ -3,11 +3,11 @@ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved -import json import csv -from typing import List, Tuple, Dict +import json import re from itertools import compress +from typing import Dict, List, Tuple import numpy as np diff --git a/lightly/utils/reordering.py b/lightly/utils/reordering.py index 6dddd8b86..22311d8b6 100644 --- a/lightly/utils/reordering.py +++ b/lightly/utils/reordering.py @@ -1,11 +1,7 @@ from typing import List, Sized -def sort_items_by_keys( - keys: List[any], - items: List[any], - sorted_keys: List[any] -): +def sort_items_by_keys(keys: List[any], items: List[any], sorted_keys: List[any]): """Sorts the items in the same order as the sorted keys. Args: @@ -33,10 +29,12 @@ def sort_items_by_keys( """ if len(keys) != len(items) or len(keys) != len(sorted_keys): - raise ValueError(f"All inputs (keys, items and sorted_keys) " - f"must have the same length, " - f"but their lengths are: ({len(keys)}," - f"{len(items)} and {len(sorted_keys)}).") + raise ValueError( + f"All inputs (keys, items and sorted_keys) " + f"must have the same length, " + f"but their lengths are: ({len(keys)}," + f"{len(items)} and {len(sorted_keys)})." + ) lookup = {key_: item_ for key_, item_ in zip(keys, items)} sorted_ = [lookup[key_] for key_ in sorted_keys] return sorted_ diff --git a/lightly/utils/scheduler.py b/lightly/utils/scheduler.py index 6ed77d616..56c7a2ec8 100644 --- a/lightly/utils/scheduler.py +++ b/lightly/utils/scheduler.py @@ -1,11 +1,10 @@ -import torch import numpy as np +import torch def cosine_schedule( step: int, max_steps: int, start_value: float, end_value: float ) -> float: - """ Use cosine decay to gradually modify start_value to reach target end_value during iterations. @@ -30,7 +29,7 @@ def cosine_schedule( if step > max_steps: # Note: we allow step == max_steps even though step starts at 0 and should end # at max_steps - 1. This is because Pytorch Lightning updates the LR scheduler - # always for the next epoch, even after the last training epoch. This results in + # always for the next epoch, even after the last training epoch. This results in # Pytorch Lightning calling the scheduler with step == max_steps. raise ValueError( f"The current step cannot be larger than max_steps but found step {step} and max_steps {max_steps}." diff --git a/lightly/utils/version_compare.py b/lightly/utils/version_compare.py index 1a584d6d8..898407cfb 100644 --- a/lightly/utils/version_compare.py +++ b/lightly/utils/version_compare.py @@ -6,7 +6,7 @@ def version_compare(v0: str, v1: str): """Returns 1 if version of v0 is larger than v1 and -1 otherwise - + Use this method to compare Python package versions and see which one is newer. @@ -16,12 +16,12 @@ def version_compare(v0: str, v1: str): >>> version_compare('1.2.0', '1.1.2') >>> 1 """ - v0 = [int(n) for n in v0.split('.')][::-1] - v1 = [int(n) for n in v1.split('.')][::-1] + v0 = [int(n) for n in v0.split(".")][::-1] + v1 = [int(n) for n in v1.split(".")][::-1] if len(v0) != 3 or len(v1) != 3: raise ValueError( - f'Length of version strings is not 3 (expected pattern `x.y.z`) but is ' - f'{v0} and {v1}.' + f"Length of version strings is not 3 (expected pattern `x.y.z`) but is " + f"{v0} and {v1}." ) pairs = list(zip(v0, v1))[::-1] for x, y in pairs: diff --git a/pylintrc b/pylintrc index b08fa1c8a..2de9082d0 100644 --- a/pylintrc +++ b/pylintrc @@ -265,7 +265,7 @@ generated-members= [FORMAT] # Maximum number of characters on a single line. -max-line-length=80 +max-line-length=88 # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt # lines made too long by directives to pytype. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..18cdd7197 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[tool.black] +extend-exclude = "lightly/openapi_generated/.*" + +[tool.isort] +profile = "black" +extend_skip = "lightly/openapi_generated" diff --git a/requirements/dev.txt b/requirements/dev.txt index efaf296a4..41381ce4c 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -18,4 +18,6 @@ opencv-python scikit-learn pandas torchmetrics -lightning-bolts # for LARS optimizer \ No newline at end of file +lightning-bolts # for LARS optimizer +black==23.1.0 # frozen version to avoid differences between CI and local dev machines +isort==5.11.5 # frozen version to avoid differences between CI and local dev machines diff --git a/setup.py b/setup.py index d4cb54c9f..7f3c310a4 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ -import setuptools -import sys import os +import sys + +import setuptools try: import builtins @@ -12,43 +13,41 @@ import lightly -def load_description(path_dir=PATH_ROOT, filename='DOCS.md'): - """Load long description from readme in the path_dir/ directory - """ +def load_description(path_dir=PATH_ROOT, filename="DOCS.md"): + """Load long description from readme in the path_dir/ directory""" with open(os.path.join(path_dir, filename)) as f: long_description = f.read() return long_description -def load_requirements(path_dir=PATH_ROOT, filename='base.txt', comment_char='#'): +def load_requirements(path_dir=PATH_ROOT, filename="base.txt", comment_char="#"): """From pytorch-lightning repo: https://github.com/PyTorchLightning/pytorch-lightning. - Load requirements from text file in the path_dir/requirements/ directory. + Load requirements from text file in the path_dir/requirements/ directory. """ - with open(os.path.join(path_dir, 'requirements', filename), 'r') as file: + with open(os.path.join(path_dir, "requirements", filename), "r") as file: lines = [ln.strip() for ln in file.readlines()] reqs = [] for ln in lines: # filer all comments if comment_char in ln: - ln = ln[:ln.index(comment_char)].strip() + ln = ln[: ln.index(comment_char)].strip() # skip directly installed dependencies - if ln.startswith('http'): + if ln.startswith("http"): continue if ln: # if requirement is not empty reqs.append(ln) return reqs -if __name__ == '__main__': - - name = 'lightly' +if __name__ == "__main__": + name = "lightly" version = lightly.__version__ description = lightly.__doc__ - author = 'Philipp Wirth & Igor Susmelj' - author_email = 'philipp@lightly.ai' + author = "Philipp Wirth & Igor Susmelj" + author_email = "philipp@lightly.ai" description = "A deep learning package for self-supervised learning" entry_points = { @@ -65,48 +64,48 @@ def load_requirements(path_dir=PATH_ROOT, filename='base.txt', comment_char='#') long_description = load_description() - python_requires = '>=3.6' + python_requires = ">=3.6" install_requires = load_requirements() - video_requires = load_requirements(filename='video.txt') - dev_requires = load_requirements(filename='dev.txt') + video_requires = load_requirements(filename="video.txt") + dev_requires = load_requirements(filename="dev.txt") all_requires = dev_requires + video_requires extras_require = { - 'video': video_requires, - 'dev': dev_requires, - 'all': all_requires, + "video": video_requires, + "dev": dev_requires, + "all": all_requires, } packages = [ - 'lightly', - 'lightly.api', - 'lightly.cli', - 'lightly.cli.config', - 'lightly.data', - 'lightly.embedding', - 'lightly.loss', - 'lightly.loss.regularizer', - 'lightly.models', - 'lightly.models.modules', - 'lightly.transforms', - 'lightly.utils', - 'lightly.utils.cropping', - 'lightly.active_learning', - 'lightly.active_learning.agents', - 'lightly.active_learning.config', - 'lightly.active_learning.scorers', - 'lightly.active_learning.utils', - 'lightly.openapi_generated', - 'lightly.openapi_generated.swagger_client', - 'lightly.openapi_generated.swagger_client.api', - 'lightly.openapi_generated.swagger_client.models' + "lightly", + "lightly.api", + "lightly.cli", + "lightly.cli.config", + "lightly.data", + "lightly.embedding", + "lightly.loss", + "lightly.loss.regularizer", + "lightly.models", + "lightly.models.modules", + "lightly.transforms", + "lightly.utils", + "lightly.utils.cropping", + "lightly.active_learning", + "lightly.active_learning.agents", + "lightly.active_learning.config", + "lightly.active_learning.scorers", + "lightly.active_learning.utils", + "lightly.openapi_generated", + "lightly.openapi_generated.swagger_client", + "lightly.openapi_generated.swagger_client.api", + "lightly.openapi_generated.swagger_client.models", ] project_urls = { - 'Homepage': 'https://www.lightly.ai', - 'Web-App': 'https://app.lightly.ai', - 'Documentation': 'https://docs.lightly.ai', - 'Github': 'https://github.com/lightly-ai/lightly', - 'Discord': 'https://discord.gg/xvNJW94', + "Homepage": "https://www.lightly.ai", + "Web-App": "https://app.lightly.ai", + "Documentation": "https://docs.lightly.ai", + "Github": "https://github.com/lightly-ai/lightly", + "Discord": "https://discord.gg/xvNJW94", } classifiers = [ @@ -125,7 +124,7 @@ def load_requirements(path_dir=PATH_ROOT, filename='base.txt', comment_char='#') "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", - "License :: OSI Approved :: MIT License" + "License :: OSI Approved :: MIT License", ] setuptools.setup( @@ -135,9 +134,9 @@ def load_requirements(path_dir=PATH_ROOT, filename='base.txt', comment_char='#') author_email=author_email, description=description, entry_points=entry_points, - license='MIT', + license="MIT", long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", install_requires=install_requires, extras_require=extras_require, python_requires=python_requires, @@ -146,5 +145,3 @@ def load_requirements(path_dir=PATH_ROOT, filename='base.txt', comment_char='#') include_package_data=True, project_urls=project_urls, ) - - diff --git a/tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py b/tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py index b5adfaa9c..1d979705c 100644 --- a/tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py +++ b/tests/UNMOCKED_end2end_tests/create_custom_metadata_from_input_dir.py @@ -5,20 +5,19 @@ if __name__ == "__main__": if len(sys.argv) == 1 + 2: - input_dir, metadata_filename= \ - (sys.argv[1 + i] for i in range(2)) + input_dir, metadata_filename = (sys.argv[1 + i] for i in range(2)) else: - raise ValueError("ERROR in number of command line arguments, must be 2." - "Example: python create_custom_metadata_from_input_dir.py input_dir metadata_filename") + raise ValueError( + "ERROR in number of command line arguments, must be 2." + "Example: python create_custom_metadata_from_input_dir.py input_dir metadata_filename" + ) dataset = LightlyDataset(input_dir) # create a list of pairs of (filename, metadata) custom_metadata = [] for index, filename in enumerate(dataset.get_filenames()): - metadata = {'index': index} + metadata = {"index": index} custom_metadata.append((filename, metadata)) save_custom_metadata(metadata_filename, custom_metadata) - - diff --git a/tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py b/tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py index 78505e546..a671d7745 100644 --- a/tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py +++ b/tests/UNMOCKED_end2end_tests/delete_datasets_test_unmocked_cli.py @@ -4,16 +4,17 @@ if __name__ == "__main__": if len(sys.argv) == 1 + 3: - num_datasets, token, date_time = \ - (sys.argv[1 + i] for i in range(3)) + num_datasets, token, date_time = (sys.argv[1 + i] for i in range(3)) else: - raise ValueError("ERROR in number of command line arguments, must be 3." - "Example: python delete_datasets_test_unmocked_cli.py 6 LIGHTLY_TOKEN 2022-09-29-13-41-24") + raise ValueError( + "ERROR in number of command line arguments, must be 3." + "Example: python delete_datasets_test_unmocked_cli.py 6 LIGHTLY_TOKEN 2022-09-29-13-41-24" + ) api_workflow_client = ApiWorkflowClient(token=token) num_datasets = int(num_datasets) - for i in range(1, num_datasets+1): + for i in range(1, num_datasets + 1): dataset_name = f"test_unmocked_cli_{i}_{date_time}" api_workflow_client.set_dataset_id_by_name(dataset_name) api_workflow_client.delete_dataset_by_id(api_workflow_client.dataset_id) diff --git a/tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py b/tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py index 6cec51f82..9d58ea06f 100644 --- a/tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py +++ b/tests/UNMOCKED_end2end_tests/scripts_for_reproducing_problems/test_api_latency.py @@ -5,7 +5,7 @@ from tqdm import tqdm from lightly.api import ApiWorkflowClient -from lightly.openapi_generated.swagger_client import Configuration, ApiClient, QuotaApi +from lightly.openapi_generated.swagger_client import ApiClient, Configuration, QuotaApi if __name__ == "__main__": token = os.getenv("LIGHTLY_TOKEN") @@ -18,13 +18,15 @@ for i in tqdm(range(n_iters)): start = time.time() quota_api.get_quota_maximum_dataset_size() - duration = time.time()-start + duration = time.time() - start latencies[i] = duration def format_latency(latency: float): return f"{latency*1000:.1f}ms" - values = [('min')] - print(f"Latencies: min: {format_latency(np.min(latencies))}, mean: {format_latency(np.mean(latencies))}, max: {format_latency(np.max(latencies))}") + values = [("min")] + print( + f"Latencies: min: {format_latency(np.min(latencies))}, mean: {format_latency(np.mean(latencies))}, max: {format_latency(np.max(latencies))}" + ) print(f"\nPINGING TO GOOGLE") - response = os.system("ping -c 1 " + "google.com") \ No newline at end of file + response = os.system("ping -c 1 " + "google.com") diff --git a/tests/UNMOCKED_end2end_tests/test_api.py b/tests/UNMOCKED_end2end_tests/test_api.py index 707738e34..146f2e585 100644 --- a/tests/UNMOCKED_end2end_tests/test_api.py +++ b/tests/UNMOCKED_end2end_tests/test_api.py @@ -4,39 +4,37 @@ from typing import List, Tuple import numpy as np -from hydra.experimental import initialize, compose - -from lightly.cli import upload_cli -from lightly.data.dataset import LightlyDataset - -from lightly.active_learning.scorers.classification import ScorerClassification - -from lightly.active_learning.config.selection_config import SelectionConfig - -from lightly.api.bitmask import BitMask - -from lightly.openapi_generated.swagger_client.models.tag_create_request import TagCreateRequest +from hydra.experimental import compose, initialize from lightly.active_learning.agents.agent import ActiveLearningAgent -from lightly.openapi_generated.swagger_client.models.sampling_method import SamplingMethod - +from lightly.active_learning.config.selection_config import SelectionConfig +from lightly.active_learning.scorers.classification import ScorerClassification from lightly.api.api_workflow_client import ApiWorkflowClient +from lightly.api.bitmask import BitMask +from lightly.cli import upload_cli +from lightly.data.dataset import LightlyDataset +from lightly.openapi_generated.swagger_client.models.sampling_method import ( + SamplingMethod, +) +from lightly.openapi_generated.swagger_client.models.tag_create_request import ( + TagCreateRequest, +) from lightly.utils.io import save_embeddings class CSVEmbeddingDataset: def __init__(self, path_to_embeddings_csv: str): - with open(path_to_embeddings_csv, 'r') as f: + with open(path_to_embeddings_csv, "r") as f: data = csv.reader(f) rows = list(data) header_row = rows[0] rows_without_header = rows[1:] - index_filenames = header_row.index('filenames') + index_filenames = header_row.index("filenames") filenames = [row[index_filenames] for row in rows_without_header] - index_labels = header_row.index('labels') + index_labels = header_row.index("labels") labels = [row[index_labels] for row in rows_without_header] embeddings = rows_without_header @@ -45,8 +43,12 @@ def __init__(self, path_to_embeddings_csv: str): for index_to_delete in indexes_to_delete: del embedding_row[index_to_delete] - self.dataset = dict([(filename, (np.array(embedding_row, dtype=float), int(label))) - for filename, embedding_row, label in zip(filenames, embeddings, labels)]) + self.dataset = dict( + [ + (filename, (np.array(embedding_row, dtype=float), int(label))) + for filename, embedding_row, label in zip(filenames, embeddings, labels) + ] + ) def get_features(self, filenames: List[str]) -> np.ndarray: features_array = np.array([self.dataset[filename][0] for filename in filenames]) @@ -63,21 +65,26 @@ def all_features_labels(self) -> Tuple[np.ndarray, np.ndarray]: return features, labels -def create_new_dataset_with_embeddings(path_to_dataset: str, - token: str, - dataset_name: str) -> ApiWorkflowClient: +def create_new_dataset_with_embeddings( + path_to_dataset: str, token: str, dataset_name: str +) -> ApiWorkflowClient: api_workflow_client = ApiWorkflowClient(token=token) # create the dataset - api_workflow_client.create_new_dataset_with_unique_name(dataset_basename=dataset_name) + api_workflow_client.create_new_dataset_with_unique_name( + dataset_basename=dataset_name + ) # upload to the dataset initialize(config_path="../../lightly/cli/config", job_name="test_app") - cfg = compose(config_name="config", overrides=[ - f"input_dir='{path_to_dataset}'", - f"token='{token}'", - f"dataset_id={api_workflow_client.dataset_id}" - ]) + cfg = compose( + config_name="config", + overrides=[ + f"input_dir='{path_to_dataset}'", + f"token='{token}'", + f"dataset_id={api_workflow_client.dataset_id}", + ], + ) upload_cli(cfg) # calculate and save the embeddings @@ -86,42 +93,54 @@ def create_new_dataset_with_embeddings(path_to_dataset: str, dataset = LightlyDataset(input_dir=path_to_dataset) embeddings = np.random.normal(size=(len(dataset.dataset.samples), 32)) filepaths, labels = zip(*dataset.dataset.samples) - filenames = [filepath[len(path_to_dataset):].lstrip('/') for filepath in filepaths] + filenames = [ + filepath[len(path_to_dataset) :].lstrip("/") for filepath in filepaths + ] print("Starting save of embeddings") save_embeddings(path_to_embeddings_csv, embeddings, labels, filenames) print("Finished save of embeddings") # upload the embeddings print("Starting upload of embeddings.") - api_workflow_client.upload_embeddings(path_to_embeddings_csv=path_to_embeddings_csv, name="embedding_1") + api_workflow_client.upload_embeddings( + path_to_embeddings_csv=path_to_embeddings_csv, name="embedding_1" + ) print("Finished upload of embeddings.") return api_workflow_client -def t_est_active_learning(api_workflow_client: ApiWorkflowClient, - method: SamplingMethod = SamplingMethod.CORAL, - query_tag_name: str = 'initial-tag', - preselected_tag_name: str = None, - n_samples_additional: List[int] = [2, 5]): +def t_est_active_learning( + api_workflow_client: ApiWorkflowClient, + method: SamplingMethod = SamplingMethod.CORAL, + query_tag_name: str = "initial-tag", + preselected_tag_name: str = None, + n_samples_additional: List[int] = [2, 5], +): # create the tags with 100 respectively 10 samples if not yet existant if query_tag_name is not None: - selection_config = SelectionConfig(method=SamplingMethod.RANDOM, n_samples=100, name=query_tag_name) + selection_config = SelectionConfig( + method=SamplingMethod.RANDOM, n_samples=100, name=query_tag_name + ) try: api_workflow_client.selection(selection_config=selection_config) except RuntimeError: pass if preselected_tag_name is not None: - selection_config = SelectionConfig(method=SamplingMethod.RANDOM, n_samples=10, name=preselected_tag_name) + selection_config = SelectionConfig( + method=SamplingMethod.RANDOM, n_samples=10, name=preselected_tag_name + ) try: api_workflow_client.selection(selection_config=selection_config) except RuntimeError: pass # define the active learning agent - agent = ActiveLearningAgent(api_workflow_client, - query_tag_name=query_tag_name, - preselected_tag_name=preselected_tag_name) + agent = ActiveLearningAgent( + api_workflow_client, + query_tag_name=query_tag_name, + preselected_tag_name=preselected_tag_name, + ) total_no_samples = len(agent.unlabeled_set) + len(agent.labeled_set) @@ -129,10 +148,16 @@ def t_est_active_learning(api_workflow_client: ApiWorkflowClient, for iteration, n_samples_additional in enumerate(n_samples_additional): n_samples = len(agent.labeled_set) + n_samples_additional - print(f"Beginning with iteration {iteration} to have {n_samples} labeled samples.") + print( + f"Beginning with iteration {iteration} to have {n_samples} labeled samples." + ) # Perform a selection - method_here = SamplingMethod.CORESET if iteration == 0 and method == SamplingMethod.CORAL else method + method_here = ( + SamplingMethod.CORESET + if iteration == 0 and method == SamplingMethod.CORAL + else method + ) selection_config = SelectionConfig(method=method_here, n_samples=n_samples) if al_scorer is None: agent.query(selection_config=selection_config) @@ -146,42 +171,50 @@ def t_est_active_learning(api_workflow_client: ApiWorkflowClient, n_samples = len(agent.query_set) n_classes = 10 predictions = np.random.rand(n_samples, n_classes) - predictions_normalized = predictions / np.sum(predictions, axis=1)[:, np.newaxis] + predictions_normalized = ( + predictions / np.sum(predictions, axis=1)[:, np.newaxis] + ) model_output = predictions_normalized al_scorer = ScorerClassification(model_output=predictions) print("Success!") -def t_est_api_with_matrix(path_to_dataset: str, - token: str, dataset_name: str = "test_api_from_pip"): - +def t_est_api_with_matrix( + path_to_dataset: str, token: str, dataset_name: str = "test_api_from_pip" +): no_samples = len(LightlyDataset(input_dir=path_to_dataset).dataset.samples) assert no_samples >= 100, "Test needs at least 100 samples in the dataset!" api_workflow_client = create_new_dataset_with_embeddings( - path_to_dataset=path_to_dataset, token=token, - dataset_name=dataset_name + path_to_dataset=path_to_dataset, token=token, dataset_name=dataset_name ) for method in [SamplingMethod.CORAL, SamplingMethod.CORESET, SamplingMethod.RANDOM]: - for query_tag_name in ['initial-tag', "query_tag_name_xyz"]: + for query_tag_name in ["initial-tag", "query_tag_name_xyz"]: for preselected_tag_name in [None, "preselected_tag_name_xyz"]: - print(f"Starting AL run with method '{method}', query_tag '{query_tag_name}' " - f"and preselected_tag '{preselected_tag_name}'.") - t_est_active_learning(api_workflow_client, method, query_tag_name, preselected_tag_name) + print( + f"Starting AL run with method '{method}', query_tag '{query_tag_name}' " + f"and preselected_tag '{preselected_tag_name}'." + ) + t_est_active_learning( + api_workflow_client, method, query_tag_name, preselected_tag_name + ) api_workflow_client.delete_dataset_by_id(api_workflow_client.dataset_id) - print("Success of the complete test suite! The dataset on the server was deleted again.") + print( + "Success of the complete test suite! The dataset on the server was deleted again." + ) if __name__ == "__main__": if len(sys.argv) == 1 + 2: - path_to_dataset, token = \ - (sys.argv[1 + i] for i in range(2)) + path_to_dataset, token = (sys.argv[1 + i] for i in range(2)) else: - raise ValueError("ERROR in number of command line arguments, must be 2." - "Example: python test_api path/to/dataset LIGHTLY_TOKEN") + raise ValueError( + "ERROR in number of command line arguments, must be 2." + "Example: python test_api path/to/dataset LIGHTLY_TOKEN" + ) t_est_api_with_matrix(path_to_dataset=path_to_dataset, token=token) diff --git a/tests/UNMOCKED_end2end_tests/test_api_append.py b/tests/UNMOCKED_end2end_tests/test_api_append.py index f11387e04..49eb68358 100644 --- a/tests/UNMOCKED_end2end_tests/test_api_append.py +++ b/tests/UNMOCKED_end2end_tests/test_api_append.py @@ -5,48 +5,50 @@ from lightly.data import LightlyDataset from lightly.utils.io import format_custom_metadata -from tests.UNMOCKED_end2end_tests.test_api import \ - create_new_dataset_with_embeddings +from tests.UNMOCKED_end2end_tests.test_api import create_new_dataset_with_embeddings -def t_est_api_append(path_to_dataset: str, token: str, - dataset_name: str = "test_api_from_pip_append"): +def t_est_api_append( + path_to_dataset: str, token: str, dataset_name: str = "test_api_from_pip_append" +): files_to_delete = [] try: print("Save custom metadata") dataset = LightlyDataset(path_to_dataset) path_custom_metadata = f"{path_to_dataset}/custom_metadata.csv" - custom_metadata = [(filename, {"metadata": f"{filename}_meta"}) for - filename in dataset.get_filenames()] + custom_metadata = [ + (filename, {"metadata": f"{filename}_meta"}) + for filename in dataset.get_filenames() + ] print("Upload to the dataset") api_workflow_client = create_new_dataset_with_embeddings( - path_to_dataset=path_to_dataset, token=token, - dataset_name=dataset_name) + path_to_dataset=path_to_dataset, token=token, dataset_name=dataset_name + ) api_workflow_client.upload_custom_metadata( - format_custom_metadata(custom_metadata)) + format_custom_metadata(custom_metadata) + ) print("save additional images and embeddings and custom metadata") n_data = 5 - dataset = torchvision.datasets.FakeData(size=n_data, - image_size=(3, 32, 32)) - sample_names = [f'img_{i}.jpg' for i in range(n_data)] + dataset = torchvision.datasets.FakeData(size=n_data, image_size=(3, 32, 32)) + sample_names = [f"img_{i}.jpg" for i in range(n_data)] for sample_idx in range(n_data): data = dataset[sample_idx] path = os.path.join(path_to_dataset, sample_names[sample_idx]) files_to_delete.append(path) data[0].save(path) - custom_metadata += [(filename, {"metadata": f"{filename}_meta"}) for - filename in sample_names] + custom_metadata += [ + (filename, {"metadata": f"{filename}_meta"}) for filename in sample_names + ] print("Upload to the dataset") api_workflow_client.upload_dataset(path_to_dataset) print("Upload custom metadata") api_workflow_client.upload_custom_metadata( - format_custom_metadata(custom_metadata)) - - + format_custom_metadata(custom_metadata) + ) finally: for filename in files_to_delete: @@ -62,6 +64,7 @@ def t_est_api_append(path_to_dataset: str, token: str, else: raise ValueError( "ERROR in number of command line arguments, must be 2." - "Example: python test_api path/to/dataset LIGHTLY_TOKEN") + "Example: python test_api path/to/dataset LIGHTLY_TOKEN" + ) - t_est_api_append(path_to_dataset=path_to_dataset, token=token) \ No newline at end of file + t_est_api_append(path_to_dataset=path_to_dataset, token=token) diff --git a/tests/UNMOCKED_end2end_tests/test_download_large_files.py b/tests/UNMOCKED_end2end_tests/test_download_large_files.py index 5c60dcc50..5eb815288 100644 --- a/tests/UNMOCKED_end2end_tests/test_download_large_files.py +++ b/tests/UNMOCKED_end2end_tests/test_download_large_files.py @@ -1,6 +1,7 @@ import time import lightly + lightly.api.utils.RETRY_MAX_RETRIES = 1 from lightly.api.download import download_image @@ -13,4 +14,4 @@ img = download_image(url_5MB, request_kwargs={"verify": False}) print(f"Took {time.time()-start:5.2f}s to download the image.") -img.show() \ No newline at end of file +img.show() diff --git a/tests/active_learning/test_BoundingBox.py b/tests/active_learning/test_BoundingBox.py index df362053a..d8f947403 100644 --- a/tests/active_learning/test_BoundingBox.py +++ b/tests/active_learning/test_BoundingBox.py @@ -4,13 +4,12 @@ class TestBoundingBox(unittest.TestCase): - def test_bounding_box(self): bbox = BoundingBox(0.2, 0.3, 0.5, 0.6) self.assertEqual(bbox.x0, 0.2) self.assertEqual(bbox.y0, 0.3) self.assertEqual(bbox.x1, 0.5) - self.assertEqual(bbox.y1, 0.6) + self.assertEqual(bbox.y1, 0.6) def test_bounding_box_2(self): bbox = BoundingBox.from_x_y_w_h(0.2, 0.3, 0.3, 0.3) @@ -29,7 +28,6 @@ def test_bounding_box_illogical_argument_2(self): # let y1 < y0 bbox = BoundingBox(0.2, 0.6, 0.5, 0.3, clip_values=False) - def test_bounding_box_oob_arguments(self): with self.assertRaises(ValueError): bbox = BoundingBox(20, 30, 100, 200) diff --git a/tests/active_learning/test_ObjectDetectionOutput.py b/tests/active_learning/test_ObjectDetectionOutput.py index 772585171..260ab963d 100644 --- a/tests/active_learning/test_ObjectDetectionOutput.py +++ b/tests/active_learning/test_ObjectDetectionOutput.py @@ -7,86 +7,80 @@ class TestObjectDetectionOutput(unittest.TestCase): - def setUp(self): self.dummy_data = [ { - 'boxes': [ - [14, 16, 52, 85], - [58, 23, 124, 49] - ], - 'object_probabilities': [ - 0.57573, - 0.988 - ], - 'class_probabilities': [ - [0.7, 0.2, 0.1], - [0.4, 0.5, 0.1] - ], - 'labels': [ + "boxes": [[14, 16, 52, 85], [58, 23, 124, 49]], + "object_probabilities": [0.57573, 0.988], + "class_probabilities": [[0.7, 0.2, 0.1], [0.4, 0.5, 0.1]], + "labels": [ 0, 1, - ] + ], }, { - 'boxes': [ + "boxes": [ [14, 16, 52, 85], ], - 'object_probabilities': [ + "object_probabilities": [ 0.1024, ], - 'class_probabilities': [ + "class_probabilities": [ [0.5, 0.41, 0.09], ], - 'labels': [0] + "labels": [0], }, { - 'boxes': [ + "boxes": [ [14, 16, 52, 85], ], - 'object_probabilities': [ + "object_probabilities": [ 1.0, ], - 'class_probabilities': [ + "class_probabilities": [ [0.0, 1.0, 0.0], ], - 'labels': [4] + "labels": [4], }, { - 'boxes': [], - 'object_probabilities': [], - 'class_probabilities': [], - 'labels': [] - } + "boxes": [], + "object_probabilities": [], + "class_probabilities": [], + "labels": [], + }, ] # convert bounding boxes W, H = 128, 128 for data in self.dummy_data: - for i, box in enumerate(data['boxes']): + for i, box in enumerate(data["boxes"]): x0 = box[0] / W y0 = box[1] / H x1 = box[2] / W y1 = box[3] / H - data['boxes'][i] = BoundingBox(x0, y0, x1, y1) + data["boxes"][i] = BoundingBox(x0, y0, x1, y1) def test_object_detection_output(self): - outputs_1 = [] outputs_2 = [] for i, data in enumerate(self.dummy_data): output = ObjectDetectionOutput( - data['boxes'], - data['object_probabilities'], - data['class_probabilities'], + data["boxes"], + data["object_probabilities"], + data["class_probabilities"], ) outputs_1.append(output) - scores = [o * max(c) for o, c in zip(data['object_probabilities'], data['class_probabilities'])] + scores = [ + o * max(c) + for o, c in zip( + data["object_probabilities"], data["class_probabilities"] + ) + ] output_2 = ObjectDetectionOutput.from_scores( - data['boxes'], + data["boxes"], scores, - data['labels'], + data["labels"], ) for output_1, output_2 in zip(outputs_1, outputs_2): @@ -95,54 +89,39 @@ def test_object_detection_output(self): for x, y in zip(output_1.scores, output_2.scores): self.assertEqual(x, y) - def test_object_detection_output_from_scores(self): outputs = [] for i, data in enumerate(self.dummy_data): output = ObjectDetectionOutput.from_scores( - data['boxes'], - data['object_probabilities'], - data['labels'], + data["boxes"], + data["object_probabilities"], + data["labels"], ) outputs.append(output) - + for output in outputs: for class_probs in output.class_probabilities: self.assertEqual(np.sum(class_probs), 1.0) - def test_object_detection_output_illegal_args(self): - with self.assertRaises(ValueError): # score > 1 - ObjectDetectionOutput.from_scores( - [BoundingBox(0, 0, 1, 1)], - [1.1], - [0] - ) + ObjectDetectionOutput.from_scores([BoundingBox(0, 0, 1, 1)], [1.1], [0]) with self.assertRaises(ValueError): # score < 0 - ObjectDetectionOutput.from_scores( - [BoundingBox(0, 0, 1, 1)], - [-1.], - [1] - ) + ObjectDetectionOutput.from_scores([BoundingBox(0, 0, 1, 1)], [-1.0], [1]) with self.assertRaises(ValueError): # different length - ObjectDetectionOutput( - [BoundingBox(0, 0, 1, 1)], - [0.5, 0.2], - [1, 2] - ) + ObjectDetectionOutput([BoundingBox(0, 0, 1, 1)], [0.5, 0.2], [1, 2]) with self.assertRaises(ValueError): # string labels ObjectDetectionOutput.from_scores( [BoundingBox(0, 0, 1, 1)], [1.1], - ['hello'], + ["hello"], ) with self.assertRaises(ValueError): @@ -154,6 +133,7 @@ def test_object_detection_output_illegal_args(self): scores=[1.0], ) + def test_ObjectDetectionOutput__with_scores(): output = ObjectDetectionOutput( boxes=[BoundingBox(0, 0, 1, 1)], @@ -163,6 +143,7 @@ def test_ObjectDetectionOutput__with_scores(): ) assert output.scores == [0.3] + def test_ObjectDetectionOutput__without_scores(): output = ObjectDetectionOutput( boxes=[BoundingBox(0, 0, 1, 1)], diff --git a/tests/active_learning/test_ScorerClassification.py b/tests/active_learning/test_ScorerClassification.py index f332b95b3..70d6e7801 100644 --- a/tests/active_learning/test_ScorerClassification.py +++ b/tests/active_learning/test_ScorerClassification.py @@ -1,17 +1,22 @@ import unittest + import numpy as np -from lightly.active_learning.scorers.classification import ScorerClassification, _entropy +from lightly.active_learning.scorers.classification import ( + ScorerClassification, + _entropy, +) class TestScorerClassification(unittest.TestCase): - def test_score_calculation_random(self): n_samples = 10000 n_classes = 10 np.random.seed(42) predictions = np.random.rand(n_samples, n_classes) - predictions_normalized = predictions / np.sum(predictions, axis=1)[:, np.newaxis] + predictions_normalized = ( + predictions / np.sum(predictions, axis=1)[:, np.newaxis] + ) model_output = predictions_normalized scorer = ScorerClassification(model_output) scores = scorer.calculate_scores() @@ -25,34 +30,40 @@ def test_score_calculation_random(self): self.assertEqual(type(score), np.ndarray) def test_score_calculation_specific(self): - model_output = [ - [0.7, 0.2, 0.1], - [0.4, 0.5, 0.1] - ] + model_output = [[0.7, 0.2, 0.1], [0.4, 0.5, 0.1]] model_output = np.array(model_output) scorer = ScorerClassification(model_output) scores = scorer.calculate_scores() - self.assertListEqual(list(scores["uncertainty_least_confidence"]), - [(1 - 0.7) / (1 - 1. / 3.), (1 - 0.5) / (1 - 1. / 3.)]) - self.assertListEqual(list(scores["uncertainty_margin"]), [1 - (0.7 - 0.2), 1 - (0.5 - 0.4)]) - for val1, val2 in zip(scores["uncertainty_entropy"], _entropy(model_output) / np.log2(3)): + self.assertListEqual( + list(scores["uncertainty_least_confidence"]), + [(1 - 0.7) / (1 - 1.0 / 3.0), (1 - 0.5) / (1 - 1.0 / 3.0)], + ) + self.assertListEqual( + list(scores["uncertainty_margin"]), [1 - (0.7 - 0.2), 1 - (0.5 - 0.4)] + ) + for val1, val2 in zip( + scores["uncertainty_entropy"], _entropy(model_output) / np.log2(3) + ): self.assertAlmostEqual(val1, val2, places=8) def test_score_calculation_binary(self): - model_output = [ - [0.7], - [0.4] - ] + model_output = [[0.7], [0.4]] model_output = np.array(model_output) scorer = ScorerClassification(model_output) scores = scorer.calculate_scores() - self.assertListEqual(list(scores["uncertainty_least_confidence"]), - [(1 - 0.7) / (1 - 1. / 2.), (1 - 0.6) / (1 - 1. / 2.)]) - self.assertListEqual(list(scores["uncertainty_margin"]), [1 - (0.7 - 0.3), 1 - (0.6 - 0.4)]) - model_output = np.concatenate([model_output, 1-model_output], axis=1) - for val1, val2 in zip(scores["uncertainty_entropy"], _entropy(model_output) / np.log2(2)): + self.assertListEqual( + list(scores["uncertainty_least_confidence"]), + [(1 - 0.7) / (1 - 1.0 / 2.0), (1 - 0.6) / (1 - 1.0 / 2.0)], + ) + self.assertListEqual( + list(scores["uncertainty_margin"]), [1 - (0.7 - 0.3), 1 - (0.6 - 0.4)] + ) + model_output = np.concatenate([model_output, 1 - model_output], axis=1) + for val1, val2 in zip( + scores["uncertainty_entropy"], _entropy(model_output) / np.log2(2) + ): self.assertAlmostEqual(val1, val2, places=8) def test_scorer_classification_empty_model_output(self): @@ -61,13 +72,13 @@ def test_scorer_classification_empty_model_output(self): self.assertEqual(set(scores.keys()), set(ScorerClassification.score_names())) def test_scorer_classification_variable_model_output_dimension(self): - for num_samples in range(5): for num_classes in range(5): - - with self.subTest(msg=f"model_output.shape = ({num_samples},{num_classes})"): + with self.subTest( + msg=f"model_output.shape = ({num_samples},{num_classes})" + ): if num_samples > 0: - preds = [1. / num_samples] * num_classes + preds = [1.0 / num_samples] * num_classes else: preds = [] model_output = [preds] * num_samples @@ -78,13 +89,14 @@ def test_scorer_classification_variable_model_output_dimension(self): else: scorer = ScorerClassification(model_output=model_output) scores = scorer.calculate_scores() - self.assertEqual(set(scores.keys()), set(ScorerClassification.score_names())) + self.assertEqual( + set(scores.keys()), set(ScorerClassification.score_names()) + ) for score_values in scores.values(): self.assertEqual(len(score_values), len(model_output)) self.assertEqual(type(score_values), np.ndarray) def test_scorer_classification_variable_model_output_tensor_order(self): - for tensor_order in range(1, 5): model_output = np.ndarray((3,) * tensor_order) with self.subTest(msg=f"model_output.shape = {model_output.shape}"): diff --git a/tests/active_learning/test_ScorerKeypointDetection.py b/tests/active_learning/test_ScorerKeypointDetection.py index ecd93b6b6..1108389d6 100644 --- a/tests/active_learning/test_ScorerKeypointDetection.py +++ b/tests/active_learning/test_ScorerKeypointDetection.py @@ -2,29 +2,22 @@ import numpy as np -from lightly.active_learning.scorers.keypoint_detection import \ - ScorerKeypointDetection - -from lightly.active_learning.utils.keypoint_predictions import \ - KeypointInstancePrediction, KeypointPrediction +from lightly.active_learning.scorers.keypoint_detection import ScorerKeypointDetection +from lightly.active_learning.utils.keypoint_predictions import ( + KeypointInstancePrediction, + KeypointPrediction, +) class TestScorerKeypointDetection(unittest.TestCase): - def setUp(self) -> None: predictions_over_images = [ [ - { - "keypoints": [123., 456., 0.1, 565., 32., 0.2] - }, { - "keypoints": [342., 432., 0.3, 43., 2., 0.4] - } - ], [ - { - "keypoints": [23., 43., 0.5, 43., 2., 0.6] - } - ], [ - ] + {"keypoints": [123.0, 456.0, 0.1, 565.0, 32.0, 0.2]}, + {"keypoints": [342.0, 432.0, 0.3, 43.0, 2.0, 0.4]}, + ], + [{"keypoints": [23.0, 43.0, 0.5, 43.0, 2.0, 0.6]}], + [], ] model_output = [] for predictions_one_image in predictions_over_images: @@ -40,12 +33,13 @@ def setUp(self) -> None: self.expected_scores_mean_uncertainty = np.asarray([0.75, 0.45, 0]) def test_scorer_calculate_scores(self): - scorer = ScorerKeypointDetection(self.model_output) scores = scorer.calculate_scores() scores_mean_uncertainty = scores["mean_uncertainty"] - np.testing.assert_allclose(scores_mean_uncertainty, self.expected_scores_mean_uncertainty) + np.testing.assert_allclose( + scores_mean_uncertainty, self.expected_scores_mean_uncertainty + ) def test_scorer_get_score_names(self): scorer_1 = ScorerKeypointDetection(self.model_output) @@ -55,35 +49,35 @@ def test_scorer_get_score_names(self): def test_keypoint_instance_prediction_creation(self): with self.subTest("create correct"): - KeypointInstancePrediction([456., 32., 0.3]) + KeypointInstancePrediction([456.0, 32.0, 0.3]) with self.subTest("create correct with object_id"): - KeypointInstancePrediction([456., 32., 0.3], 3) + KeypointInstancePrediction([456.0, 32.0, 0.3], 3) with self.subTest("create correct with object_id and score"): - KeypointInstancePrediction([456., 32., 0.3], 3, 0.3) + KeypointInstancePrediction([456.0, 32.0, 0.3], 3, 0.3) with self.subTest("create correct with score"): - KeypointInstancePrediction([456., 32., 0.3], score = 0.3) + KeypointInstancePrediction([456.0, 32.0, 0.3], score=0.3) with self.subTest("create wrong keypoints format"): with self.assertRaises(ValueError): - KeypointInstancePrediction([456., 32., 0.3, 1], 3) + KeypointInstancePrediction([456.0, 32.0, 0.3, 1], 3) with self.subTest("create confidence < 0"): with self.assertRaises(ValueError): - KeypointInstancePrediction([456., 32., -0.1], 3) + KeypointInstancePrediction([456.0, 32.0, -0.1], 3) with self.subTest("create confidence > 1"): with self.assertRaises(ValueError): - KeypointInstancePrediction([456., 32., 1.5], 3) + KeypointInstancePrediction([456.0, 32.0, 1.5], 3) with self.subTest("create from dict"): dict_ = { "category_id": 3, "keypoints": [423, 432, 0.4, 231, 655, 0.3], - "score": -1.9 + "score": -1.9, } KeypointInstancePrediction.from_dict(dict_) def test_keypoint_prediction_creation(self): with self.subTest("create from KeypointInstancePrediction"): keypoints = [ - KeypointInstancePrediction([456., 32., 0.3]), - KeypointInstancePrediction([456., 32., 0.3], 3, 0.3) + KeypointInstancePrediction([456.0, 32.0, 0.3]), + KeypointInstancePrediction([456.0, 32.0, 0.3], 3, 0.3), ] KeypointPrediction(keypoints) with self.subTest("create from dicts"): @@ -91,7 +85,7 @@ def test_keypoint_prediction_creation(self): { "category_id": 3, "keypoints": [423, 432, 0.4, 231, 655, 0.3], - "score": -1.9 + "score": -1.9, } ] KeypointPrediction.from_dicts(dicts) @@ -105,7 +99,3 @@ def test_keypoint_prediction_creation(self): } ]""" KeypointPrediction.from_json_string(json_str) - - - - diff --git a/tests/active_learning/test_ScorerObjectDetection.py b/tests/active_learning/test_ScorerObjectDetection.py index 575807caf..3791d2f7f 100644 --- a/tests/active_learning/test_ScorerObjectDetection.py +++ b/tests/active_learning/test_ScorerObjectDetection.py @@ -3,72 +3,59 @@ import numpy as np from lightly.active_learning.scorers.classification import _entropy +from lightly.active_learning.scorers.detection import ScorerObjectDetection from lightly.active_learning.utils.bounding_box import BoundingBox from lightly.active_learning.utils.object_detection_output import ObjectDetectionOutput -from lightly.active_learning.scorers.detection import ScorerObjectDetection - class TestScorerObjectDetection(unittest.TestCase): - def setUp(self): self.dummy_data = [ { - 'boxes': [ - [14, 16, 52, 85], - [58, 23, 124, 49] - ], - 'object_probabilities': [ - 0.57573, - 0.988 - ], - 'class_probabilities': [ - [0.7, 0.2, 0.1], - [0.4, 0.5, 0.1] - ], - 'labels': [ + "boxes": [[14, 16, 52, 85], [58, 23, 124, 49]], + "object_probabilities": [0.57573, 0.988], + "class_probabilities": [[0.7, 0.2, 0.1], [0.4, 0.5, 0.1]], + "labels": [ 0, 1, - ] + ], }, { - 'boxes': [ + "boxes": [ [14, 16, 52, 85], ], - 'object_probabilities': [ + "object_probabilities": [ 0.1024, ], - 'class_probabilities': [ + "class_probabilities": [ [0.5, 0.41, 0.09], ], - 'labels': [0] + "labels": [0], }, { - 'boxes': [], - 'object_probabilities': [], - 'class_probabilities': [], - 'labels': [] - } + "boxes": [], + "object_probabilities": [], + "class_probabilities": [], + "labels": [], + }, ] - def test_object_detection_scorer(self): - # convert bounding boxes W, H = 128, 128 for data in self.dummy_data: - for i, box in enumerate(data['boxes']): + for i, box in enumerate(data["boxes"]): x0 = box[0] / W y0 = box[1] / H x1 = box[2] / W y1 = box[3] / H - data['boxes'][i] = BoundingBox(x0, y0, x1, y1) + data["boxes"][i] = BoundingBox(x0, y0, x1, y1) for i, data in enumerate(self.dummy_data): self.dummy_data[i] = ObjectDetectionOutput( - data['boxes'], - data['object_probabilities'], - data['class_probabilities'], + data["boxes"], + data["object_probabilities"], + data["class_probabilities"], ) scorer = ScorerObjectDetection(self.dummy_data) @@ -84,79 +71,75 @@ def test_object_detection_scorer(self): for key, val in scores.items(): self.assertEqual(type(scores[key]), type(np.array([]))) - res = scores['object_frequency'] + res = scores["object_frequency"] self.assertEqual(len(res), len(self.dummy_data)) self.assertListEqual(res.tolist(), [1.0, 0.95, 0.9]) - res = scores['objectness_least_confidence'] + res = scores["objectness_least_confidence"] self.assertEqual(len(res), len(self.dummy_data)) - self.assertListEqual(res.tolist(), [0.5514945, 0.9488, 0.]) + self.assertListEqual(res.tolist(), [0.5514945, 0.9488, 0.0]) for score_name, score in scores.items(): if "classification" in score_name: self.assertEqual(len(res), len(self.dummy_data)) if score_name == "classification_uncertainty_least_confidence": - self.assertListEqual(list(score), [max(1 - 0.7, 1 - 0.5)/(1 - 1/3), (1 - 0.5)/(1 - 1/3), 0]) + self.assertListEqual( + list(score), + [max(1 - 0.7, 1 - 0.5) / (1 - 1 / 3), (1 - 0.5) / (1 - 1 / 3), 0], + ) elif score_name == "classification_uncertainty_margin": - self.assertListEqual(list(score), [max(1 - (0.7 - 0.2), 1 - (0.5 - 0.4)), 1 - (0.5 - 0.41), 0]) + self.assertListEqual( + list(score), + [max(1 - (0.7 - 0.2), 1 - (0.5 - 0.4)), 1 - (0.5 - 0.41), 0], + ) elif score_name == "classification_uncertainty_entropy": - entropies_0 = _entropy(np.array(self.dummy_data[0].class_probabilities))/np.log2(3) - entropies_1 = _entropy(np.array(self.dummy_data[1].class_probabilities))/np.log2(3) + entropies_0 = _entropy( + np.array(self.dummy_data[0].class_probabilities) + ) / np.log2(3) + entropies_1 = _entropy( + np.array(self.dummy_data[1].class_probabilities) + ) / np.log2(3) score_target = [float(max(entropies_0)), float(max(entropies_1)), 0] for val1, val2 in zip(score, score_target): self.assertAlmostEqual(val1, val2, places=8) - def test_object_detection_scorer_config(self): - # convert bounding boxes W, H = 128, 128 for data in self.dummy_data: - for i, box in enumerate(data['boxes']): + for i, box in enumerate(data["boxes"]): x0 = box[0] / W y0 = box[1] / H x1 = box[2] / W y1 = box[3] / H - data['boxes'][i] = BoundingBox(x0, y0, x1, y1) + data["boxes"][i] = BoundingBox(x0, y0, x1, y1) for i, data in enumerate(self.dummy_data): self.dummy_data[i] = ObjectDetectionOutput( - data['boxes'], - data['object_probabilities'], - data['class_probabilities'], + data["boxes"], + data["object_probabilities"], + data["class_probabilities"], ) # check for default config scorer = ScorerObjectDetection(self.dummy_data) scores = scorer.calculate_scores() - expected_default_config = { - 'frequency_penalty': 0.25, - 'min_score': 0.9 - } + expected_default_config = {"frequency_penalty": 0.25, "min_score": 0.9} self.assertDictEqual(scorer.config, expected_default_config) # check for config override - new_config = { - 'frequency_penalty': 0.55, - 'min_score': 0.6 - } + new_config = {"frequency_penalty": 0.55, "min_score": 0.6} scorer = ScorerObjectDetection(self.dummy_data, config=new_config) scores = scorer.calculate_scores() self.assertDictEqual(scorer.config, new_config) # check for invalid key passed - new_config = { - 'frequenci_penalty': 0.55, - 'minimum_score': 0.6 - } + new_config = {"frequenci_penalty": 0.55, "minimum_score": 0.6} with self.assertRaises(KeyError): scorer = ScorerObjectDetection(self.dummy_data, config=new_config) # check for wrong value passed - new_config = { - 'frequency_penalty': 'test', - 'min_score': 1.6 - } + new_config = {"frequency_penalty": "test", "min_score": 1.6} with self.assertRaises(ValueError): scorer = ScorerObjectDetection(self.dummy_data, config=new_config) @@ -164,18 +147,16 @@ def test_object_detection_from_class_labels(self): # convert bounding boxes W, H = 128, 128 for data in self.dummy_data: - for i, box in enumerate(data['boxes']): + for i, box in enumerate(data["boxes"]): x0 = box[0] / W y0 = box[1] / H x1 = box[2] / W y1 = box[3] / H - data['boxes'][i] = BoundingBox(x0, y0, x1, y1) + data["boxes"][i] = BoundingBox(x0, y0, x1, y1) for i, data in enumerate(self.dummy_data): self.dummy_data[i] = ObjectDetectionOutput.from_scores( - data['boxes'], - data['object_probabilities'], - data['labels'] + data["boxes"], data["object_probabilities"], data["labels"] ) # check for default config @@ -186,9 +167,7 @@ def test_object_detection_from_class_labels(self): # make sure the max entry of a score is not 0.0 for key, val in scores.items(): self.assertNotEqual(max(val), 0.0) - + # make sure all scores are numpy arrays for key, val in scores.items(): self.assertEqual(type(scores[key]), type(np.array([]))) - - diff --git a/tests/active_learning/test_ScorerSemanticSegmentation.py b/tests/active_learning/test_ScorerSemanticSegmentation.py index 529d22fb4..e0814a4c9 100644 --- a/tests/active_learning/test_ScorerSemanticSegmentation.py +++ b/tests/active_learning/test_ScorerSemanticSegmentation.py @@ -3,13 +3,14 @@ import numpy as np import lightly -from lightly.active_learning.scorers import ScorerSemanticSegmentation -from lightly.active_learning.scorers import ScorerClassification +from lightly.active_learning.scorers import ( + ScorerClassification, + ScorerSemanticSegmentation, +) -class TestScorerSemanticSegmentation(unittest.TestCase): +class TestScorerSemanticSegmentation(unittest.TestCase): def setUp(self): - self.N = 100 self.W, self.H, self.C = 32, 32, 10 @@ -20,18 +21,28 @@ def setUp(self): self.dummy_data_width_1 = np.random.randn(self.N * self.H, self.C) self.dummy_data_width_1 /= np.sum(self.dummy_data_width_1, axis=-1)[:, None] - self.dummy_data_width_1 = self.dummy_data_width_1.reshape(self.N, 1, self.H, self.C) + self.dummy_data_width_1 = self.dummy_data_width_1.reshape( + self.N, 1, self.H, self.C + ) self.dummy_data_height_1 = np.random.randn(self.N * self.W, self.C) self.dummy_data_height_1 /= np.sum(self.dummy_data_height_1, axis=-1)[:, None] - self.dummy_data_height_1 = self.dummy_data_height_1.reshape(self.N, self.W, 1, self.C) + self.dummy_data_height_1 = self.dummy_data_height_1.reshape( + self.N, self.W, 1, self.C + ) self.dummy_data_width_height_1 = np.random.randn(self.N, self.C) - self.dummy_data_width_height_1 /= np.sum(self.dummy_data_width_height_1, axis=-1)[:, None] - self.dummy_data_width_height_1 = self.dummy_data_width_height_1.reshape(self.N, 1, 1, self.C) + self.dummy_data_width_height_1 /= np.sum( + self.dummy_data_width_height_1, axis=-1 + )[:, None] + self.dummy_data_width_height_1 = self.dummy_data_width_height_1.reshape( + self.N, 1, 1, self.C + ) self.dummy_data_classes_1 = np.random.randn(self.N * self.W * self.H, 1) - self.dummy_data_classes_1 = self.dummy_data_classes_1.reshape(self.N, self.W, self.H, 1) + self.dummy_data_classes_1 = self.dummy_data_classes_1.reshape( + self.N, self.W, self.H, 1 + ) # the following data should always fail self.dummy_data_valerr = np.random.randn(self.N, self.C) @@ -43,69 +54,61 @@ def dummy_data_generator(self): yield prediction def test_scorer_default_case(self): - scorer = ScorerSemanticSegmentation(self.dummy_data) scores = scorer.calculate_scores() for score_name, score_array in scores.items(): self.assertTrue(isinstance(score_array, np.ndarray)) - self.assertEqual(score_array.shape, (self.N, )) + self.assertEqual(score_array.shape, (self.N,)) def test_scorer_width_1_case(self): - scorer = ScorerSemanticSegmentation(self.dummy_data_width_1) scores = scorer.calculate_scores() for score_name, score_array in scores.items(): self.assertTrue(isinstance(score_array, np.ndarray)) - self.assertEqual(score_array.shape, (self.N, )) + self.assertEqual(score_array.shape, (self.N,)) def test_scorer_height_1_case(self): - scorer = ScorerSemanticSegmentation(self.dummy_data_height_1) scores = scorer.calculate_scores() for score_name, score_array in scores.items(): self.assertTrue(isinstance(score_array, np.ndarray)) - self.assertEqual(score_array.shape, (self.N, )) + self.assertEqual(score_array.shape, (self.N,)) def test_scorer_width_height_1_case(self): - scorer = ScorerSemanticSegmentation(self.dummy_data_width_height_1) scores = scorer.calculate_scores() for score_name, score_array in scores.items(): self.assertTrue(isinstance(score_array, np.ndarray)) - self.assertEqual(score_array.shape, (self.N, )) + self.assertEqual(score_array.shape, (self.N,)) def test_scorer_classes_1_case(self): - scorer = ScorerSemanticSegmentation(self.dummy_data_classes_1) scores = scorer.calculate_scores() for score_name, score_array in scores.items(): self.assertTrue(isinstance(score_array, np.ndarray)) - self.assertEqual(score_array.shape, (self.N, )) + self.assertEqual(score_array.shape, (self.N,)) def test_scorer_generator_case(self): - scorer = ScorerSemanticSegmentation(self.dummy_data_generator()) scores = scorer.calculate_scores() for score_name, score_array in scores.items(): self.assertTrue(isinstance(score_array, np.ndarray)) - self.assertEqual(score_array.shape, (self.N, )) + self.assertEqual(score_array.shape, (self.N,)) def test_wrong_input_shape(self): - scorer = ScorerSemanticSegmentation(self.dummy_data_valerr) with self.assertRaises(ValueError): scorer.calculate_scores() def test_scorer_semseg__score_names(self): - - model_output = [np.empty(shape=(1,1,1))] + model_output = [np.empty(shape=(1, 1, 1))] scorer = ScorerSemanticSegmentation(model_output=model_output) scores = scorer.calculate_scores() assert sorted(scores.keys()) == sorted(ScorerSemanticSegmentation.score_names()) diff --git a/tests/active_learning/test_active_learning_agent.py b/tests/active_learning/test_active_learning_agent.py index 0358b3566..67c54427d 100644 --- a/tests/active_learning/test_active_learning_agent.py +++ b/tests/active_learning/test_active_learning_agent.py @@ -13,12 +13,23 @@ def test_agent(self): self.api_workflow_client.embedding_id = "embedding_id_xyz" agent_0 = ActiveLearningAgent(self.api_workflow_client) - agent_1 = ActiveLearningAgent(self.api_workflow_client, query_tag_name="query_tag_name_xyz") - agent_2 = ActiveLearningAgent(self.api_workflow_client, query_tag_name="query_tag_name_xyz", - preselected_tag_name="preselected_tag_name_xyz") - agent_3 = ActiveLearningAgent(self.api_workflow_client, preselected_tag_name="preselected_tag_name_xyz") + agent_1 = ActiveLearningAgent( + self.api_workflow_client, query_tag_name="query_tag_name_xyz" + ) + agent_2 = ActiveLearningAgent( + self.api_workflow_client, + query_tag_name="query_tag_name_xyz", + preselected_tag_name="preselected_tag_name_xyz", + ) + agent_3 = ActiveLearningAgent( + self.api_workflow_client, preselected_tag_name="preselected_tag_name_xyz" + ) - for method in [SamplingMethod.CORAL, SamplingMethod.CORESET, SamplingMethod.RANDOM]: + for method in [ + SamplingMethod.CORAL, + SamplingMethod.CORESET, + SamplingMethod.RANDOM, + ]: for agent in [agent_0, agent_1, agent_2, agent_3]: for batch_size in [2, 6]: n_old_labeled = len(agent.labeled_set) @@ -26,36 +37,54 @@ def test_agent(self): n_samples = len(agent.labeled_set) + batch_size if method == SamplingMethod.CORAL and len(agent.labeled_set) == 0: - selection_config = SelectionConfig(n_samples=n_samples, method=SamplingMethod.CORESET) + selection_config = SelectionConfig( + n_samples=n_samples, method=SamplingMethod.CORESET + ) else: - selection_config = SelectionConfig(n_samples=n_samples, method=method) + selection_config = SelectionConfig( + n_samples=n_samples, method=method + ) if selection_config.method == SamplingMethod.CORAL: - predictions = np.random.rand(len(agent.query_set), 10).astype(np.float32) - predictions_normalized = predictions / np.sum(predictions, axis=1)[:, np.newaxis] + predictions = np.random.rand(len(agent.query_set), 10).astype( + np.float32 + ) + predictions_normalized = ( + predictions / np.sum(predictions, axis=1)[:, np.newaxis] + ) al_scorer = ScorerClassification(predictions_normalized) - agent.query(selection_config=selection_config, al_scorer=al_scorer) + agent.query( + selection_config=selection_config, al_scorer=al_scorer + ) else: selection_config = SelectionConfig(n_samples=n_samples) agent.query(selection_config=selection_config) - + labeled_set, added_set = agent.labeled_set, agent.added_set self.assertEqual(n_old_labeled + len(added_set), len(labeled_set)) self.assertTrue(set(added_set).issubset(labeled_set)) - self.assertEqual(len(list(set(agent.labeled_set) & set(agent.unlabeled_set))), 0) - self.assertEqual(n_old_unlabeled - len(added_set), len(agent.unlabeled_set)) + self.assertEqual( + len(list(set(agent.labeled_set) & set(agent.unlabeled_set))), 0 + ) + self.assertEqual( + n_old_unlabeled - len(added_set), len(agent.unlabeled_set) + ) def test_agent_wrong_number_of_scores(self): self.api_workflow_client.embedding_id = "embedding_id_xyz" - agent = ActiveLearningAgent(self.api_workflow_client, preselected_tag_name="preselected_tag_name_xyz") + agent = ActiveLearningAgent( + self.api_workflow_client, preselected_tag_name="preselected_tag_name_xyz" + ) method = SamplingMethod.CORAL n_samples = len(agent.labeled_set) + 2 n_predictions = len(agent.query_set) - 3 # the -3 should cause an error predictions = np.random.rand(n_predictions, 10).astype(np.float32) - predictions_normalized = predictions / np.sum(predictions, axis=1)[:, np.newaxis] + predictions_normalized = ( + predictions / np.sum(predictions, axis=1)[:, np.newaxis] + ) al_scorer = ScorerClassification(predictions_normalized) selection_config = SelectionConfig(n_samples=n_samples, method=method) @@ -68,14 +97,22 @@ def test_agent_with_generator(self): height = 32 no_classes = 13 - agent = ActiveLearningAgent(self.api_workflow_client, preselected_tag_name="preselected_tag_name_xyz") + agent = ActiveLearningAgent( + self.api_workflow_client, preselected_tag_name="preselected_tag_name_xyz" + ) method = SamplingMethod.CORAL n_samples = len(agent.labeled_set) + 2 n_predictions = len(agent.query_set) - predictions = np.random.rand(n_predictions, no_classes, width, height).astype(np.float32) - predictions_normalized = predictions / np.sum(predictions, axis=1)[:, np.newaxis] - predictions_generator = (predictions_normalized[i] for i in range(n_predictions)) + predictions = np.random.rand(n_predictions, no_classes, width, height).astype( + np.float32 + ) + predictions_normalized = ( + predictions / np.sum(predictions, axis=1)[:, np.newaxis] + ) + predictions_generator = ( + predictions_normalized[i] for i in range(n_predictions) + ) al_scorer = ScorerSemanticSegmentation(predictions_generator) selection_config = SelectionConfig(n_samples=n_samples, method=method) @@ -86,11 +123,9 @@ def test_agent_with_generator(self): agent.upload_scores(al_scorer) def test_agent_added_set_before_query(self): - self.api_workflow_client.embedding_id = "embedding_id_xyz" agent = ActiveLearningAgent( - self.api_workflow_client, - preselected_tag_name="preselected_tag_name_xyz" + self.api_workflow_client, preselected_tag_name="preselected_tag_name_xyz" ) agent.query_set @@ -100,7 +135,6 @@ def test_agent_added_set_before_query(self): agent.added_set def test_agent_query_too_few(self): - self.api_workflow_client.embedding_id = "embedding_id_xyz" agent = ActiveLearningAgent( self.api_workflow_client, @@ -108,10 +142,7 @@ def test_agent_query_too_few(self): ) # sample 0 samples - selection_config = SelectionConfig( - n_samples=0, - method=SamplingMethod.RANDOM - ) + selection_config = SelectionConfig(n_samples=0, method=SamplingMethod.RANDOM) agent.query(selection_config) @@ -124,26 +155,26 @@ def test_agent_only_upload_scores(self): n_predictions = len(agent.query_set) predictions = np.random.rand(n_predictions, 10).astype(np.float32) - predictions_normalized = predictions / np.sum(predictions, axis=1)[:, np.newaxis] + predictions_normalized = ( + predictions / np.sum(predictions, axis=1)[:, np.newaxis] + ) al_scorer = ScorerClassification(predictions_normalized) agent.upload_scores(al_scorer) def test_agent_without_embedding_id(self): agent = ActiveLearningAgent( - self.api_workflow_client, - preselected_tag_name="preselected_tag_name_xyz" + self.api_workflow_client, preselected_tag_name="preselected_tag_name_xyz" ) method = SamplingMethod.CORAL n_samples = len(agent.labeled_set) + 2 n_predictions = len(agent.query_set) predictions = np.random.rand(n_predictions, 10).astype(np.float32) - predictions_normalized = predictions / np.sum(predictions, axis=1)[:, np.newaxis] + predictions_normalized = ( + predictions / np.sum(predictions, axis=1)[:, np.newaxis] + ) al_scorer = ScorerClassification(predictions_normalized) selection_config = SelectionConfig(n_samples=n_samples, method=method) agent.query(selection_config=selection_config, al_scorer=al_scorer) - - - diff --git a/tests/api/benchmark_video_download.py b/tests/api/benchmark_video_download.py index 0927e0fe4..107b51eb9 100644 --- a/tests/api/benchmark_video_download.py +++ b/tests/api/benchmark_video_download.py @@ -5,8 +5,12 @@ import numpy as np from tqdm import tqdm -from lightly.api.download import download_video_frames_at_timestamps, \ - download_all_video_frames, download_video_frame +from lightly.api.download import ( + download_all_video_frames, + download_video_frame, + download_video_frames_at_timestamps, +) + @unittest.skip("Only used for benchmarks") class BenchmarkDownloadVideoFrames(unittest.TestCase): @@ -17,7 +21,7 @@ class BenchmarkDownloadVideoFrames(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.video_url_12min_100mb = "https://mediandr-a.akamaihd.net/progressive/2018/0912/TV-20180912-1628-0000.ln.mp4" - with av.open(cls.video_url_12min_100mb ) as container: + with av.open(cls.video_url_12min_100mb) as container: stream = container.streams.video[0] duration = stream.duration # This video has its timestamps 0-based @@ -40,7 +44,9 @@ def test_download_at_timestamps_for_loop(self): frame = download_video_frame(self.video_url_12min_100mb, timestamp) def test_download_at_timestamps(self): - frames = download_video_frames_at_timestamps(self.video_url_12min_100mb, self.timestamps) + frames = download_video_frames_at_timestamps( + self.video_url_12min_100mb, self.timestamps + ) frames = list(tqdm(frames, total=len(self.timestamps))) # Takes long as it downloads the whole video first @@ -51,10 +57,11 @@ def test_download_at_indices_decord(self): See https://github.com/dmlc/decord/issues/199 """ import decord + vr = decord.VideoReader(self.video_url_12min_100mb) - decord.bridge.set_bridge('torch') + decord.bridge.set_bridge("torch") print(f"Took {time.time() - self.start_time}s for creating the video reader.") frames = vr.get_batch(list(range(0, 18000, 18))) def tearDown(self) -> None: - print(f"Took {time.time()-self.start_time}s") \ No newline at end of file + print(f"Took {time.time()-self.start_time}s") diff --git a/tests/api/test_BitMask.py b/tests/api/test_BitMask.py index d19c68236..1d477232c 100644 --- a/tests/api/test_BitMask.py +++ b/tests/api/test_BitMask.py @@ -1,6 +1,6 @@ import unittest from copy import deepcopy -from random import random, seed, randint +from random import randint, random, seed from lightly.api.bitmask import BitMask @@ -8,12 +8,10 @@ class TestBitMask(unittest.TestCase): - - def setup(self, psuccess=1.): + def setup(self, psuccess=1.0): pass def test_get_and_set(self): - mask = BitMask.from_bin("0b11110000") self.assertFalse(mask.get_kth_bit(2)) @@ -36,7 +34,6 @@ def test_bitmask_from_length(self): self.assertEqual(mask.to_bin(), "0b1111") def test_get_and_set_outside_of_range(self): - mask = BitMask.from_bin("0b11110000") self.assertFalse(mask.get_kth_bit(100)) @@ -60,7 +57,6 @@ def test_inverse(self): self.assertEqual(mask.x, y) def test_store_and_retrieve(self): - x = int("0b01010100100100100100100010010100100100101001001010101010", 2) mask = BitMask(x) mask.set_kth_bit(11) @@ -106,7 +102,7 @@ def test_differences(self): self.assert_difference("0b10111", "0b01100", "0b10011") def random_bitstring(self, length: int): - bitsting = '0b' + bitsting = "0b" for i in range(length): bitsting += str(randint(0, 1)) return bitsting @@ -117,12 +113,12 @@ def test_difference_random(self): for string_length in range(1, 100, 10): bitstring_1 = self.random_bitstring(string_length) bitstring_2 = self.random_bitstring(string_length) - target = '0b' + target = "0b" for bit_1, bit_2 in zip(bitstring_1[2:], bitstring_2[2:]): - if bit_1 == '1' and bit_2 == '0': - target += '1' + if bit_1 == "1" and bit_2 == "0": + target += "1" else: - target += '0' + target += "0" self.assert_difference(bitstring_1, bitstring_2, target) def test_operator_minus(self): @@ -132,7 +128,9 @@ def test_operator_minus(self): mask_target = BitMask.from_bin("0b10011") diff = mask_a - mask_b self.assertEqual(diff, mask_target) - self.assertEqual(mask_a_old, mask_a) # make sure the original mask is unchanged. + self.assertEqual( + mask_a_old, mask_a + ) # make sure the original mask is unchanged. def test_equal(self): mask_a = BitMask.from_bin("0b101") @@ -157,20 +155,18 @@ def test_masked_select_from_list(self): self.assertTrue(all([item_ > 0 for item_ in all_ones])) self.assertTrue(all([item_ == 0 for item_ in all_zeros])) - def test_masked_select_from_list_example(self): list_ = [1, 2, 3, 4, 5, 6] - mask = BitMask.from_bin('0b001101') # expected result is [1, 3, 4] + mask = BitMask.from_bin("0b001101") # expected result is [1, 3, 4] selected = mask.masked_select_from_list(list_) self.assertListEqual(selected, [1, 3, 4]) - def test_invert(self): # get random bitstring length = 10 bitstring = self.random_bitstring(10) - - #get inverse + + # get inverse mask = BitMask.from_bin(bitstring) mask.invert(length) inverted = mask.to_bin() @@ -179,14 +175,12 @@ def test_invert(self): inverted = inverted[2:] bitstring = bitstring[2:] for i in range(min(len(bitstring), len(inverted))): - if bitstring[-i - 1] == '0': - self.assertEqual(inverted[-i - 1], '1') + if bitstring[-i - 1] == "0": + self.assertEqual(inverted[-i - 1], "1") else: - self.assertEqual(inverted[-i - 1], '0') - + self.assertEqual(inverted[-i - 1], "0") def test_nonzero_bits(self): - mask = BitMask.from_bin("0b0") indices = [100, 1000, 10_000, 100_000] diff --git a/tests/api/test_download.py b/tests/api/test_download.py index 943a13a00..8921d461f 100644 --- a/tests/api/test_download.py +++ b/tests/api/test_download.py @@ -1,6 +1,6 @@ +import json import os import sys -import json import tempfile import unittest import warnings @@ -8,20 +8,21 @@ from unittest import mock import numpy as np -from PIL import Image import tqdm +from PIL import Image try: import av + AV_AVAILABLE = True except ImportError: AV_AVAILABLE = False -# mock requests module so that files are read from +# mock requests module so that files are read from # disk instead of loading them from a remote url -class MockedRequestsModule: +class MockedRequestsModule: def get(self, url, stream=None, *args, **kwargs): return MockedResponse(url) @@ -29,8 +30,8 @@ class Session: def get(self, url, stream=None, *args, **kwargs): return MockedResponse(url) -class MockedRequestsModulePartialResponse: +class MockedRequestsModulePartialResponse: def get(self, url, stream=None, *args, **kwargs): return MockedResponsePartialStream(url) @@ -40,17 +41,17 @@ def raise_for_status(self): class Session: def get(self, url, stream=None, *args, **kwargs): return MockedResponsePartialStream(url) - -class MockedResponse: + +class MockedResponse: def __init__(self, raw): self._raw = raw @property def raw(self): - # instead of returning the byte stream from the url + # instead of returning the byte stream from the url # we just give back an openend filehandle - return open(self._raw, 'rb') + return open(self._raw, "rb") @property def status_code(self): @@ -60,9 +61,9 @@ def raise_for_status(self): return def json(self): - # instead of returning the byte stream from the url + # instead of returning the byte stream from the url # we just load the json and return the dictionary - with open(self._raw, 'r') as f: + with open(self._raw, "r") as f: return json.load(f) def __enter__(self): @@ -71,15 +72,15 @@ def __enter__(self): def __exit__(self, *args): pass -class MockedResponsePartialStream(MockedResponse): +class MockedResponsePartialStream(MockedResponse): return_partial_stream = True @property def raw(self): # instead of returning the byte stream from the url # we just give back an openend filehandle - stream = open(self._raw, 'rb') + stream = open(self._raw, "rb") if self.return_partial_stream: bytes = stream.read() stream_first_part = BytesIO(bytes[:1024]) @@ -92,9 +93,8 @@ def raw(self): import lightly -@mock.patch('lightly.api.download.requests', MockedRequestsModulePartialResponse()) +@mock.patch("lightly.api.download.requests", MockedRequestsModulePartialResponse()) class TestDownloadPartialRespons(unittest.TestCase): - def setUp(self): self._max_retries = lightly.api.utils.RETRY_MAX_RETRIES self._max_backoff = lightly.api.utils.RETRY_MAX_BACKOFF @@ -111,7 +111,7 @@ def test_download_image_half_broken_retry_once(self): lightly.api.utils.RETRY_MAX_RETRIES = 1 original = _pil_image() - with tempfile.NamedTemporaryFile(suffix='.png') as file: + with tempfile.NamedTemporaryFile(suffix=".png") as file: original.save(file.name) # assert that the retry fails with self.assertRaises(RuntimeError) as error: @@ -124,16 +124,14 @@ def test_download_image_half_broken_retry_twice(self): lightly.api.utils.RETRY_MAX_RETRIES = 2 MockedResponse.return_partial_stream = True original = _pil_image() - with tempfile.NamedTemporaryFile(suffix='.png') as file: + with tempfile.NamedTemporaryFile(suffix=".png") as file: original.save(file.name) image = lightly.api.download.download_image(file.name) assert _images_equal(image, original) - -@mock.patch('lightly.api.download.requests', MockedRequestsModule()) +@mock.patch("lightly.api.download.requests", MockedRequestsModule()) class TestDownload(unittest.TestCase): - def setUp(self): self._max_retries = lightly.api.utils.RETRY_MAX_RETRIES self._max_backoff = lightly.api.utils.RETRY_MAX_BACKOFF @@ -148,22 +146,21 @@ def tearDown(self): def test_download_image(self): original = _pil_image() - with tempfile.NamedTemporaryFile(suffix='.png') as file: + with tempfile.NamedTemporaryFile(suffix=".png") as file: original.save(file.name) - for request_kwargs in [None, {'stream': False}]: + for request_kwargs in [None, {"stream": False}]: with self.subTest(request_kwargs=request_kwargs): image = lightly.api.download.download_image( - file.name, - request_kwargs=request_kwargs + file.name, request_kwargs=request_kwargs ) assert _images_equal(image, original) def test_download_prediction(self): original = _json_prediction() - with tempfile.NamedTemporaryFile(suffix='.json', mode="w+") as file: - with open(file.name, 'w') as f: + with tempfile.NamedTemporaryFile(suffix=".json", mode="w+") as file: + with open(file.name, "w") as f: json.dump(original, f) - for request_kwargs in [None, {'stream': False}]: + for request_kwargs in [None, {"stream": False}]: with self.subTest(request_kwargs=request_kwargs): response = lightly.api.download.download_prediction_file( file.name, @@ -174,14 +171,14 @@ def test_download_prediction(self): def test_download_image_with_session(self): session = MockedRequestsModule.Session() original = _pil_image() - with tempfile.NamedTemporaryFile(suffix='.png') as file: + with tempfile.NamedTemporaryFile(suffix=".png") as file: original.save(file.name) image = lightly.api.download.download_image(file.name, session=session) assert _images_equal(image, original) @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_all_video_frames(self): - with tempfile.NamedTemporaryFile(suffix='.avi') as file: + with tempfile.NamedTemporaryFile(suffix=".avi") as file: original = _generate_video(file.name) frames = list(lightly.api.download.download_all_video_frames(file.name)) for frame, orig in zip(frames, original): @@ -189,36 +186,47 @@ def test_download_all_video_frames(self): @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_all_video_frames_timeout(self): - with tempfile.NamedTemporaryFile(suffix='.avi') as file: + with tempfile.NamedTemporaryFile(suffix=".avi") as file: _generate_video(file.name) - with self.assertRaisesRegexp(RuntimeError, "Maximum retries exceeded.*av.error.ExitError.*Immediate exit requested.*"): - list(lightly.api.download.download_all_video_frames(file.name, timeout=0)) + with self.assertRaisesRegexp( + RuntimeError, + "Maximum retries exceeded.*av.error.ExitError.*Immediate exit requested.*", + ): + list( + lightly.api.download.download_all_video_frames(file.name, timeout=0) + ) @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_last_video_frame(self): - with tempfile.NamedTemporaryFile(suffix='.avi') as file: + with tempfile.NamedTemporaryFile(suffix=".avi") as file: n_frames = 5 original = _generate_video(file.name, n_frames=n_frames) - timestamps = list(range(1, n_frames+1)) + timestamps = list(range(1, n_frames + 1)) for timestamp in timestamps: with self.subTest(timestamp=timestamp): if timestamp > n_frames: with self.assertRaises(RuntimeError): - frame = lightly.api.download.download_video_frame(file.name, timestamp) + frame = lightly.api.download.download_video_frame( + file.name, timestamp + ) else: - frame = lightly.api.download.download_video_frame(file.name, timestamp) + frame = lightly.api.download.download_video_frame( + file.name, timestamp + ) @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_video_frames_at_timestamps(self): - with tempfile.NamedTemporaryFile(suffix='.avi') as file: + with tempfile.NamedTemporaryFile(suffix=".avi") as file: n_frames = 5 original = _generate_video(file.name, n_frames=n_frames) - original_timestamps = list(range(1, n_frames+1)) + original_timestamps = list(range(1, n_frames + 1)) frame_indices = list(range(2, len(original) - 1, 2)) timestamps = [original_timestamps[i] for i in frame_indices] - frames = list(lightly.api.download.download_video_frames_at_timestamps( - file.name, timestamps - )) + frames = list( + lightly.api.download.download_video_frames_at_timestamps( + file.name, timestamps + ) + ) self.assertEqual(len(frames), len(timestamps)) for frame, timestamp in zip(frames, frame_indices): orig = original[timestamp] @@ -226,54 +234,74 @@ def test_download_video_frames_at_timestamps(self): @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_video_frames_at_timestamps_timeout(self): - with tempfile.NamedTemporaryFile(suffix='.avi') as file: + with tempfile.NamedTemporaryFile(suffix=".avi") as file: n_frames = 5 _generate_video(file.name, n_frames) - with self.assertRaisesRegexp(RuntimeError, "Maximum retries exceeded.*av.error.ExitError.*Immediate exit requested.*"): - list(lightly.api.download.download_video_frames_at_timestamps( - file.name, - timestamps=list(range(1, n_frames + 1)), - timeout=0, - )) + with self.assertRaisesRegexp( + RuntimeError, + "Maximum retries exceeded.*av.error.ExitError.*Immediate exit requested.*", + ): + list( + lightly.api.download.download_video_frames_at_timestamps( + file.name, + timestamps=list(range(1, n_frames + 1)), + timeout=0, + ) + ) @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_video_frames_at_timestamps_wrong_order(self): - with tempfile.NamedTemporaryFile(suffix='.avi') as file: + with tempfile.NamedTemporaryFile(suffix=".avi") as file: original = _generate_video(file.name) timestamps = [2, 1] with self.assertRaises(ValueError): - frames = list(lightly.api.download.download_video_frames_at_timestamps( - file.name, timestamps - )) + frames = list( + lightly.api.download.download_video_frames_at_timestamps( + file.name, timestamps + ) + ) @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_video_frames_at_timestamps_emtpy(self): - with tempfile.NamedTemporaryFile(suffix='.avi') as file: - frames = list(lightly.api.download.download_video_frames_at_timestamps( + with tempfile.NamedTemporaryFile(suffix=".avi") as file: + frames = list( + lightly.api.download.download_video_frames_at_timestamps( file.name, timestamps=[] - )) + ) + ) self.assertEqual(len(frames), 0) @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_all_video_frames_restart_throws(self): - with tempfile.NamedTemporaryFile(suffix='.avi') as file: + with tempfile.NamedTemporaryFile(suffix=".avi") as file: original = _generate_video(file.name) with self.assertRaises(ValueError): # timestamp too small - frames = list(lightly.api.download.download_all_video_frames(file.name, timestamp=-1)) + frames = list( + lightly.api.download.download_all_video_frames( + file.name, timestamp=-1 + ) + ) # timestamp too large - frames = list(lightly.api.download.download_all_video_frames(file.name, timestamp=len(original) + 1)) + frames = list( + lightly.api.download.download_all_video_frames( + file.name, timestamp=len(original) + 1 + ) + ) self.assertEqual(len(frames), 0) - @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_all_video_frames_restart_at_0(self): # relevant for restarting if the frame iterator is empty # although it shouldn't be - with tempfile.NamedTemporaryFile(suffix='.avi') as file: + with tempfile.NamedTemporaryFile(suffix=".avi") as file: original = _generate_video(file.name) - frames = list(lightly.api.download.download_all_video_frames(file.name, timestamp=None)) + frames = list( + lightly.api.download.download_all_video_frames( + file.name, timestamp=None + ) + ) for frame, orig in zip(frames, original): assert _images_equal(frame, orig) @@ -282,18 +310,22 @@ def test_download_all_video_frames_restart(self): # relevant if decoding a frame goes wrong for some reason and we # want to try again restart_timestamp = 3 - with tempfile.NamedTemporaryFile(suffix='.avi') as file: + with tempfile.NamedTemporaryFile(suffix=".avi") as file: original = _generate_video(file.name) - frames = list(lightly.api.download.download_all_video_frames(file.name, restart_timestamp)) + frames = list( + lightly.api.download.download_all_video_frames( + file.name, restart_timestamp + ) + ) for frame, orig in zip(frames, original[2:]): assert _images_equal(frame, orig) @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_video_frame_fps(self): for fps in [24, 30, 60]: - with self.subTest(msg=f"fps={fps}"), \ - tempfile.NamedTemporaryFile(suffix='.avi') as file: - + with self.subTest(msg=f"fps={fps}"), tempfile.NamedTemporaryFile( + suffix=".avi" + ) as file: original = _generate_video(file.name, fps=fps) all_frames = lightly.api.download.download_all_video_frames( file.name, @@ -310,47 +342,53 @@ def test_download_video_frame_fps(self): @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_video_frame_timestamp_exception(self): for fps in [24, 30, 60]: - with self.subTest(msg=f"fps={fps}"), \ - tempfile.NamedTemporaryFile(suffix='.avi') as file: - + with self.subTest(msg=f"fps={fps}"), tempfile.NamedTemporaryFile( + suffix=".avi" + ) as file: original = _generate_video(file.name, fps=fps) # this should be the last frame and exist - frame = lightly.api.download.download_video_frame(file.name, len(original)) + frame = lightly.api.download.download_video_frame( + file.name, len(original) + ) assert _images_equal(frame, original[-1]) - # timestamp after last frame + # timestamp after last frame with self.assertRaises(RuntimeError): - lightly.api.download.download_video_frame(file.name, len(original) + 1) + lightly.api.download.download_video_frame( + file.name, len(original) + 1 + ) @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_video_frame_negative_timestamp_exception(self): for fps in [24, 30, 60]: - with self.subTest(msg=f"fps={fps}"), \ - tempfile.NamedTemporaryFile(suffix='.avi') as file: - + with self.subTest(msg=f"fps={fps}"), tempfile.NamedTemporaryFile( + suffix=".avi" + ) as file: _generate_video(file.name, fps=fps) with self.assertRaises(ValueError): lightly.api.download.download_video_frame(file.name, -1) def test_download_and_write_file(self): original = _pil_image() - with tempfile.NamedTemporaryFile(suffix='.png') as file1, \ - tempfile.NamedTemporaryFile(suffix='.png') as file2: - + with tempfile.NamedTemporaryFile( + suffix=".png" + ) as file1, tempfile.NamedTemporaryFile(suffix=".png") as file2: original.save(file1.name) lightly.api.download.download_and_write_file(file1.name, file2.name) image = Image.open(file2.name) assert _images_equal(original, image) - + def test_download_and_write_file_with_session(self): session = MockedRequestsModule.Session() original = _pil_image() - with tempfile.NamedTemporaryFile(suffix='.png') as file1, \ - tempfile.NamedTemporaryFile(suffix='.png') as file2: - + with tempfile.NamedTemporaryFile( + suffix=".png" + ) as file1, tempfile.NamedTemporaryFile(suffix=".png") as file2: original.save(file1.name) - lightly.api.download.download_and_write_file(file1.name, file2.name, session=session) + lightly.api.download.download_and_write_file( + file1.name, file2.name, session=session + ) image = Image.open(file2.name) assert _images_equal(original, image) @@ -358,23 +396,22 @@ def test_download_and_write_all_files(self): n_files = 3 max_workers = 2 originals = [_pil_image(seed=i) for i in range(n_files)] - filenames = [f'filename_{i}.png' for i in range(n_files)] - with tempfile.TemporaryDirectory() as tempdir1, \ - tempfile.TemporaryDirectory() as tempdir2: - - for request_kwargs in [None, {'stream': False}]: + filenames = [f"filename_{i}.png" for i in range(n_files)] + with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2: + for request_kwargs in [None, {"stream": False}]: with self.subTest(request_kwargs=request_kwargs): - # save images at "remote" location - urls = [os.path.join(tempdir1, f'url_{i}.png') for i in range(n_files)] + urls = [ + os.path.join(tempdir1, f"url_{i}.png") for i in range(n_files) + ] for image, url in zip(originals, urls): image.save(url) # download images from remote to local file_infos = list(zip(filenames, urls)) lightly.api.download.download_and_write_all_files( - file_infos, - output_dir=tempdir2, + file_infos, + output_dir=tempdir2, max_workers=max_workers, request_kwargs=request_kwargs, ) @@ -387,33 +424,35 @@ def test_download_and_write_all_files(self): def test_download_video_frame_count(self): fps = 24 for true_n_frames in [24, 30, 60]: - for suffix in ['.avi', '.mpeg']: - with tempfile.NamedTemporaryFile(suffix=suffix) as file, \ - self.subTest(msg=f'n_frames={true_n_frames}, extension={suffix}'): - + for suffix in [".avi", ".mpeg"]: + with tempfile.NamedTemporaryFile(suffix=suffix) as file, self.subTest( + msg=f"n_frames={true_n_frames}, extension={suffix}" + ): _generate_video(file.name, n_frames=true_n_frames, fps=fps) n_frames = lightly.api.download.video_frame_count(file.name) assert n_frames == true_n_frames @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_video_Frame_count_timeout(self): - with tempfile.NamedTemporaryFile(suffix='.avi') as file: + with tempfile.NamedTemporaryFile(suffix=".avi") as file: _generate_video(file.name) - with self.assertRaisesRegexp(RuntimeError, "Maximum retries exceeded.*av.error.ExitError.*Immediate exit requested.*"): + with self.assertRaisesRegexp( + RuntimeError, + "Maximum retries exceeded.*av.error.ExitError.*Immediate exit requested.*", + ): lightly.api.download.video_frame_count(file.name, timeout=0) @unittest.skipUnless(AV_AVAILABLE, "Pyav not installed") def test_download_video_frame_count_no_metadata(self): fps = 24 for true_n_frames in [24, 30, 60]: - for suffix in ['.avi', '.mpeg']: - with tempfile.NamedTemporaryFile(suffix=suffix) as file, \ - self.subTest(msg=f'n_frames={true_n_frames}, extension={suffix}'): - + for suffix in [".avi", ".mpeg"]: + with tempfile.NamedTemporaryFile(suffix=suffix) as file, self.subTest( + msg=f"n_frames={true_n_frames}, extension={suffix}" + ): _generate_video(file.name, n_frames=true_n_frames, fps=fps) n_frames = lightly.api.download.video_frame_count( - file.name, - ignore_metadata=True + file.name, ignore_metadata=True ) assert n_frames == true_n_frames @@ -421,11 +460,14 @@ def test_download_video_frame_count_no_metadata(self): def test_download_all_video_frame_counts(self): true_n_frames = [3, 5] fps = 24 - for suffix in ['.avi', '.mpeg']: - with tempfile.NamedTemporaryFile(suffix=suffix) as file1, \ - tempfile.NamedTemporaryFile(suffix=suffix) as file2, \ - self.subTest(msg=f'extension={suffix}'): - + for suffix in [".avi", ".mpeg"]: + with tempfile.NamedTemporaryFile( + suffix=suffix + ) as file1, tempfile.NamedTemporaryFile( + suffix=suffix + ) as file2, self.subTest( + msg=f"extension={suffix}" + ): _generate_video(file1.name, n_frames=true_n_frames[0], fps=fps) _generate_video(file2.name, n_frames=true_n_frames[1], fps=fps) frame_counts = lightly.api.download.all_video_frame_counts( @@ -438,12 +480,12 @@ def test_download_all_video_frame_counts(self): def test_download_all_video_frame_counts_broken(self): fps = 24 n_frames = 5 - with tempfile.NamedTemporaryFile(suffix='.mpeg') as file1, \ - tempfile.NamedTemporaryFile(suffix='.mpeg') as file2: - + with tempfile.NamedTemporaryFile( + suffix=".mpeg" + ) as file1, tempfile.NamedTemporaryFile(suffix=".mpeg") as file2: _generate_video(file1.name, fps=fps, n_frames=n_frames) _generate_video(file2.name, fps=fps, broken=True) - + urls = [file1.name, file2.name] result = lightly.api.download.all_video_frame_counts(urls) assert result == [n_frames, None] @@ -452,12 +494,12 @@ def test_download_all_video_frame_counts_broken(self): def test_download_all_video_frame_counts_broken_ignore_exceptions(self): fps = 24 n_frames = 5 - with tempfile.NamedTemporaryFile(suffix='.mpeg') as file1, \ - tempfile.NamedTemporaryFile(suffix='.mpeg') as file2: - + with tempfile.NamedTemporaryFile( + suffix=".mpeg" + ) as file1, tempfile.NamedTemporaryFile(suffix=".mpeg") as file2: _generate_video(file1.name, fps=fps, n_frames=n_frames) _generate_video(file2.name, fps=fps, broken=True) - + urls = [file1.name, file2.name] with self.assertRaises(RuntimeError): result = lightly.api.download.all_video_frame_counts( @@ -469,10 +511,10 @@ def test_download_all_video_frame_counts_broken_ignore_exceptions(self): def test_download_all_video_frame_counts_progress_bar(self): true_n_frames = [3, 5] fps = 24 - pbar = mock.Mock(wraps=tqdm.tqdm(unit='videos')) - with tempfile.NamedTemporaryFile(suffix='.avi') as file1, \ - tempfile.NamedTemporaryFile(suffix='.avi') as file2: - + pbar = mock.Mock(wraps=tqdm.tqdm(unit="videos")) + with tempfile.NamedTemporaryFile( + suffix=".avi" + ) as file1, tempfile.NamedTemporaryFile(suffix=".avi") as file2: _generate_video(file1.name, n_frames=true_n_frames[0], fps=fps) _generate_video(file2.name, n_frames=true_n_frames[1], fps=fps) frame_counts = lightly.api.download.all_video_frame_counts( @@ -483,30 +525,34 @@ def test_download_all_video_frame_counts_progress_bar(self): assert frame_counts == true_n_frames assert pbar.update.call_count == len(true_n_frames) + def _images_equal(image1, image2): # note that images saved and loaded from disk must # use a lossless format, otherwise this equality will not hold return np.all(np.array(image1) == np.array(image2)) + def _pil_image(width=100, height=50, seed=0): np.random.seed(seed) image = (np.random.randn(width, height, 3) * 255).astype(np.uint8) - image = Image.fromarray(image, mode='RGB') + image = Image.fromarray(image, mode="RGB") return image + def _json_prediction(): return { - 'string': 'Hello World', - 'int': 1, - 'float': 0.5, + "string": "Hello World", + "int": 1, + "float": 0.5, } + def _generate_video( - out_file, - n_frames=5, - width=100, - height=50, - seed=0, + out_file, + n_frames=5, + width=100, + height=50, + seed=0, fps=24, broken=False, ): @@ -518,19 +564,19 @@ def _generate_video( fps = 24. """ - is_mpeg = out_file.endswith('.mpeg') - video_format = 'libx264rgb' - pixel_format = 'rgb24' + is_mpeg = out_file.endswith(".mpeg") + video_format = "libx264rgb" + pixel_format = "rgb24" if is_mpeg: - video_format = 'mpeg1video' - pixel_format = 'yuv420p' + video_format = "mpeg1video" + pixel_format = "yuv420p" if broken: n_frames = 0 np.random.seed(seed) - container = av.open(out_file, mode='w') + container = av.open(out_file, mode="w") stream = container.add_stream(video_format, rate=fps) stream.width = width stream.height = height @@ -542,18 +588,20 @@ def _generate_video( # save lossless video stream.options["crf"] = "0" images = (np.random.randn(n_frames, height, width, 3) * 255).astype(np.uint8) - frames = [av.VideoFrame.from_ndarray(image, format=pixel_format) for image in images] - + frames = [ + av.VideoFrame.from_ndarray(image, format=pixel_format) for image in images + ] + for frame in frames: for packet in stream.encode(frame): container.mux(packet) if not broken: # flush the stream - # video cannot be loaded if this is omitted + # video cannot be loaded if this is omitted packet = stream.encode(None) container.mux(packet) - + container.close() pil_images = [frame.to_image() for frame in frames] diff --git a/tests/api/test_rest_parser.py b/tests/api/test_rest_parser.py index ba0ce081a..860f4c0ac 100644 --- a/tests/api/test_rest_parser.py +++ b/tests/api/test_rest_parser.py @@ -2,31 +2,37 @@ import numpy as np -from lightly.openapi_generated.swagger_client import ApiClient, ScoresApi, ActiveLearningScoreCreateRequest, \ - SamplingMethod +from lightly.openapi_generated.swagger_client import ( + ActiveLearningScoreCreateRequest, + ApiClient, + SamplingMethod, + ScoresApi, +) from lightly.openapi_generated.swagger_client.rest import ApiException class TestRestParser(unittest.TestCase): - @unittest.skip("This test only shows the error, it does not ensure it is solved.") def test_parse_active_learning_scores(self): score_value_tuple = ( np.random.normal(0, 1, size=(999,)).astype(np.float32), np.random.normal(0, 1, size=(999,)).astype(np.float64), - [12.0] * 999 + [12.0] * 999, ) api_client = ApiClient() self._scores_api = ScoresApi(api_client) for i, score_values in enumerate(score_value_tuple): with self.subTest(i=i, msg=str(type(score_values))): - body = ActiveLearningScoreCreateRequest(score_type=SamplingMethod.CORESET, scores=list(score_values)) + body = ActiveLearningScoreCreateRequest( + score_type=SamplingMethod.CORESET, scores=list(score_values) + ) if isinstance(score_values[0], float): with self.assertRaises(ApiException): self._scores_api.create_or_update_active_learning_score_by_tag_id( - body, dataset_id="dataset_id_xyz", tag_id="tag_id_xyz") + body, dataset_id="dataset_id_xyz", tag_id="tag_id_xyz" + ) else: with self.assertRaises(AttributeError): self._scores_api.create_or_update_active_learning_score_by_tag_id( - body, dataset_id="dataset_id_xyz", tag_id="tag_id_xyz") - + body, dataset_id="dataset_id_xyz", tag_id="tag_id_xyz" + ) diff --git a/tests/api/test_utils.py b/tests/api/test_utils.py index c01891312..ce4caa796 100644 --- a/tests/api/test_utils.py +++ b/tests/api/test_utils.py @@ -1,27 +1,26 @@ -import unittest - import os -from PIL import Image +import unittest from PIL import Image import lightly -from lightly.api.utils import DatasourceType, get_signed_url_destination, retry -from lightly.api.utils import getenv -from lightly.api.utils import PIL_to_bytes +from lightly.api.utils import ( + DatasourceType, + PIL_to_bytes, + get_signed_url_destination, + getenv, + retry, +) class TestUtils(unittest.TestCase): - def test_retry_success(self): - def my_func(arg, kwarg=5): return arg + kwarg self.assertEqual(retry(my_func, 5, kwarg=5), 10) def test_retry_fail(self): - def my_func(): raise RuntimeError() @@ -29,16 +28,16 @@ def my_func(): retry(my_func) def test_getenv(self): - os.environ['TEST_ENV_VARIABLE'] = 'hello world' - env = getenv('TEST_ENV_VARIABLE', 'default') - self.assertEqual(env, 'hello world') + os.environ["TEST_ENV_VARIABLE"] = "hello world" + env = getenv("TEST_ENV_VARIABLE", "default") + self.assertEqual(env, "hello world") def test_getenv_fail(self): - env = getenv('TEST_ENV_VARIABLE_WHICH_DOES_NOT_EXIST', 'hello world') - self.assertEqual(env, 'hello world') + env = getenv("TEST_ENV_VARIABLE_WHICH_DOES_NOT_EXIST", "hello world") + self.assertEqual(env, "hello world") def test_PIL_to_bytes(self): - image = Image.new('RGB', (128, 128)) + image = Image.new("RGB", (128, 128)) # test with quality=None PIL_to_bytes(image) @@ -47,36 +46,41 @@ def test_PIL_to_bytes(self): PIL_to_bytes(image, quality=90) # test with quality=90 and ext=jpg - PIL_to_bytes(image, ext='JPEG', quality=90) + PIL_to_bytes(image, ext="JPEG", quality=90) def test_get_signed_url_destination(self): - # S3 self.assertEqual( - get_signed_url_destination('https://lightly.s3.eu-central-1.amazonaws.com/lightly/somewhere/image.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=0123456789%2F20220811%2Feu-central-1%2Fs3%2Faws4_request&X-Amz-Date=20220811T065010Z&X-Amz-Expires=601200&X-Amz-Signature=0123456789&X-Amz-SignedHeaders=host&x-id=GetObject'), - DatasourceType.S3 + get_signed_url_destination( + "https://lightly.s3.eu-central-1.amazonaws.com/lightly/somewhere/image.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=0123456789%2F20220811%2Feu-central-1%2Fs3%2Faws4_request&X-Amz-Date=20220811T065010Z&X-Amz-Expires=601200&X-Amz-Signature=0123456789&X-Amz-SignedHeaders=host&x-id=GetObject" + ), + DatasourceType.S3, ) self.assertNotEqual( - get_signed_url_destination('http://someething.with.s3.in.it'), - DatasourceType.S3 + get_signed_url_destination("http://someething.with.s3.in.it"), + DatasourceType.S3, ) # GCS self.assertEqual( - get_signed_url_destination('https://storage.googleapis.com/lightly/somewhere/image.jpg?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=lightly%40appspot.gserviceaccount.com%2F20220811%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20220811T065325Z&X-Goog-Expires=601201&X-Goog-SignedHeaders=host&X-Goog-Signature=01234567890'), - DatasourceType.GCS + get_signed_url_destination( + "https://storage.googleapis.com/lightly/somewhere/image.jpg?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=lightly%40appspot.gserviceaccount.com%2F20220811%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20220811T065325Z&X-Goog-Expires=601201&X-Goog-SignedHeaders=host&X-Goog-Signature=01234567890" + ), + DatasourceType.GCS, ) self.assertNotEqual( - get_signed_url_destination('http://someething.with.google.in.it'), - DatasourceType.GCS + get_signed_url_destination("http://someething.with.google.in.it"), + DatasourceType.GCS, ) # AZURE self.assertEqual( - get_signed_url_destination('https://lightly.blob.core.windows.net/lightly/somewhere/image.jpg?sv=2020-08-04&ss=bfqt&srt=sco&sp=0123456789&se=2022-04-13T20:20:02Z&st=2022-04-13T12:20:02Z&spr=https&sig=0123456789'), - DatasourceType.AZURE + get_signed_url_destination( + "https://lightly.blob.core.windows.net/lightly/somewhere/image.jpg?sv=2020-08-04&ss=bfqt&srt=sco&sp=0123456789&se=2022-04-13T20:20:02Z&st=2022-04-13T12:20:02Z&spr=https&sig=0123456789" + ), + DatasourceType.AZURE, ) self.assertNotEqual( - get_signed_url_destination('http://someething.with.windows.in.it'), - DatasourceType.AZURE + get_signed_url_destination("http://someething.with.windows.in.it"), + DatasourceType.AZURE, ) diff --git a/tests/api/test_version_checking.py b/tests/api/test_version_checking.py index 2145af1db..df885c785 100644 --- a/tests/api/test_version_checking.py +++ b/tests/api/test_version_checking.py @@ -3,15 +3,18 @@ import unittest import lightly -from lightly.api.version_checking import get_latest_version, \ - get_minimum_compatible_version, pretty_print_latest_version, \ - LightlyAPITimeoutException, is_latest_version, is_compatible_version - +from lightly.api.version_checking import ( + LightlyAPITimeoutException, + get_latest_version, + get_minimum_compatible_version, + is_compatible_version, + is_latest_version, + pretty_print_latest_version, +) from tests.api_workflow.mocked_api_workflow_client import MockedVersioningApi class TestVersionChecking(unittest.TestCase): - def setUp(self) -> None: lightly.api.version_checking.VersioningApi = MockedVersioningApi @@ -38,14 +41,14 @@ def test_pretty_print(self): def test_version_check_timout_mocked(self): """ - We cannot check for other errors as we don't know whether the - current LIGHTLY_SERVER_URL is - - unreachable (error in < 1 second) - - causing a timeout and thus raising a LightlyAPITimeoutException - - reachable (success in < 1 second - - Thus this only checks that the actual lightly.do_version_check() - with needing >1s internally causes a LightlyAPITimeoutException + We cannot check for other errors as we don't know whether the + current LIGHTLY_SERVER_URL is + - unreachable (error in < 1 second) + - causing a timeout and thus raising a LightlyAPITimeoutException + - reachable (success in < 1 second + + Thus this only checks that the actual lightly.do_version_check() + with needing >1s internally causes a LightlyAPITimeoutException """ try: old_get_versioning_api = lightly.api.version_checking.get_versioning_api @@ -55,7 +58,9 @@ def mocked_get_versioning_api_timeout(): print("This line should never be reached, calling sys.exit()") sys.exit() - lightly.api.version_checking.get_versioning_api = mocked_get_versioning_api_timeout + lightly.api.version_checking.get_versioning_api = ( + mocked_get_versioning_api_timeout + ) start_time = time.time() diff --git a/tests/api_workflow/mocked_api_workflow_client.py b/tests/api_workflow/mocked_api_workflow_client.py index e7772a384..97472c92c 100644 --- a/tests/api_workflow/mocked_api_workflow_client.py +++ b/tests/api_workflow/mocked_api_workflow_client.py @@ -1,21 +1,82 @@ import csv import io +import json import tempfile import unittest -from io import IOBase from collections import defaultdict -import json +from io import IOBase +from typing import * import numpy as np import requests from requests import Response + +import lightly +from lightly.api.api_workflow_client import ApiWorkflowClient +from lightly.openapi_generated.swagger_client import ( + ApiClient, + CreateEntityResponse, + DatasourceRawSamplesMetadataData, + InitialTagCreateRequest, + QuotaApi, + SampleCreateRequest, + SampleData, + SampleDataModes, + SampleMetaData, + SamplesApi, + SampleUpdateRequest, + SampleWriteUrls, + ScoresApi, + TagArithmeticsRequest, + TagBitMaskResponse, + Trigger2dEmbeddingJobRequest, + VersioningApi, +) +from lightly.openapi_generated.swagger_client.api.collaboration_api import ( + CollaborationApi, +) +from lightly.openapi_generated.swagger_client.api.datasets_api import DatasetsApi +from lightly.openapi_generated.swagger_client.api.datasources_api import DatasourcesApi from lightly.openapi_generated.swagger_client.api.docker_api import DockerApi +from lightly.openapi_generated.swagger_client.api.embeddings_api import EmbeddingsApi +from lightly.openapi_generated.swagger_client.api.jobs_api import JobsApi +from lightly.openapi_generated.swagger_client.api.mappings_api import MappingsApi +from lightly.openapi_generated.swagger_client.api.samplings_api import SamplingsApi +from lightly.openapi_generated.swagger_client.api.tags_api import TagsApi +from lightly.openapi_generated.swagger_client.models.async_task_data import ( + AsyncTaskData, +) from lightly.openapi_generated.swagger_client.models.create_docker_worker_registry_entry_request import ( CreateDockerWorkerRegistryEntryRequest, ) +from lightly.openapi_generated.swagger_client.models.dataset_create_request import ( + DatasetCreateRequest, +) +from lightly.openapi_generated.swagger_client.models.dataset_data import DatasetData +from lightly.openapi_generated.swagger_client.models.dataset_embedding_data import ( + DatasetEmbeddingData, +) +from lightly.openapi_generated.swagger_client.models.datasource_config import ( + DatasourceConfig, +) +from lightly.openapi_generated.swagger_client.models.datasource_config_base import ( + DatasourceConfigBase, +) +from lightly.openapi_generated.swagger_client.models.datasource_processed_until_timestamp_request import ( + DatasourceProcessedUntilTimestampRequest, +) from lightly.openapi_generated.swagger_client.models.datasource_processed_until_timestamp_response import ( DatasourceProcessedUntilTimestampResponse, ) +from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data import ( + DatasourceRawSamplesData, +) +from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data_row import ( + DatasourceRawSamplesDataRow, +) +from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_predictions_data import ( + DatasourceRawSamplesPredictionsData, +) from lightly.openapi_generated.swagger_client.models.docker_run_data import ( DockerRunData, ) @@ -52,69 +113,6 @@ from lightly.openapi_generated.swagger_client.models.filename_and_read_url import ( FilenameAndReadUrl, ) -from lightly.openapi_generated.swagger_client.models.label_box_data_row import ( - LabelBoxDataRow, -) -from lightly.openapi_generated.swagger_client.models.label_studio_task import ( - LabelStudioTask, -) -from lightly.openapi_generated.swagger_client.models.label_studio_task_data import ( - LabelStudioTaskData, -) - - -from lightly.openapi_generated.swagger_client.models.tag_creator import TagCreator -from lightly.openapi_generated.swagger_client.models.dataset_create_request import ( - DatasetCreateRequest, -) -from lightly.openapi_generated.swagger_client.models.dataset_data import DatasetData -from lightly.openapi_generated.swagger_client.models.sample_partial_mode import ( - SamplePartialMode, -) -from lightly.openapi_generated.swagger_client.api.datasets_api import DatasetsApi -from lightly.openapi_generated.swagger_client.api.datasources_api import DatasourcesApi -from lightly.openapi_generated.swagger_client.models.timestamp import Timestamp -from lightly.openapi_generated.swagger_client.rest import ApiException - -import lightly - -from lightly.api.api_workflow_client import ApiWorkflowClient - -from typing import * - -from lightly.openapi_generated.swagger_client import ( - ScoresApi, - CreateEntityResponse, - SamplesApi, - SampleCreateRequest, - InitialTagCreateRequest, - ApiClient, - VersioningApi, - QuotaApi, - TagArithmeticsRequest, - TagBitMaskResponse, - SampleWriteUrls, - SampleData, - SampleMetaData, - SampleDataModes, - DatasourceRawSamplesMetadataData, - Trigger2dEmbeddingJobRequest, - SampleUpdateRequest, -) -from lightly.openapi_generated.swagger_client.api.embeddings_api import EmbeddingsApi -from lightly.openapi_generated.swagger_client.api.collaboration_api import ( - CollaborationApi, -) -from lightly.openapi_generated.swagger_client.api.jobs_api import JobsApi -from lightly.openapi_generated.swagger_client.api.mappings_api import MappingsApi -from lightly.openapi_generated.swagger_client.api.samplings_api import SamplingsApi -from lightly.openapi_generated.swagger_client.api.tags_api import TagsApi -from lightly.openapi_generated.swagger_client.models.async_task_data import ( - AsyncTaskData, -) -from lightly.openapi_generated.swagger_client.models.dataset_embedding_data import ( - DatasetEmbeddingData, -) from lightly.openapi_generated.swagger_client.models.job_result_type import ( JobResultType, ) @@ -125,30 +123,20 @@ from lightly.openapi_generated.swagger_client.models.job_status_data_result import ( JobStatusDataResult, ) -from lightly.openapi_generated.swagger_client.models.sampling_create_request import ( - SamplingCreateRequest, -) -from lightly.openapi_generated.swagger_client.models.tag_data import TagData -from lightly.openapi_generated.swagger_client.models.write_csv_url_data import ( - WriteCSVUrlData, -) -from lightly.openapi_generated.swagger_client.models.datasource_config import ( - DatasourceConfig, -) -from lightly.openapi_generated.swagger_client.models.datasource_config_base import ( - DatasourceConfigBase, +from lightly.openapi_generated.swagger_client.models.label_box_data_row import ( + LabelBoxDataRow, ) -from lightly.openapi_generated.swagger_client.models.datasource_processed_until_timestamp_request import ( - DatasourceProcessedUntilTimestampRequest, +from lightly.openapi_generated.swagger_client.models.label_studio_task import ( + LabelStudioTask, ) -from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data import ( - DatasourceRawSamplesData, +from lightly.openapi_generated.swagger_client.models.label_studio_task_data import ( + LabelStudioTaskData, ) -from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data_row import ( - DatasourceRawSamplesDataRow, +from lightly.openapi_generated.swagger_client.models.sample_partial_mode import ( + SamplePartialMode, ) -from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_predictions_data import ( - DatasourceRawSamplesPredictionsData, +from lightly.openapi_generated.swagger_client.models.sampling_create_request import ( + SamplingCreateRequest, ) from lightly.openapi_generated.swagger_client.models.shared_access_config_create_request import ( SharedAccessConfigCreateRequest, @@ -159,6 +147,13 @@ from lightly.openapi_generated.swagger_client.models.shared_access_type import ( SharedAccessType, ) +from lightly.openapi_generated.swagger_client.models.tag_creator import TagCreator +from lightly.openapi_generated.swagger_client.models.tag_data import TagData +from lightly.openapi_generated.swagger_client.models.timestamp import Timestamp +from lightly.openapi_generated.swagger_client.models.write_csv_url_data import ( + WriteCSVUrlData, +) +from lightly.openapi_generated.swagger_client.rest import ApiException def _check_dataset_id(dataset_id: str): @@ -356,7 +351,7 @@ def perform_tag_arithmetics_bitmask( def upsize_tags_by_dataset_id(self, body, dataset_id, **kwargs): _check_dataset_id(dataset_id) assert body.upsize_tag_creator in ( - TagCreator.USER_PIP, + TagCreator.USER_PIP, TagCreator.USER_PIP_LIGHTLY_MAGIC, ) @@ -389,48 +384,48 @@ def export_tag_to_label_studio_tasks( return [] return [ LabelStudioTask( - id = 0, - data = LabelStudioTaskData( - image = "https://api.lightly.ai/v1/datasets/62383ab8f9cb290cd83ab5f9/samples/62383cb7e6a0f29e3f31e213/readurlRedirect?type=full&CENSORED", - lightly_file_name = "2008_006249_jpg.rf.fdd64460945ca901aa3c7e48ffceea83.jpg", - lightly_meta_info = SampleData( - id = "sample_id_0", - type = "IMAGE", - dataset_id = dataset_id, - file_name = "2008_006249_jpg.rf.fdd64460945ca901aa3c7e48ffceea83.jpg", - exif = {}, - index = 0, - created_at = 1647852727873, - last_modified_at = 1647852727873, - meta_data = SampleMetaData( - sharpness = 27.31265790443818, - size_in_bytes = 48224, - snr = 2.1969673926211217, - mean = [ + id=0, + data=LabelStudioTaskData( + image="https://api.lightly.ai/v1/datasets/62383ab8f9cb290cd83ab5f9/samples/62383cb7e6a0f29e3f31e213/readurlRedirect?type=full&CENSORED", + lightly_file_name="2008_006249_jpg.rf.fdd64460945ca901aa3c7e48ffceea83.jpg", + lightly_meta_info=SampleData( + id="sample_id_0", + type="IMAGE", + dataset_id=dataset_id, + file_name="2008_006249_jpg.rf.fdd64460945ca901aa3c7e48ffceea83.jpg", + exif={}, + index=0, + created_at=1647852727873, + last_modified_at=1647852727873, + meta_data=SampleMetaData( + sharpness=27.31265790443818, + size_in_bytes=48224, + snr=2.1969673926211217, + mean=[ 0.24441662557257224, 0.4460417517905863, 0.6960984853824035, ], - shape = [167, 500, 3], - std = [ + shape=[167, 500, 3], + std=[ 0.12448681278605961, 0.09509570033043004, 0.0763725998175394, ], - sum_of_squares = [ + sum_of_squares=[ 6282.243860049413, 17367.702452895475, 40947.22059208768, ], - sum_of_values = [ + sum_of_values=[ 20408.78823530978, 37244.486274513954, 58124.22352943069, ], ), ), - ) - ).to_dict() # temporary until we have a proper openapi generator + ), + ).to_dict() # temporary until we have a proper openapi generator ] def export_tag_to_label_box_data_rows( @@ -440,9 +435,9 @@ def export_tag_to_label_box_data_rows( return [] return [ LabelBoxDataRow( - external_id = "2008_007291_jpg.rf.2fca436925b52ea33cf897125a34a2fb.jpg", - image_url = "https://api.lightly.ai/v1/datasets/62383ab8f9cb290cd83ab5f9/samples/62383cb7e6a0f29e3f31e233/readurlRedirect?type=CENSORED", - ).to_dict() # temporary until we have a proper openapi generator + external_id="2008_007291_jpg.rf.2fca436925b52ea33cf897125a34a2fb.jpg", + image_url="https://api.lightly.ai/v1/datasets/62383ab8f9cb290cd83ab5f9/samples/62383cb7e6a0f29e3f31e233/readurlRedirect?type=CENSORED", + ).to_dict() # temporary until we have a proper openapi generator ] def export_tag_to_basic_filenames_and_read_urls( @@ -452,12 +447,14 @@ def export_tag_to_basic_filenames_and_read_urls( return [] return [ FilenameAndReadUrl( - file_name = "export-basic-test-sample-0.png", - read_url = "https://storage.googleapis.com/somwhere/export-basic-test-sample-0.png?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=CENSORED", - ).to_dict() # temporary until we have a proper openapi generator + file_name="export-basic-test-sample-0.png", + read_url="https://storage.googleapis.com/somwhere/export-basic-test-sample-0.png?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=CENSORED", + ).to_dict() # temporary until we have a proper openapi generator ] - def export_tag_to_basic_filenames(self, dataset_id: str, tag_id: str, **kwargs) -> str: + def export_tag_to_basic_filenames( + self, dataset_id: str, tag_id: str, **kwargs + ) -> str: return """ IMG_2276_jpeg_jpg.rf.7411b1902c81bad8cdefd2cc4eb3a97b.jpg IMG_2285_jpeg_jpg.rf.4a93d99b9f0b6cccfb27bf2f4a13b99e.jpg @@ -604,15 +601,14 @@ def __init__(self, api_client): def _all_datasets(self) -> List[DatasetData]: return [*self.datasets, *self.shared_datasets] - def reset(self): self.datasets = self._default_datasets self.shared_datasets = self._shared_datasets def get_datasets( - self, - shared: bool, - page_size: Optional[int] = None, + self, + shared: bool, + page_size: Optional[int] = None, page_offset: Optional[int] = None, ): start, end = _start_and_end_offset(page_size=page_size, page_offset=page_offset) @@ -668,12 +664,18 @@ def get_children_of_dataset_id(self, dataset_id, **kwargs): def get_datasets_enriched(self, **kwargs): raise NotImplementedError() - def get_datasets_query_by_name(self, dataset_name: str, shared: bool, exact: bool) -> List[DatasetData]: + def get_datasets_query_by_name( + self, dataset_name: str, shared: bool, exact: bool + ) -> List[DatasetData]: datasets = self.get_datasets(shared=shared) if exact: return [dataset for dataset in datasets if dataset.name == dataset_name] else: - return [dataset for dataset in datasets if dataset.name is not None and dataset.name.startswith(dataset_name)] + return [ + dataset + for dataset in datasets + if dataset.name is not None and dataset.name.startswith(dataset_name) + ] def update_dataset_by_id(self, body, dataset_id, **kwargs): raise NotImplementedError() @@ -689,7 +691,6 @@ def __init__(self, api_client=None): self.reset() def reset(self): - local_datasource = DatasourceConfigBase( type="LOCAL", full_path="", purpose="INPUT_OUTPUT" ).to_dict() @@ -878,7 +879,7 @@ def __init__(self, api_client=None): created_at=Timestamp(0), last_modified_at=Timestamp(100), owner="user-id-1", - runs_on=[] + runs_on=[], ) ] self._registered_workers = [ @@ -906,7 +907,7 @@ def get_docker_worker_registry_entries(self, **kwargs): def create_docker_worker_config(self, body, **kwargs): assert isinstance(body, DockerWorkerConfigCreateRequest) return CreateEntityResponse(id="worker-config-id-123") - + def create_docker_worker_config_v2(self, body, **kwargs): assert isinstance(body, DockerWorkerConfigV2CreateRequest) return CreateEntityResponse(id="worker-configv2-id-123") @@ -916,22 +917,33 @@ def create_docker_run_scheduled_by_dataset_id(self, body, dataset_id, **kwargs): _check_dataset_id(dataset_id) return CreateEntityResponse(id=f"scheduled-run-id-123-dataset-{dataset_id}") - def get_docker_runs(self, page_size: Optional[int] = None, page_offset: Optional[int] = None, **kwargs): + def get_docker_runs( + self, + page_size: Optional[int] = None, + page_offset: Optional[int] = None, + **kwargs, + ): start, end = _start_and_end_offset(page_size=page_size, page_offset=page_offset) return self._compute_worker_runs[start:end] def get_docker_runs_count(self, **kwargs): return len(self._compute_worker_runs) - def get_docker_runs_scheduled_by_dataset_id(self, dataset_id, state: Optional[str] = None, **kwargs): + def get_docker_runs_scheduled_by_dataset_id( + self, dataset_id, state: Optional[str] = None, **kwargs + ): runs = self._scheduled_compute_worker_runs runs = [run for run in runs if run.dataset_id == dataset_id] return runs - def cancel_scheduled_docker_run_state_by_id(self, dataset_id: str, scheduled_id: str, **kwargs): + def cancel_scheduled_docker_run_state_by_id( + self, dataset_id: str, scheduled_id: str, **kwargs + ): raise NotImplementedError() - def confirm_docker_run_artifact_creation(self, run_id: str, artifact_id: str, **kwargs): + def confirm_docker_run_artifact_creation( + self, run_id: str, artifact_id: str, **kwargs + ): raise NotImplementedError() def create_docker_run(self, body, **kwargs): @@ -969,10 +981,10 @@ def get_docker_runs_scheduled_by_worker_id(self, worker_id, **kwargs): def get_docker_worker_config_by_id(self, config_id, **kwargs): raise NotImplementedError() - + def get_docker_worker_configs(self, **kwargs): raise NotImplementedError() - + def get_docker_worker_registry_entry_by_id(self, worker_id, **kwargs): raise NotImplementedError() @@ -994,9 +1006,12 @@ def update_docker_worker_config_by_id(self, body, config_id, **kwargs): def update_docker_worker_registry_entry_by_id(self, body, worker_id, **kwargs): raise NotImplementedError() - def update_scheduled_docker_run_state_by_id(self, body, dataset_id, worker_id, scheduled_id, **kwargs): + def update_scheduled_docker_run_state_by_id( + self, body, dataset_id, worker_id, scheduled_id, **kwargs + ): raise NotImplementedError() + class MockedVersioningApi(VersioningApi): def get_latest_pip_version(self, **kwargs): return "1.2.8" @@ -1041,7 +1056,6 @@ def get_shared_access_configs_by_dataset_id(self, dataset_id, **kwargs): class MockedApiWorkflowClient(ApiWorkflowClient): - embeddings_filename_base = "img" n_embedding_rows_on_server = N_FILES_ON_SERVER @@ -1114,8 +1128,9 @@ def setUp(self, token="token_xyz", dataset_id="dataset_id_xyz") -> None: token=token, dataset_id=dataset_id ) + def _start_and_end_offset( - page_size: Optional[int], + page_size: Optional[int], page_offset: Optional[int], ) -> Union[Tuple[int, int], Tuple[None, None]]: if page_size is None and page_offset is None: diff --git a/tests/api_workflow/test_api_workflow.py b/tests/api_workflow/test_api_workflow.py index a288bbf4b..d7465a327 100644 --- a/tests/api_workflow/test_api_workflow.py +++ b/tests/api_workflow/test_api_workflow.py @@ -1,28 +1,35 @@ import os from unittest import mock -from urllib3 import Timeout import numpy as np +from urllib3 import Timeout import lightly from lightly.api.api_workflow_client import set_api_client_request_timeout from lightly.openapi_generated.swagger_client.api_client import ApiClient -from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowClient, MockedApiWorkflowSetup +from tests.api_workflow.mocked_api_workflow_client import ( + MockedApiWorkflowClient, + MockedApiWorkflowSetup, +) -class TestApiWorkflow(MockedApiWorkflowSetup): +class TestApiWorkflow(MockedApiWorkflowSetup): def setUp(self) -> None: lightly.api.api_workflow_client.__version__ = lightly.__version__ self.api_workflow_client = MockedApiWorkflowClient(token="token_xyz") - @mock.patch.dict(os.environ, {'LIGHTLY_TOKEN': 'token_xyz'}) + @mock.patch.dict(os.environ, {"LIGHTLY_TOKEN": "token_xyz"}) def test_init_with_env_token(self): MockedApiWorkflowClient() def test_error_if_init_without_token(self): - # copy environment variables but remove LIGHTLY_TOKEN if it exists - env_without_token = {k: v for k, v in os.environ.items() if k != 'LIGHTLY_TOKEN'} - with self.assertRaises(ValueError), mock.patch.dict(os.environ, env_without_token, clear=True): + # copy environment variables but remove LIGHTLY_TOKEN if it exists + env_without_token = { + k: v for k, v in os.environ.items() if k != "LIGHTLY_TOKEN" + } + with self.assertRaises(ValueError), mock.patch.dict( + os.environ, env_without_token, clear=True + ): MockedApiWorkflowClient() def test_error_if_version_is_incompatible(self): @@ -57,84 +64,99 @@ def test_reorder_random(self): numbers_all = list(np.random.choice(numbers_to_choose_from, 100)) filenames_on_server = [f"img_{i}" for i in numbers_all] - api_workflow_client = MockedApiWorkflowClient(token="token_xyz", dataset_id="dataset_id_xyz") + api_workflow_client = MockedApiWorkflowClient( + token="token_xyz", dataset_id="dataset_id_xyz" + ) api_workflow_client._mappings_api.sample_names = filenames_on_server numbers_in_tag = np.copy(numbers_all) np.random.shuffle(numbers_in_tag) filenames_for_list = [f"img_{i}" for i in numbers_in_tag] - list_ordered = api_workflow_client._order_list_by_filenames(filenames_for_list, - list_to_order=numbers_in_tag) + list_ordered = api_workflow_client._order_list_by_filenames( + filenames_for_list, list_to_order=numbers_in_tag + ) list_desired_order = [i for i in numbers_all if i in numbers_in_tag] assert list_ordered == list_desired_order def test_reorder_manual(self): - filenames_on_server = ['a', 'b', 'c'] - api_workflow_client = MockedApiWorkflowClient(token="token_xyz", dataset_id="dataset_id_xyz") + filenames_on_server = ["a", "b", "c"] + api_workflow_client = MockedApiWorkflowClient( + token="token_xyz", dataset_id="dataset_id_xyz" + ) api_workflow_client._mappings_api.sample_names = filenames_on_server - filenames_for_list = ['c', 'a', 'b'] - list_to_order = ['cccc', 'aaaa', 'bbbb'] - list_ordered = api_workflow_client._order_list_by_filenames(filenames_for_list, list_to_order=list_to_order) - list_desired_order = ['aaaa', 'bbbb', 'cccc'] + filenames_for_list = ["c", "a", "b"] + list_to_order = ["cccc", "aaaa", "bbbb"] + list_ordered = api_workflow_client._order_list_by_filenames( + filenames_for_list, list_to_order=list_to_order + ) + list_desired_order = ["aaaa", "bbbb", "cccc"] assert list_ordered == list_desired_order def test_reorder_wrong_lengths(self): - filenames_on_server = ['a', 'b', 'c'] + filenames_on_server = ["a", "b", "c"] api_workflow_client = MockedApiWorkflowClient( token="token_xyz", dataset_id="dataset_id_xyz" ) api_workflow_client._mappings_api.sample_names = filenames_on_server - filenames_for_list = ['c', 'a', 'b'] - list_to_order = ['cccc', 'aaaa', 'bbbb'] + filenames_for_list = ["c", "a", "b"] + list_to_order = ["cccc", "aaaa", "bbbb"] with self.subTest("filenames_for_list wrong length"): with self.assertRaises(ValueError): api_workflow_client._order_list_by_filenames( - filenames_for_list[:-1], list_to_order) + filenames_for_list[:-1], list_to_order + ) with self.subTest("list_to_order wrong length"): with self.assertRaises(ValueError): api_workflow_client._order_list_by_filenames( - filenames_for_list, list_to_order[:-1]) + filenames_for_list, list_to_order[:-1] + ) with self.subTest("filenames_for_list and list_to_order wrong length"): with self.assertRaises(ValueError): api_workflow_client._order_list_by_filenames( - filenames_for_list[:-1], list_to_order[:-1]) + filenames_for_list[:-1], list_to_order[:-1] + ) def test_set_api_client_timeout_total(): client = ApiClient() set_api_client_request_timeout(client, timeout=1) - # mock urllib3 + # mock urllib3 client.rest_client.pool_manager = mock.Mock(wraps=client.rest_client.pool_manager) - client.rest_client.pool_manager.request.return_value = mock.Mock(status=200, data="data".encode('utf8')) + client.rest_client.pool_manager.request.return_value = mock.Mock( + status=200, data="data".encode("utf8") + ) - client.request('GET', 'some-url') + client.request("GET", "some-url") - # verify that urllib3 request was called with timeout + # verify that urllib3 request was called with timeout request_calls = client.rest_client.pool_manager.request.mock_calls assert len(request_calls) == 1 _, _, kwargs = request_calls[0] - assert isinstance(kwargs['timeout'], Timeout) - assert kwargs['timeout'].total == 1 + assert isinstance(kwargs["timeout"], Timeout) + assert kwargs["timeout"].total == 1 + def test_set_api_client_timeout_connect_read(): client = ApiClient() set_api_client_request_timeout(client, timeout=(1, 2)) - # mock urllib3 + # mock urllib3 client.rest_client.pool_manager = mock.Mock(wraps=client.rest_client.pool_manager) - client.rest_client.pool_manager.request.return_value = mock.Mock(status=200, data="data".encode('utf8')) + client.rest_client.pool_manager.request.return_value = mock.Mock( + status=200, data="data".encode("utf8") + ) - client.request('GET', 'some-url') + client.request("GET", "some-url") - # verify that urllib3 request was called with timeout + # verify that urllib3 request was called with timeout request_calls = client.rest_client.pool_manager.request.mock_calls assert len(request_calls) == 1 _, _, kwargs = request_calls[0] - assert isinstance(kwargs['timeout'], Timeout) - assert kwargs['timeout'].connect_timeout == 1 - assert kwargs['timeout'].read_timeout == 2 + assert isinstance(kwargs["timeout"], Timeout) + assert kwargs["timeout"].connect_timeout == 1 + assert kwargs["timeout"].read_timeout == 2 diff --git a/tests/api_workflow/test_api_workflow_artifacts.py b/tests/api_workflow/test_api_workflow_artifacts.py index b05bf41a5..55910d374 100644 --- a/tests/api_workflow/test_api_workflow_artifacts.py +++ b/tests/api_workflow/test_api_workflow_artifacts.py @@ -1,13 +1,13 @@ import pytest from pytest_mock import MockerFixture +from lightly.api import ApiWorkflowClient, ArtifactNotExist from lightly.openapi_generated.swagger_client import ( + DockerApi, DockerRunArtifactData, DockerRunArtifactType, DockerRunData, - DockerApi, ) -from lightly.api import ApiWorkflowClient, ArtifactNotExist def test_download_compute_worker_run_artifacts(mocker: MockerFixture) -> None: diff --git a/tests/api_workflow/test_api_workflow_client.py b/tests/api_workflow/test_api_workflow_client.py index c38fd7230..0d973028f 100644 --- a/tests/api_workflow/test_api_workflow_client.py +++ b/tests/api_workflow/test_api_workflow_client.py @@ -1,22 +1,21 @@ +import os import platform import unittest from unittest import mock -import lightly import requests -import os - from pytest_mock import MockerFixture -from lightly.api.api_workflow_client import ApiWorkflowClient, LIGHTLY_S3_SSE_KMS_KEY +import lightly +from lightly.api.api_workflow_client import LIGHTLY_S3_SSE_KMS_KEY, ApiWorkflowClient -class TestApiWorkflowClient(unittest.TestCase): +class TestApiWorkflowClient(unittest.TestCase): def test_upload_file_with_signed_url(self): - with mock.patch('lightly.api.api_workflow_client.requests') as requests: + with mock.patch("lightly.api.api_workflow_client.requests") as requests: client = ApiWorkflowClient(token="") file = mock.Mock() - signed_write_url = '' + signed_write_url = "" client.upload_file_with_signed_url( file=file, signed_write_url=signed_write_url, @@ -26,74 +25,81 @@ def test_upload_file_with_signed_url(self): def test_upload_file_with_signed_url_session(self): session = mock.Mock() file = mock.Mock() - signed_write_url = '' + signed_write_url = "" client = ApiWorkflowClient(token="") client.upload_file_with_signed_url( - file=file, - signed_write_url=signed_write_url, - session=session + file=file, signed_write_url=signed_write_url, session=session ) session.put.assert_called_with(signed_write_url, data=file) - + def test_upload_file_with_signed_url_session_sse(self): session = mock.Mock() file = mock.Mock() - signed_write_url = 'http://somwhere.s3.amazonaws.com/someimage.png' + signed_write_url = "http://somwhere.s3.amazonaws.com/someimage.png" client = ApiWorkflowClient(token="") - # set the environment var to enable SSE - os.environ[LIGHTLY_S3_SSE_KMS_KEY] = 'True' + # set the environment var to enable SSE + os.environ[LIGHTLY_S3_SSE_KMS_KEY] = "True" client.upload_file_with_signed_url( - file=file, - signed_write_url=signed_write_url, - session=session + file=file, signed_write_url=signed_write_url, session=session ) - session.put.assert_called_with(signed_write_url, data=file, headers={'x-amz-server-side-encryption': 'AES256'}) - + session.put.assert_called_with( + signed_write_url, + data=file, + headers={"x-amz-server-side-encryption": "AES256"}, + ) + def test_upload_file_with_signed_url_session_sse_kms(self): session = mock.Mock() file = mock.Mock() - signed_write_url = 'http://somwhere.s3.amazonaws.com/someimage.png' + signed_write_url = "http://somwhere.s3.amazonaws.com/someimage.png" client = ApiWorkflowClient(token="") - # set the environment var to enable SSE with KMS + # set the environment var to enable SSE with KMS sseKMSKey = "arn:aws:kms:us-west-2:123456789000:key/1234abcd-12ab-34cd-56ef-1234567890ab" os.environ[LIGHTLY_S3_SSE_KMS_KEY] = sseKMSKey client.upload_file_with_signed_url( - file=file, - signed_write_url=signed_write_url, - session=session + file=file, signed_write_url=signed_write_url, session=session ) session.put.assert_called_with( signed_write_url, data=file, headers={ - 'x-amz-server-side-encryption': 'aws:kms', - 'x-amz-server-side-encryption-aws-kms-key-id': sseKMSKey, - } + "x-amz-server-side-encryption": "aws:kms", + "x-amz-server-side-encryption-aws-kms-key-id": sseKMSKey, + }, ) def test_upload_file_with_signed_url_raise_status(self): def raise_connection_error(*args, **kwargs): raise requests.exceptions.ConnectionError() - with mock.patch('lightly.api.api_workflow_client.requests.put', raise_connection_error): + with mock.patch( + "lightly.api.api_workflow_client.requests.put", raise_connection_error + ): client = ApiWorkflowClient(token="") with self.assertRaises(requests.exceptions.ConnectionError): client.upload_file_with_signed_url( file=mock.Mock(), - signed_write_url='', + signed_write_url="", ) + def test_user_agent_header(mocker: MockerFixture) -> None: mocker.patch.object(lightly.api.api_workflow_client, "__version__", new="VERSION") - mocker.patch.object(lightly.api.api_workflow_client, "is_compatible_version", new=lambda _: True) - mocked_platform = mocker.patch.object(lightly.api.api_workflow_client, "platform", spec_set=platform) + mocker.patch.object( + lightly.api.api_workflow_client, "is_compatible_version", new=lambda _: True + ) + mocked_platform = mocker.patch.object( + lightly.api.api_workflow_client, "platform", spec_set=platform + ) mocked_platform.system.return_value = "SYSTEM" mocked_platform.release.return_value = "RELEASE" mocked_platform.platform.return_value = "PLATFORM" mocked_platform.processor.return_value = "PROCESSOR" mocked_platform.python_version.return_value = "PYTHON_VERSION" - client = ApiWorkflowClient(token="") - assert client.api_client.user_agent == f"Lightly/VERSION (SYSTEM/RELEASE; PLATFORM; PROCESSOR;) python/PYTHON_VERSION" + assert ( + client.api_client.user_agent + == f"Lightly/VERSION (SYSTEM/RELEASE; PLATFORM; PROCESSOR;) python/PYTHON_VERSION" + ) diff --git a/tests/api_workflow/test_api_workflow_collaboration.py b/tests/api_workflow/test_api_workflow_collaboration.py index 4fe447764..abcbd5c3c 100644 --- a/tests/api_workflow/test_api_workflow_collaboration.py +++ b/tests/api_workflow/test_api_workflow_collaboration.py @@ -1,17 +1,25 @@ -from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup, MockedApiWorkflowClient +from tests.api_workflow.mocked_api_workflow_client import ( + MockedApiWorkflowClient, + MockedApiWorkflowSetup, +) class TestApiWorkflowDatasets(MockedApiWorkflowSetup): - def setUp(self) -> None: self.api_workflow_client = MockedApiWorkflowClient(token="token_xyz") def test_share_empty_dataset(self): - self.api_workflow_client.share_dataset_only_with(dataset_id="some-dataset-id", user_emails=[]) + self.api_workflow_client.share_dataset_only_with( + dataset_id="some-dataset-id", user_emails=[] + ) def test_share_dataset(self): - self.api_workflow_client.share_dataset_only_with(dataset_id="some-dataset-id", user_emails=["someone@something.com"]) + self.api_workflow_client.share_dataset_only_with( + dataset_id="some-dataset-id", user_emails=["someone@something.com"] + ) def test_get_shared_users(self): - user_emails = self.api_workflow_client.get_shared_users(dataset_id="some-dataset-id") + user_emails = self.api_workflow_client.get_shared_users( + dataset_id="some-dataset-id" + ) assert user_emails == ["user1@gmail.com", "user2@something.com"] diff --git a/tests/api_workflow/test_api_workflow_compute_worker.py b/tests/api_workflow/test_api_workflow_compute_worker.py index 603ab8fff..0a7bc2368 100644 --- a/tests/api_workflow/test_api_workflow_compute_worker.py +++ b/tests/api_workflow/test_api_workflow_compute_worker.py @@ -1,47 +1,47 @@ import json import random +from typing import Any, List +from unittest import mock from unittest.mock import MagicMock import pytest from pytest_mock import MockerFixture -from typing import Any, List -from unittest import mock +from lightly.api import ApiWorkflowClient, api_workflow_compute_worker from lightly.api.api_workflow_compute_worker import ( STATE_SCHEDULED_ID_NOT_FOUND, ComputeWorkerRunInfo, + InvalidConfigurationError, _config_to_camel_case, _snake_to_camel_case, _validate_config, - InvalidConfigurationError, ) from lightly.openapi_generated.swagger_client import ( - SelectionConfig, - SelectionConfigEntry, - SelectionInputType, - SelectionStrategyType, ApiClient, DockerApi, - SelectionConfigEntryInput, - SelectionStrategyThresholdOperation, - SelectionInputPredictionsName, - SelectionConfigEntryStrategy, - DockerWorkerConfig, - DockerWorkerType, DockerRunData, DockerRunScheduledData, DockerRunScheduledPriority, DockerRunScheduledState, DockerRunState, + DockerWorkerConfig, DockerWorkerConfigV2Docker, DockerWorkerConfigV2DockerCorruptnessCheck, DockerWorkerConfigV2Lightly, DockerWorkerConfigV2LightlyLoader, + DockerWorkerType, + SelectionConfig, + SelectionConfigEntry, + SelectionConfigEntryInput, + SelectionConfigEntryStrategy, + SelectionInputPredictionsName, + SelectionInputType, + SelectionStrategyThresholdOperation, + SelectionStrategyType, TagData, ) from lightly.openapi_generated.swagger_client.rest import ApiException from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup -from lightly.api import api_workflow_compute_worker, ApiWorkflowClient class TestApiWorkflowComputeWorker(MockedApiWorkflowSetup): @@ -312,7 +312,6 @@ def test_selection_config_from_dict__typo() -> None: def test_get_scheduled_run_by_id() -> None: - scheduled_runs = [ DockerRunScheduledData( id=f"id_{i}", @@ -341,7 +340,6 @@ def test_get_scheduled_run_by_id() -> None: def test_get_scheduled_run_by_id_not_found() -> None: - scheduled_runs = [ DockerRunScheduledData( id=f"id_{i}", @@ -373,7 +371,6 @@ def test_get_scheduled_run_by_id_not_found() -> None: def test_get_compute_worker_state_and_message_OPEN() -> None: - scheduled_run = DockerRunScheduledData( id=f"id_2", dataset_id="dataset_id", @@ -449,7 +446,6 @@ def test_get_compute_worker_state_and_message_docker_state() -> None: def test_compute_worker_run_info_generator(mocker) -> None: - states = [f"state_{i}" for i in range(7)] states[-1] = DockerRunState.COMPLETED @@ -666,7 +662,6 @@ def test__config_to_camel_case() -> None: def test__snake_to_camel_case() -> None: - assert _snake_to_camel_case("lorem") == "lorem" assert _snake_to_camel_case("lorem_ipsum") == "loremIpsum" assert _snake_to_camel_case("lorem_ipsum_dolor") == "loremIpsumDolor" @@ -674,12 +669,11 @@ def test__snake_to_camel_case() -> None: def test__validate_config__docker(mocker: MockerFixture) -> None: - obj = DockerWorkerConfigV2Docker( enable_training=False, corruptness_check=DockerWorkerConfigV2DockerCorruptnessCheck( corruption_threshold=0.1, - ) + ), ) _validate_config( cfg={ @@ -688,22 +682,21 @@ def test__validate_config__docker(mocker: MockerFixture) -> None: "corruption_threshold": 0.1, }, }, - obj=obj + obj=obj, ) def test__validate_config__docker_typo(mocker: MockerFixture) -> None: - obj = DockerWorkerConfigV2Docker( enable_training=False, corruptness_check=DockerWorkerConfigV2DockerCorruptnessCheck( corruption_threshold=0.1, - ) + ), ) with pytest.raises( InvalidConfigurationError, - match="Option 'enable_trainingx' does not exist! Did you mean 'enable_training'?" + match="Option 'enable_trainingx' does not exist! Did you mean 'enable_training'?", ): _validate_config( cfg={ @@ -712,21 +705,21 @@ def test__validate_config__docker_typo(mocker: MockerFixture) -> None: "corruption_threshold": 0.1, }, }, - obj=obj + obj=obj, ) -def test__validate_config__docker_typo_nested(mocker: MockerFixture) -> None: +def test__validate_config__docker_typo_nested(mocker: MockerFixture) -> None: obj = DockerWorkerConfigV2Docker( enable_training=False, corruptness_check=DockerWorkerConfigV2DockerCorruptnessCheck( corruption_threshold=0.1, - ) + ), ) with pytest.raises( InvalidConfigurationError, - match="Option 'corruption_thresholdx' does not exist! Did you mean 'corruption_threshold'?" + match="Option 'corruption_thresholdx' does not exist! Did you mean 'corruption_threshold'?", ): _validate_config( cfg={ @@ -735,12 +728,11 @@ def test__validate_config__docker_typo_nested(mocker: MockerFixture) -> None: "corruption_thresholdx": 0.1, }, }, - obj=obj + obj=obj, ) def test__validate_config__lightly(mocker: MockerFixture) -> None: - obj = DockerWorkerConfigV2Lightly( loader=DockerWorkerConfigV2LightlyLoader( num_workers=-1, @@ -756,12 +748,11 @@ def test__validate_config__lightly(mocker: MockerFixture) -> None: "shuffle": True, }, }, - obj=obj + obj=obj, ) def test__validate_config__lightly_typo(mocker: MockerFixture) -> None: - obj = DockerWorkerConfigV2Lightly( loader=DockerWorkerConfigV2LightlyLoader( num_workers=-1, @@ -771,7 +762,7 @@ def test__validate_config__lightly_typo(mocker: MockerFixture) -> None: ) with pytest.raises( InvalidConfigurationError, - match="Option 'loaderx' does not exist! Did you mean 'loader'?" + match="Option 'loaderx' does not exist! Did you mean 'loader'?", ): _validate_config( cfg={ @@ -781,12 +772,11 @@ def test__validate_config__lightly_typo(mocker: MockerFixture) -> None: "shuffle": True, }, }, - obj=obj + obj=obj, ) def test__validate_config__lightly_typo_nested(mocker: MockerFixture) -> None: - obj = DockerWorkerConfigV2Lightly( loader=DockerWorkerConfigV2LightlyLoader( num_workers=-1, @@ -796,7 +786,7 @@ def test__validate_config__lightly_typo_nested(mocker: MockerFixture) -> None: ) with pytest.raises( InvalidConfigurationError, - match="Option 'num_workersx' does not exist! Did you mean 'num_workers'?" + match="Option 'num_workersx' does not exist! Did you mean 'num_workers'?", ): _validate_config( cfg={ @@ -806,16 +796,12 @@ def test__validate_config__lightly_typo_nested(mocker: MockerFixture) -> None: "shuffle": True, }, }, - obj=obj + obj=obj, ) def test__validate_config__raises_type_error(mocker: MockerFixture) -> None: with pytest.raises( - TypeError, - match="of argument 'obj' has not attribute 'swagger_types'" + TypeError, match="of argument 'obj' has not attribute 'swagger_types'" ): - _validate_config( - cfg={}, - obj=mocker.MagicMock() - ) + _validate_config(cfg={}, obj=mocker.MagicMock()) diff --git a/tests/api_workflow/test_api_workflow_datasets.py b/tests/api_workflow/test_api_workflow_datasets.py index 0881d35ba..5825eefd7 100644 --- a/tests/api_workflow/test_api_workflow_datasets.py +++ b/tests/api_workflow/test_api_workflow_datasets.py @@ -3,9 +3,9 @@ from lightly.api import ApiWorkflowClient from lightly.openapi_generated.swagger_client import ( Creator, + DatasetCreateRequest, DatasetsApi, DatasetType, - DatasetCreateRequest, ) from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup diff --git a/tests/api_workflow/test_api_workflow_datasources.py b/tests/api_workflow/test_api_workflow_datasources.py index 3595d3e3b..8d92681ac 100644 --- a/tests/api_workflow/test_api_workflow_datasources.py +++ b/tests/api_workflow/test_api_workflow_datasources.py @@ -1,13 +1,13 @@ +from collections import defaultdict from unittest import mock -import tqdm import pytest +import tqdm -from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup from lightly.openapi_generated.swagger_client.models.datasource_raw_samples_data_row import ( DatasourceRawSamplesDataRow, ) -from collections import defaultdict +from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup class TestApiWorkflowDatasources(MockedApiWorkflowSetup): diff --git a/tests/api_workflow/test_api_workflow_download_dataset.py b/tests/api_workflow/test_api_workflow_download_dataset.py index 40fba4e68..7a144a585 100644 --- a/tests/api_workflow/test_api_workflow_download_dataset.py +++ b/tests/api_workflow/test_api_workflow_download_dataset.py @@ -5,37 +5,50 @@ import PIL import lightly -from lightly.api import api_workflow_download_dataset -from lightly.api import download +from lightly.api import api_workflow_download_dataset, download from lightly.openapi_generated.swagger_client import DatasetData, DatasetEmbeddingData from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup - class TestApiWorkflowDownloadDataset(MockedApiWorkflowSetup): def setUp(self) -> None: - MockedApiWorkflowSetup.setUp(self, dataset_id='dataset_0_id') + MockedApiWorkflowSetup.setUp(self, dataset_id="dataset_0_id") self.api_workflow_client._tags_api.no_tags = 3 def test_download_non_existing_tag(self): with self.assertRaises(ValueError): - self.api_workflow_client.download_dataset('path/to/dir', tag_name='this_is_not_a_real_tag_name') + self.api_workflow_client.download_dataset( + "path/to/dir", tag_name="this_is_not_a_real_tag_name" + ) def test_download_thumbnails(self): def get_thumbnail_dataset_by_id(*args): - return DatasetData(name=f'dataset', id='dataset_id', last_modified_at=0, - type='thumbnails', size_in_bytes=-1, n_samples=-1, created_at=-1) - self.api_workflow_client._datasets_api.get_dataset_by_id = get_thumbnail_dataset_by_id + return DatasetData( + name=f"dataset", + id="dataset_id", + last_modified_at=0, + type="thumbnails", + size_in_bytes=-1, + n_samples=-1, + created_at=-1, + ) + + self.api_workflow_client._datasets_api.get_dataset_by_id = ( + get_thumbnail_dataset_by_id + ) with self.assertRaises(ValueError): - self.api_workflow_client.download_dataset('path/to/dir') + self.api_workflow_client.download_dataset("path/to/dir") def test_download_dataset(self): def my_func(read_url): - return PIL.Image.fromarray(np.zeros((32, 32))).convert('RGB') - #mock_get_image_from_readurl.return_value = PIL.Image.fromarray(np.zeros((32, 32))) + return PIL.Image.fromarray(np.zeros((32, 32))).convert("RGB") + + # mock_get_image_from_readurl.return_value = PIL.Image.fromarray(np.zeros((32, 32))) lightly.api.api_workflow_download_dataset._get_image_from_read_url = my_func - self.api_workflow_client.download_dataset('path-to-dir-remove-me', tag_name='initial-tag') - shutil.rmtree('path-to-dir-remove-me') + self.api_workflow_client.download_dataset( + "path-to-dir-remove-me", tag_name="initial-tag" + ) + shutil.rmtree("path-to-dir-remove-me") def test_get_embedding_data_by_name(self) -> None: embedding_0 = DatasetEmbeddingData( @@ -53,9 +66,8 @@ def test_get_embedding_data_by_name(self) -> None: with mock.patch.object( self.api_workflow_client._embeddings_api, "get_embeddings_by_dataset_id", - return_value=[embedding_0, embedding_1] + return_value=[embedding_0, embedding_1], ) as mock_get_embeddings_by_dataset_id: - embedding = self.api_workflow_client.get_embedding_data_by_name( name="embedding_0" ) @@ -72,33 +84,26 @@ def test_get_embedding_data_by_name__no_embedding_with_name(self) -> None: is_processed=False, ) with mock.patch.object( - self.api_workflow_client._embeddings_api, - "get_embeddings_by_dataset_id", - return_value=[embedding_0] - ) as mock_get_embeddings_by_dataset_id, \ - self.assertRaisesRegex( - ValueError, - "There are no embeddings with name 'other_embedding' for dataset with id 'dataset_0_id'." - ): - - self.api_workflow_client.get_embedding_data_by_name( - name="other_embedding" - ) + self.api_workflow_client._embeddings_api, + "get_embeddings_by_dataset_id", + return_value=[embedding_0], + ) as mock_get_embeddings_by_dataset_id, self.assertRaisesRegex( + ValueError, + "There are no embeddings with name 'other_embedding' for dataset with id 'dataset_0_id'.", + ): + self.api_workflow_client.get_embedding_data_by_name(name="other_embedding") mock_get_embeddings_by_dataset_id.assert_called_once_with( dataset_id="dataset_0_id", ) def test_download_embeddings_csv_by_id(self) -> None: with mock.patch.object( - self.api_workflow_client._embeddings_api, - "get_embeddings_csv_read_url_by_id", - return_value="read_url", - ) as mock_get_embeddings_csv_read_url_by_id, \ - mock.patch.object( - download, - "download_and_write_file" - ) as mock_download: - + self.api_workflow_client._embeddings_api, + "get_embeddings_csv_read_url_by_id", + return_value="read_url", + ) as mock_get_embeddings_csv_read_url_by_id, mock.patch.object( + download, "download_and_write_file" + ) as mock_download: self.api_workflow_client.download_embeddings_csv_by_id( embedding_id="embedding_id", output_path="embeddings.csv", @@ -114,21 +119,23 @@ def test_download_embeddings_csv_by_id(self) -> None: def test_download_embeddings_csv(self) -> None: with mock.patch.object( - self.api_workflow_client, - "get_all_embedding_data", - return_value=[DatasetEmbeddingData( + self.api_workflow_client, + "get_all_embedding_data", + return_value=[ + DatasetEmbeddingData( id="0", name="default_20221209_10h45m49s", created_at=0, is_processed=False, - )] - ) as mock_get_all_embedding_data, \ - mock.patch.object( - self.api_workflow_client, - "download_embeddings_csv_by_id", - ) as mock_download_embeddings_csv_by_id: - - self.api_workflow_client.download_embeddings_csv(output_path="embeddings.csv") + ) + ], + ) as mock_get_all_embedding_data, mock.patch.object( + self.api_workflow_client, + "download_embeddings_csv_by_id", + ) as mock_download_embeddings_csv_by_id: + self.api_workflow_client.download_embeddings_csv( + output_path="embeddings.csv" + ) mock_get_all_embedding_data.assert_called_once() mock_download_embeddings_csv_by_id.assert_called_once_with( embedding_id="0", @@ -137,36 +144,48 @@ def test_download_embeddings_csv(self) -> None: def test_download_embeddings_csv__no_default_embedding(self) -> None: with mock.patch.object( - self.api_workflow_client, - "get_all_embedding_data", - return_value=[], - ) as mock_get_all_embedding_data, \ - self.assertRaisesRegex( - RuntimeError, - "Could not find embeddings for dataset with id 'dataset_0_id'." - ): - - self.api_workflow_client.download_embeddings_csv(output_path="embeddings.csv") + self.api_workflow_client, + "get_all_embedding_data", + return_value=[], + ) as mock_get_all_embedding_data, self.assertRaisesRegex( + RuntimeError, + "Could not find embeddings for dataset with id 'dataset_0_id'.", + ): + self.api_workflow_client.download_embeddings_csv( + output_path="embeddings.csv" + ) mock_get_all_embedding_data.assert_called_once() - def test_export_label_box_data_rows_by_tag_name(self): - rows = self.api_workflow_client.export_label_box_data_rows_by_tag_name('initial-tag') + rows = self.api_workflow_client.export_label_box_data_rows_by_tag_name( + "initial-tag" + ) self.assertIsNotNone(rows) self.assertTrue(all(isinstance(row, dict) for row in rows)) def test_export_label_studio_tasks_by_tag_name(self): - tasks = self.api_workflow_client.export_label_studio_tasks_by_tag_name('initial-tag') + tasks = self.api_workflow_client.export_label_studio_tasks_by_tag_name( + "initial-tag" + ) self.assertIsNotNone(tasks) self.assertTrue(all(isinstance(task, dict) for task in tasks)) def test_export_tag_to_basic_filenames_and_read_urls(self): - filenames_and_read_urls = self.api_workflow_client.export_filenames_and_read_urls_by_tag_name('initial-tag') + filenames_and_read_urls = ( + self.api_workflow_client.export_filenames_and_read_urls_by_tag_name( + "initial-tag" + ) + ) self.assertIsNotNone(filenames_and_read_urls) - self.assertTrue(all(isinstance(filenames_and_read_url, dict) for filenames_and_read_url in filenames_and_read_urls)) + self.assertTrue( + all( + isinstance(filenames_and_read_url, dict) + for filenames_and_read_url in filenames_and_read_urls + ) + ) def test_export_filenames_by_tag_name(self): - filenames = self.api_workflow_client.export_filenames_by_tag_name('initial-tag') + filenames = self.api_workflow_client.export_filenames_by_tag_name("initial-tag") self.assertIsNotNone(filenames) self.assertTrue(isinstance(filenames, str)) @@ -196,6 +215,7 @@ def test__get_latest_default_embedding_data() -> None: ) assert embedding == embedding_1 + def test__get_latest_default_embedding_data__no_default_embedding() -> None: custom_embedding = DatasetEmbeddingData( id="0", diff --git a/tests/api_workflow/test_api_workflow_predictions.py b/tests/api_workflow/test_api_workflow_predictions.py index 87dd93271..235e5d1f6 100644 --- a/tests/api_workflow/test_api_workflow_predictions.py +++ b/tests/api_workflow/test_api_workflow_predictions.py @@ -3,10 +3,10 @@ from lightly.api import ApiWorkflowClient from lightly.api.prediction_singletons import PredictionSingletonClassificationRepr from lightly.openapi_generated.swagger_client import ( + PredictionsApi, PredictionTaskSchema, - TaskType, PredictionTaskSchemaCategory, - PredictionsApi, + TaskType, ) diff --git a/tests/api_workflow/test_api_workflow_selection.py b/tests/api_workflow/test_api_workflow_selection.py index 0d65e98e5..d5526d783 100644 --- a/tests/api_workflow/test_api_workflow_selection.py +++ b/tests/api_workflow/test_api_workflow_selection.py @@ -1,11 +1,12 @@ -from lightly.active_learning.config.selection_config import SelectionConfig, \ - SamplingConfig -from lightly.openapi_generated.swagger_client import TagData, SamplingMethod +from lightly.active_learning.config.selection_config import ( + SamplingConfig, + SelectionConfig, +) +from lightly.openapi_generated.swagger_client import SamplingMethod, TagData from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup class TestApiWorkflowSelection(MockedApiWorkflowSetup): - def test_sampling_deprecated(self): self.api_workflow_client.embedding_id = "embedding_id_xyz" @@ -13,7 +14,9 @@ def test_sampling_deprecated(self): sampling_config = SamplingConfig(SamplingMethod.CORESET, n_samples=32) with self.assertWarns(PendingDeprecationWarning): - new_tag_data = self.api_workflow_client.sampling(selection_config=sampling_config) + new_tag_data = self.api_workflow_client.sampling( + selection_config=sampling_config + ) assert isinstance(new_tag_data, TagData) def test_selection(self): @@ -21,13 +24,17 @@ def test_selection(self): selection_config = SelectionConfig() - new_tag_data = self.api_workflow_client.selection(selection_config=selection_config) + new_tag_data = self.api_workflow_client.selection( + selection_config=selection_config + ) assert isinstance(new_tag_data, TagData) def test_runtime_error_on_existing_tag_name(self): self.api_workflow_client.embedding_id = "embedding_id_xyz" - selection_config = SelectionConfig(name='initial-tag') + selection_config = SelectionConfig(name="initial-tag") with self.assertRaises(RuntimeError): - new_tag_data = self.api_workflow_client.selection(selection_config=selection_config) + new_tag_data = self.api_workflow_client.selection( + selection_config=selection_config + ) diff --git a/tests/api_workflow/test_api_workflow_tags.py b/tests/api_workflow/test_api_workflow_tags.py index 722ea177a..98d709177 100644 --- a/tests/api_workflow/test_api_workflow_tags.py +++ b/tests/api_workflow/test_api_workflow_tags.py @@ -11,11 +11,13 @@ from lightly.active_learning.scorers.classification import ScorerClassification from lightly.openapi_generated.swagger_client import SamplingMethod from lightly.openapi_generated.swagger_client.models.tag_data import TagData -from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowClient, MockedApiWorkflowSetup +from tests.api_workflow.mocked_api_workflow_client import ( + MockedApiWorkflowClient, + MockedApiWorkflowSetup, +) class TestApiWorkflowTags(MockedApiWorkflowSetup): - def setUp(self) -> None: lightly.api.api_workflow_client.__version__ = lightly.__version__ warnings.filterwarnings("ignore", category=UserWarning) @@ -31,10 +33,10 @@ def tearDown(self) -> None: def test_get_all_tags(self): self.api_workflow_client.get_all_tags() - + def test_get_tag_name(self): self.api_workflow_client.get_tag_by_name(tag_name=self.valid_tag_name) - + def test_get_tag_name_nonexisting(self): with self.assertRaises(ValueError): self.api_workflow_client.get_tag_by_name(tag_name=self.invalid_tag_name) @@ -43,35 +45,48 @@ def test_get_tag_id(self): self.api_workflow_client.get_tag_by_id(tag_id=self.valid_tag_id) def test_get_filenames_in_tag(self): - tag_data = self.api_workflow_client.get_tag_by_name(tag_name=self.valid_tag_name) + tag_data = self.api_workflow_client.get_tag_by_name( + tag_name=self.valid_tag_name + ) self.api_workflow_client.get_filenames_in_tag(tag_data) def test_get_filenames_in_tag_with_filenames(self): - tag_data = self.api_workflow_client.get_tag_by_name(tag_name=self.valid_tag_name) + tag_data = self.api_workflow_client.get_tag_by_name( + tag_name=self.valid_tag_name + ) filenames = self.api_workflow_client.get_filenames() self.api_workflow_client.get_filenames_in_tag(tag_data, filenames) def test_get_filenames_in_tag_exclude_parent(self): - tag_data = self.api_workflow_client.get_tag_by_name(tag_name=self.valid_tag_name) + tag_data = self.api_workflow_client.get_tag_by_name( + tag_name=self.valid_tag_name + ) self.api_workflow_client.get_filenames_in_tag(tag_data, exclude_parent_tag=True) def test_get_filenames_in_tag_with_filenames_exclude_parent(self): - tag_data = self.api_workflow_client.get_tag_by_name(tag_name=self.valid_tag_name) + tag_data = self.api_workflow_client.get_tag_by_name( + tag_name=self.valid_tag_name + ) filenames = self.api_workflow_client.get_filenames() - self.api_workflow_client.get_filenames_in_tag(tag_data, filenames, exclude_parent_tag=True) + self.api_workflow_client.get_filenames_in_tag( + tag_data, filenames, exclude_parent_tag=True + ) def test_create_tag_from_filenames(self): filenames_server = self.api_workflow_client.get_filenames() filenames_new_tag = filenames_server[:10][::3] - self.api_workflow_client.create_tag_from_filenames(filenames_new_tag, new_tag_name="funny_new_tag") + self.api_workflow_client.create_tag_from_filenames( + filenames_new_tag, new_tag_name="funny_new_tag" + ) def test_create_tag_from_filenames(self): filenames_server = self.api_workflow_client.get_filenames() filenames_new_tag = filenames_server[:10][::3] - filenames_new_tag[0] = 'some-random-non-existing-filename.jpg' + filenames_new_tag[0] = "some-random-non-existing-filename.jpg" with self.assertRaises(RuntimeError): - self.api_workflow_client.create_tag_from_filenames(filenames_new_tag, new_tag_name="funny_new_tag") + self.api_workflow_client.create_tag_from_filenames( + filenames_new_tag, new_tag_name="funny_new_tag" + ) def test_delete_tag_by_id(self): self.api_workflow_client.delete_tag_by_id(self.valid_tag_id) - diff --git a/tests/api_workflow/test_api_workflow_upload_custom_metadata.py b/tests/api_workflow/test_api_workflow_upload_custom_metadata.py index 4c9b28958..96b7d8367 100644 --- a/tests/api_workflow/test_api_workflow_upload_custom_metadata.py +++ b/tests/api_workflow/test_api_workflow_upload_custom_metadata.py @@ -1,38 +1,39 @@ import copy import json import os +import pathlib import random import tempfile -import pathlib from typing import List +import cv2 import numpy as np -from lightly.openapi_generated.swagger_client.models.sample_data_modes import SampleDataModes import torchvision -from lightly.api.api_workflow_upload_metadata import \ - InvalidCustomMetadataWarning +from lightly.api.api_workflow_upload_metadata import InvalidCustomMetadataWarning from lightly.api.utils import MAXIMUM_FILENAME_LENGTH from lightly.data.dataset import LightlyDataset from lightly.openapi_generated.swagger_client import SampleData +from lightly.openapi_generated.swagger_client.models.sample_data_modes import ( + SampleDataModes, +) from lightly.utils.io import COCO_ANNOTATION_KEYS - -from tests.api_workflow.mocked_api_workflow_client import \ - MockedApiWorkflowSetup - -import cv2 +from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup class TestApiWorkflowUploadCustomMetadata(MockedApiWorkflowSetup): - def create_fake_dataset(self, n_data: int = 10, sample_names=None): - self.dataset = torchvision.datasets.FakeData(size=n_data, - image_size=(3, 32, 32)) + self.dataset = torchvision.datasets.FakeData( + size=n_data, image_size=(3, 32, 32) + ) self.folder_path = tempfile.mkdtemp() - image_extension = '.jpg' - sample_names = sample_names if sample_names is not None else [ - f'img_{i}{image_extension}' for i in range(n_data)] + image_extension = ".jpg" + sample_names = ( + sample_names + if sample_names is not None + else [f"img_{i}{image_extension}" for i in range(n_data)] + ) for sample_idx in range(n_data): data = self.dataset[sample_idx] sample_name = sample_names[sample_idx] @@ -40,10 +41,13 @@ def create_fake_dataset(self, n_data: int = 10, sample_names=None): data[0].save(path) coco_json = dict() - coco_json[COCO_ANNOTATION_KEYS.images] = [{'id': i, 'file_name': fname} for i, fname in - enumerate(sample_names)] - coco_json[COCO_ANNOTATION_KEYS.custom_metadata] = [{'id': i, 'image_id': i, 'custom_metadata': 0} - for i, _ in enumerate(sample_names)] + coco_json[COCO_ANNOTATION_KEYS.images] = [ + {"id": i, "file_name": fname} for i, fname in enumerate(sample_names) + ] + coco_json[COCO_ANNOTATION_KEYS.custom_metadata] = [ + {"id": i, "image_id": i, "custom_metadata": 0} + for i, _ in enumerate(sample_names) + ] self.custom_metadata_file = tempfile.NamedTemporaryFile(mode="w+") json.dump(coco_json, self.custom_metadata_file) @@ -51,27 +55,31 @@ def create_fake_dataset(self, n_data: int = 10, sample_names=None): def test_upload_custom_metadata_one_step(self): self.create_fake_dataset() - with open(self.custom_metadata_file.name, 'r') as f: + with open(self.custom_metadata_file.name, "r") as f: custom_metadata = json.load(f) - self.api_workflow_client.upload_dataset(input=self.folder_path, custom_metadata=custom_metadata) + self.api_workflow_client.upload_dataset( + input=self.folder_path, custom_metadata=custom_metadata + ) def test_upload_custom_metadata_two_steps_verbose(self): self.create_fake_dataset() self.api_workflow_client.upload_dataset(input=self.folder_path) - with open(self.custom_metadata_file.name, 'r') as f: + with open(self.custom_metadata_file.name, "r") as f: custom_metadata = json.load(f) - self.api_workflow_client.upload_custom_metadata(custom_metadata, verbose=True) + self.api_workflow_client.upload_custom_metadata( + custom_metadata, verbose=True + ) def test_upload_custom_metadata_two_steps(self): self.create_fake_dataset() self.api_workflow_client.upload_dataset(input=self.folder_path) - with open(self.custom_metadata_file.name, 'r') as f: + with open(self.custom_metadata_file.name, "r") as f: custom_metadata = json.load(f) self.api_workflow_client.upload_custom_metadata(custom_metadata) def test_upload_custom_metadata_before_uploading_samples(self): self.create_fake_dataset() - with open(self.custom_metadata_file.name, 'r') as f: + with open(self.custom_metadata_file.name, "r") as f: custom_metadata = json.load(f) with self.assertWarns(InvalidCustomMetadataWarning): self.api_workflow_client.upload_custom_metadata(custom_metadata) @@ -79,20 +87,18 @@ def test_upload_custom_metadata_before_uploading_samples(self): def test_upload_custom_metadata_with_append(self): self.create_fake_dataset() self.api_workflow_client.upload_dataset(input=self.folder_path) - with open(self.custom_metadata_file.name, 'r') as f: + with open(self.custom_metadata_file.name, "r") as f: custom_metadata = json.load(f) - custom_metadata['metadata'] = custom_metadata['metadata'][:3] + custom_metadata["metadata"] = custom_metadata["metadata"][:3] self.api_workflow_client.upload_custom_metadata(custom_metadata) - def subtest_upload_custom_metadata( - self, - image_ids_images: List[int], - image_ids_annotations: List[int], - filenames_server: List[str] + self, + image_ids_images: List[int], + image_ids_annotations: List[int], + filenames_server: List[str], ): - - def get_samples_partial_by_dataset_id(*args, **kwargs)-> List[SampleDataModes]: + def get_samples_partial_by_dataset_id(*args, **kwargs) -> List[SampleDataModes]: samples = [ SampleDataModes( id="dfd", @@ -101,52 +107,54 @@ def get_samples_partial_by_dataset_id(*args, **kwargs)-> List[SampleDataModes]: for filename in filenames_server ] return samples - self.api_workflow_client._samples_api.get_samples_partial_by_dataset_id = get_samples_partial_by_dataset_id - filenames_metadata = [f"img_{id}.jpg" for id in image_ids_annotations] - with self.subTest(image_ids_images=image_ids_images, - image_ids_annotations=image_ids_annotations, - filenames_server=filenames_server + self.api_workflow_client._samples_api.get_samples_partial_by_dataset_id = ( + get_samples_partial_by_dataset_id + ) + filenames_metadata = [f"img_{id}.jpg" for id in image_ids_annotations] + with self.subTest( + image_ids_images=image_ids_images, + image_ids_annotations=image_ids_annotations, + filenames_server=filenames_server, ): custom_metadata = { COCO_ANNOTATION_KEYS.images: [ - { + { COCO_ANNOTATION_KEYS.images_id: id, - COCO_ANNOTATION_KEYS.images_filename: filename} + COCO_ANNOTATION_KEYS.images_filename: filename, + } for id, filename in zip(image_ids_images, filenames_metadata) ], COCO_ANNOTATION_KEYS.custom_metadata: [ { COCO_ANNOTATION_KEYS.custom_metadata_image_id: id, - "any_key": "any_value" + "any_key": "any_value", } for id in image_ids_annotations ], } # The annotations must only have image_ids that are also in the images. - custom_metadata_malformatted = \ + custom_metadata_malformatted = ( len(set(image_ids_annotations) - set(image_ids_images)) > 0 + ) # Only custom metadata whose filename is on the server can be uploaded. - metatadata_without_filenames_on_server = \ + metatadata_without_filenames_on_server = ( len(set(filenames_metadata) - set(filenames_server)) > 0 + ) - if metatadata_without_filenames_on_server \ - or custom_metadata_malformatted: + if metatadata_without_filenames_on_server or custom_metadata_malformatted: with self.assertWarns(InvalidCustomMetadataWarning): - self.api_workflow_client.upload_custom_metadata( - custom_metadata - ) + self.api_workflow_client.upload_custom_metadata(custom_metadata) else: - self.api_workflow_client.upload_custom_metadata( - custom_metadata - ) - + self.api_workflow_client.upload_custom_metadata(custom_metadata) def test_upload_custom_metadata(self): potential_image_ids_images = [[0, 1, 2], [-1, 1], list(range(10)), [-3]] potential_image_ids_annotations = potential_image_ids_images - potential_filenames_server = [[f"img_{id}.jpg" for id in ids] for ids in potential_image_ids_images] + potential_filenames_server = [ + [f"img_{id}.jpg" for id in ids] for ids in potential_image_ids_images + ] self.create_fake_dataset() self.api_workflow_client.upload_dataset(input=self.folder_path) @@ -155,10 +163,5 @@ def test_upload_custom_metadata(self): for image_ids_annotations in potential_image_ids_annotations: for filenames_server in potential_filenames_server: self.subtest_upload_custom_metadata( - image_ids_images, - image_ids_annotations, - filenames_server + image_ids_images, image_ids_annotations, filenames_server ) - - - diff --git a/tests/api_workflow/test_api_workflow_upload_dataset.py b/tests/api_workflow/test_api_workflow_upload_dataset.py index cab49b432..3a8f5da45 100644 --- a/tests/api_workflow/test_api_workflow_upload_dataset.py +++ b/tests/api_workflow/test_api_workflow_upload_dataset.py @@ -1,21 +1,21 @@ import copy import os +import pathlib import random import tempfile -import pathlib +import warnings +import cv2 import numpy as np -from lightly.openapi_generated.swagger_client.models.sample_partial_mode import SamplePartialMode import torchvision from lightly.api.utils import MAXIMUM_FILENAME_LENGTH from lightly.data.dataset import LightlyDataset - +from lightly.openapi_generated.swagger_client.models.sample_partial_mode import ( + SamplePartialMode, +) from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup -import cv2 -import warnings - class TestApiWorkflowUploadDataset(MockedApiWorkflowSetup): def setUp(self) -> None: @@ -24,19 +24,23 @@ def setUp(self) -> None: self.n_data = 100 self.create_fake_dataset() self.api_workflow_client._tags_api.no_tags = 0 - def tearDown(self) -> None: warnings.resetwarnings() def create_fake_dataset(self, length_of_filepath: int = -1, sample_names=None): n_data = self.n_data if sample_names is None else len(sample_names) - self.dataset = torchvision.datasets.FakeData(size=n_data, - image_size=(3, 32, 32)) + self.dataset = torchvision.datasets.FakeData( + size=n_data, image_size=(3, 32, 32) + ) self.folder_path = tempfile.mkdtemp() - image_extension = '.jpg' - sample_names = sample_names if sample_names is not None else [f'img_{i}{image_extension}' for i in range(n_data)] + image_extension = ".jpg" + sample_names = ( + sample_names + if sample_names is not None + else [f"img_{i}{image_extension}" for i in range(n_data)] + ) for sample_idx in range(n_data): data = self.dataset[sample_idx] sample_name = sample_names[sample_idx] @@ -45,13 +49,17 @@ def create_fake_dataset(self, length_of_filepath: int = -1, sample_names=None): if length_of_filepath > len(path): assert path.endswith(image_extension) n_missing_chars = length_of_filepath - len(path) - path = path[:-len(image_extension)] + 'x' * n_missing_chars + image_extension + path = ( + path[: -len(image_extension)] + + "x" * n_missing_chars + + image_extension + ) data[0].save(path) def corrupt_fake_dataset(self): n_data = self.n_data - sample_names = [f'img_{i}.jpg' for i in range(n_data)] + sample_names = [f"img_{i}.jpg" for i in range(n_data)] for sample_name in sample_names: pathlib.Path(os.path.join(self.folder_path, sample_name)).touch() @@ -61,7 +69,9 @@ def test_upload_dataset_over_quota(self): def get_quota_reduced(): return str(quota) - self.api_workflow_client._quota_api.get_quota_maximum_dataset_size = get_quota_reduced + self.api_workflow_client._quota_api.get_quota_maximum_dataset_size = ( + get_quota_reduced + ) with self.assertRaises(ValueError): self.api_workflow_client.upload_dataset(input=self.folder_path) @@ -91,23 +101,30 @@ def test_filename_length_lower(self): self.create_fake_dataset(length_of_filepath=MAXIMUM_FILENAME_LENGTH - 1) self.api_workflow_client.upload_dataset(input=self.folder_path) - samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id(dataset_id="does not matter") + samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id( + dataset_id="does not matter" + ) self.assertEqual(self.n_data, len(samples)) def test_filename_length_upper(self): self.create_fake_dataset(length_of_filepath=MAXIMUM_FILENAME_LENGTH + 10) self.api_workflow_client.upload_dataset(input=self.folder_path) - samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id(dataset_id="does not matter") + samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id( + dataset_id="does not matter" + ) self.assertEqual(0, len(samples)) - def create_fake_video_dataset(self, n_videos=5, n_frames_per_video=10, w=32, h=32, c=3, extension='avi'): - + def create_fake_video_dataset( + self, n_videos=5, n_frames_per_video=10, w=32, h=32, c=3, extension="avi" + ): self.video_input_dir = tempfile.mkdtemp() - self.frames = (np.random.randn(n_frames_per_video, w, h, c) * 255).astype(np.uint8) + self.frames = (np.random.randn(n_frames_per_video, w, h, c) * 255).astype( + np.uint8 + ) for i in range(n_videos): - path = os.path.join(self.video_input_dir, f'output-{i}.{extension}') + path = os.path.join(self.video_input_dir, f"output-{i}.{extension}") out = cv2.VideoWriter(path, 0, 1, (w, h)) for frame in self.frames: out.write(frame) @@ -133,7 +150,9 @@ def failing_upload_sample(*args, **kwargs): self.api_workflow_client.upload_dataset(input=self.folder_path) # Ensure that not all samples were uploaded - samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id(dataset_id="does not matter") + samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id( + dataset_id="does not matter" + ) self.assertLess(len(samples), self.n_data) # Upload without failing uploads @@ -141,13 +160,13 @@ def failing_upload_sample(*args, **kwargs): self.api_workflow_client.upload_dataset(input=self.folder_path) # Ensure that now all samples were uploaded exactly once - samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id(dataset_id="does not matter") + samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id( + dataset_id="does not matter" + ) self.assertEqual(self.n_data, len(samples)) - def test_upload_dataset_twice_with_overlap(self): - - all_sample_names = [f'img_upload_twice_{i}.jpg' for i in range(10)] + all_sample_names = [f"img_upload_twice_{i}.jpg" for i in range(10)] # upload first part of the dataset (sample_0 - sample_6) self.create_fake_dataset(sample_names=all_sample_names[:7]) @@ -158,21 +177,24 @@ def test_upload_dataset_twice_with_overlap(self): self.api_workflow_client.upload_dataset(input=self.folder_path) # always returns all samples so dataset_id doesn't matter - samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id(dataset_id='') + samples = self.api_workflow_client._samples_api.get_samples_by_dataset_id( + dataset_id="" + ) # assert the filenames are the same self.assertListEqual( - sorted(all_sample_names), + sorted(all_sample_names), sorted([s.file_name for s in samples]), ) # assert partially getting the samples fileNames returns the same data - samples_file_names = self.api_workflow_client._samples_api.get_samples_partial_by_dataset_id( - dataset_id='', - mode=SamplePartialMode.FULL + samples_file_names = ( + self.api_workflow_client._samples_api.get_samples_partial_by_dataset_id( + dataset_id="", mode=SamplePartialMode.FULL + ) ) self.assertListEqual( - sorted(all_sample_names), + sorted(all_sample_names), sorted([s.file_name for s in samples_file_names]), ) diff --git a/tests/api_workflow/test_api_workflow_upload_embeddings.py b/tests/api_workflow/test_api_workflow_upload_embeddings.py index be8fc183d..7c38bcfa4 100644 --- a/tests/api_workflow/test_api_workflow_upload_embeddings.py +++ b/tests/api_workflow/test_api_workflow_upload_embeddings.py @@ -1,56 +1,62 @@ -from json import load -import os -import io import csv +import io +import os import random import tempfile +from json import load import numpy as np -from lightly.utils.io import save_embeddings, load_embeddings, INVALID_FILENAME_CHARACTERS import lightly -from tests.api_workflow.mocked_api_workflow_client import \ - MockedApiWorkflowSetup, N_FILES_ON_SERVER +from lightly.utils.io import ( + INVALID_FILENAME_CHARACTERS, + load_embeddings, + save_embeddings, +) +from tests.api_workflow.mocked_api_workflow_client import ( + N_FILES_ON_SERVER, + MockedApiWorkflowSetup, +) class TestApiWorkflowUploadEmbeddings(MockedApiWorkflowSetup): - - def create_fake_embeddings(self, - n_data, - n_data_start: int = 0, - n_dims: int = 32, - special_name_first_sample: bool = False, - special_char_in_first_filename: str = None): + def create_fake_embeddings( + self, + n_data, + n_data_start: int = 0, + n_dims: int = 32, + special_name_first_sample: bool = False, + special_char_in_first_filename: str = None, + ): # create fake embeddings self.folder_path = tempfile.mkdtemp() - self.path_to_embeddings = os.path.join( - self.folder_path, - 'embeddings.csv' - ) + self.path_to_embeddings = os.path.join(self.folder_path, "embeddings.csv") - self.sample_names = [f'img_{i}.jpg' for i in range(n_data_start, n_data_start + n_data)] + self.sample_names = [ + f"img_{i}.jpg" for i in range(n_data_start, n_data_start + n_data) + ] if special_name_first_sample: self.sample_names[0] = "bliblablub" if special_char_in_first_filename: - self.sample_names[0] = f'_{special_char_in_first_filename}' \ - f'{self.sample_names[0]}' + self.sample_names[0] = ( + f"_{special_char_in_first_filename}" f"{self.sample_names[0]}" + ) labels = [0] * len(self.sample_names) save_embeddings( self.path_to_embeddings, np.random.randn(n_data, n_dims), labels, - self.sample_names + self.sample_names, ) - - def t_ester_upload_embedding(self, - n_data, - n_dims: int = 32, - special_name_first_sample: bool = False, - special_char_in_first_filename: str = None, - name: str = "embedding_xyz" - ): - + def t_ester_upload_embedding( + self, + n_data, + n_dims: int = 32, + special_name_first_sample: bool = False, + special_char_in_first_filename: str = None, + name: str = "embedding_xyz", + ): self.create_fake_embeddings( n_data, n_dims=n_dims, @@ -59,13 +65,17 @@ def t_ester_upload_embedding(self, ) # perform the workflow to upload the embeddings - self.api_workflow_client.upload_embeddings(path_to_embeddings_csv=self.path_to_embeddings, name=name) + self.api_workflow_client.upload_embeddings( + path_to_embeddings_csv=self.path_to_embeddings, name=name + ) self.api_workflow_client.n_dims_embeddings_on_server = n_dims def test_upload_success(self): n_data = len(self.api_workflow_client._mappings_api.sample_names) self.t_ester_upload_embedding(n_data=n_data) - filepath_embeddings_sorted = os.path.join(self.folder_path, "embeddings_sorted.csv") + filepath_embeddings_sorted = os.path.join( + self.folder_path, "embeddings_sorted.csv" + ) self.assertFalse(os.path.isfile(filepath_embeddings_sorted)) def test_upload_wrong_length(self): @@ -84,20 +94,19 @@ def test_upload_comma_filenames(self): with self.subTest(msg=f"invalid_char: {invalid_char}"): with self.assertRaises(ValueError): self.t_ester_upload_embedding( - n_data=n_data, - special_char_in_first_filename=invalid_char) + n_data=n_data, special_char_in_first_filename=invalid_char + ) def test_set_embedding_id_default(self): self.api_workflow_client.set_embedding_id_to_latest() - self.assertEqual(self.api_workflow_client.embedding_id, 'embedding_id_xyz') - + self.assertEqual(self.api_workflow_client.embedding_id, "embedding_id_xyz") + def test_set_embedding_id_no_embeddings(self): self.api_workflow_client._embeddings_api.embeddings = [] with self.assertRaises(RuntimeError): self.api_workflow_client.set_embedding_id_to_latest() def test_upload_existing_embedding(self): - # first upload embeddings n_data = len(self.api_workflow_client._mappings_api.sample_names) self.t_ester_upload_embedding(n_data=n_data) @@ -110,11 +119,10 @@ def test_upload_existing_embedding(self): self.api_workflow_client.append_embeddings( self.path_to_embeddings, - 'embedding_id_xyz_2', + "embedding_id_xyz_2", ) def test_append_embeddings_with_overlap(self): - # mock the embeddings on the server n_data_server = len(self.api_workflow_client._mappings_api.sample_names) self.api_workflow_client.n_dims_embeddings_on_server = 32 @@ -122,7 +130,9 @@ def test_append_embeddings_with_overlap(self): # create new local embeddings overlapping with server embeddings n_data_start_local = n_data_server // 3 n_data_local = n_data_server * 2 - self.create_fake_embeddings(n_data=n_data_local, n_data_start=n_data_start_local) + self.create_fake_embeddings( + n_data=n_data_local, n_data_start=n_data_start_local + ) """ Assumptions: @@ -146,12 +156,13 @@ def test_append_embeddings_with_overlap(self): # append the local embeddings to the server embeddings self.api_workflow_client.append_embeddings( self.path_to_embeddings, - 'embedding_id_xyz_2', + "embedding_id_xyz_2", ) # load the new (appended) embeddings - _, labels_appended, filenames_appended = \ - load_embeddings(self.path_to_embeddings) + _, labels_appended, filenames_appended = load_embeddings( + self.path_to_embeddings + ) # define the expected filenames and labels self.create_fake_embeddings(n_data=n_data_local + n_data_start_local) @@ -162,31 +173,27 @@ def test_append_embeddings_with_overlap(self): self.assertListEqual(filenames_appended, filenames_expected) self.assertListEqual(labels_appended, labels_expected) - def test_append_embeddings_different_shape(self): - # first upload embeddings n_data = len(self.api_workflow_client._mappings_api.sample_names) self.t_ester_upload_embedding(n_data=n_data) # create a new set of embeddings - self.create_fake_embeddings(10, n_dims=16) # default is 32 + self.create_fake_embeddings(10, n_dims=16) # default is 32 self.api_workflow_client.n_dims_embeddings_on_server = 32 with self.assertRaises(RuntimeError): self.api_workflow_client.append_embeddings( self.path_to_embeddings, - 'embedding_id_xyz_2', + "embedding_id_xyz_2", ) - def tearDown(self) -> None: for filename in ["embeddings.csv", "embeddings_sorted.csv"]: - if hasattr(self, 'folder_path'): + if hasattr(self, "folder_path"): try: filepath = os.path.join(self.folder_path, filename) os.remove(filepath) except FileNotFoundError: pass - diff --git a/tests/cli/test_cli_crop.py b/tests/cli/test_cli_crop.py index fb9776d54..d28d414fd 100644 --- a/tests/cli/test_cli_crop.py +++ b/tests/cli/test_cli_crop.py @@ -1,8 +1,8 @@ import os +import random import re import sys import tempfile -import random import torchvision import yaml @@ -11,52 +11,65 @@ import lightly from lightly.active_learning.utils import BoundingBox from lightly.data import LightlyDataset -from lightly.utils.cropping.crop_image_by_bounding_boxes import crop_dataset_by_bounding_boxes_and_save +from lightly.utils.cropping.crop_image_by_bounding_boxes import ( + crop_dataset_by_bounding_boxes_and_save, +) from lightly.utils.cropping.read_yolo_label_file import read_yolo_label_file -from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup, MockedApiWorkflowClient +from tests.api_workflow.mocked_api_workflow_client import ( + MockedApiWorkflowClient, + MockedApiWorkflowSetup, +) class TestCLICrop(MockedApiWorkflowSetup): - @classmethod def setUpClass(cls) -> None: - sys.modules["lightly.cli.upload_cli"].ApiWorkflowClient = MockedApiWorkflowClient + sys.modules[ + "lightly.cli.upload_cli" + ].ApiWorkflowClient = MockedApiWorkflowClient def setUp(self): MockedApiWorkflowSetup.setUp(self) self.create_fake_dataset() self.create_fake_yolo_labels() with initialize(config_path="../../lightly/cli/config", job_name="test_app"): - self.cfg = compose(config_name="config", overrides=[ - f"input_dir={self.folder_path}", - f"label_dir={self.folder_path_labels}", - f"output_dir={tempfile.mkdtemp()}", - f"label_names_file={self.label_names_file}" - ]) + self.cfg = compose( + config_name="config", + overrides=[ + f"input_dir={self.folder_path}", + f"label_dir={self.folder_path_labels}", + f"output_dir={tempfile.mkdtemp()}", + f"label_names_file={self.label_names_file}", + ], + ) def create_fake_dataset(self): n_data = len(self.api_workflow_client.get_filenames()) - self.dataset = torchvision.datasets.FakeData(size=n_data, image_size=(3, 32, 32)) + self.dataset = torchvision.datasets.FakeData( + size=n_data, image_size=(3, 32, 32) + ) self.folder_path = tempfile.mkdtemp() - sample_names = [f'img_{i}.jpg' for i in range(n_data)] + sample_names = [f"img_{i}.jpg" for i in range(n_data)] self.sample_names = sample_names for sample_idx in range(n_data): data = self.dataset[sample_idx] path = os.path.join(self.folder_path, sample_names[sample_idx]) data[0].save(path) - def create_fake_yolo_labels(self, no_classes: int = 10, objects_per_image: int = 13): + def create_fake_yolo_labels( + self, no_classes: int = 10, objects_per_image: int = 13 + ): random.seed(42) n_data = len(self.api_workflow_client.get_filenames()) self.folder_path_labels = tempfile.mkdtemp() - label_names = [f'img_{i}.txt' for i in range(n_data)] + label_names = [f"img_{i}.txt" for i in range(n_data)] self.label_names = label_names for filename_label in label_names: path = os.path.join(self.folder_path_labels, filename_label) - with open(path, 'a') as the_file: + with open(path, "a") as the_file: for i in range(objects_per_image): class_id = random.randint(0, no_classes - 1) x = random.uniform(0.1, 0.9) @@ -66,8 +79,10 @@ def create_fake_yolo_labels(self, no_classes: int = 10, objects_per_image: int = line = f"{class_id} {x} {y} {w} {h}\n" the_file.write(line) yaml_dict = {"names": [f"class{i}" for i in range(no_classes)]} - self.label_names_file = tempfile.mktemp('.yaml', 'data', dir=self.folder_path_labels) - with open(self.label_names_file, 'w') as file: + self.label_names_file = tempfile.mktemp( + ".yaml", "data", dir=self.folder_path_labels + ) + with open(self.label_names_file, "w") as file: yaml.dump(yaml_dict, file) def parse_cli_string(self, cli_words: str): @@ -77,18 +92,18 @@ def parse_cli_string(self, cli_words: str): dict_keys = cli_words[0::2] dict_values = cli_words[1::2] for key, value in zip(dict_keys, dict_values): - value = value.strip('\"') - value = value.strip('\'') + value = value.strip('"') + value = value.strip("'") self.cfg[key] = value def test_parse_cli_string(self): cli_string = "lightly-crop label_dir=/blub" self.parse_cli_string(cli_string) - self.assertEqual(self.cfg['label_dir'], '/blub') + self.assertEqual(self.cfg["label_dir"], "/blub") def test_read_yolo(self): for f in os.listdir(self.cfg.label_dir): - if f.endswith('.txt'): + if f.endswith(".txt"): filepath = os.path.join(self.cfg.label_dir, f) read_yolo_label_file(filepath, 0.1) @@ -100,28 +115,32 @@ def test_crop_dataset_by_bounding_boxes_and_save(self): class_indices_list_list = [[1]] * no_files class_names = ["class_0", "class_1"] with self.subTest("all_correct"): - crop_dataset_by_bounding_boxes_and_save(dataset, - output_dir, - bounding_boxes_list_list, - class_indices_list_list, - class_names) + crop_dataset_by_bounding_boxes_and_save( + dataset, + output_dir, + bounding_boxes_list_list, + class_indices_list_list, + class_names, + ) with self.subTest("wrong length of bounding_boxes_list_list"): with self.assertRaises(ValueError): - crop_dataset_by_bounding_boxes_and_save(dataset, - output_dir, - bounding_boxes_list_list[:-1], - class_indices_list_list, - class_names) + crop_dataset_by_bounding_boxes_and_save( + dataset, + output_dir, + bounding_boxes_list_list[:-1], + class_indices_list_list, + class_names, + ) with self.subTest("wrong internal length of class_indices_list_list"): with self.assertWarns(UserWarning): class_indices_list_list[0] *= 2 - crop_dataset_by_bounding_boxes_and_save(dataset, - output_dir, - bounding_boxes_list_list, - class_indices_list_list, - class_names) - - + crop_dataset_by_bounding_boxes_and_save( + dataset, + output_dir, + bounding_boxes_list_list, + class_indices_list_list, + class_names, + ) def test_crop_with_class_names(self): cli_string = "lightly-crop crop_padding=0.1" @@ -131,5 +150,5 @@ def test_crop_with_class_names(self): def test_crop_without_class_names(self): cli_string = "lightly-crop crop_padding=0.1" self.parse_cli_string(cli_string) - self.cfg['label_names_file'] = '' + self.cfg["label_names_file"] = "" lightly.cli.crop_cli(self.cfg) diff --git a/tests/cli/test_cli_download.py b/tests/cli/test_cli_download.py index a55d43a1e..50cbc1665 100644 --- a/tests/cli/test_cli_download.py +++ b/tests/cli/test_cli_download.py @@ -7,27 +7,31 @@ from hydra.experimental import compose, initialize import lightly -from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup -from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowClient +from tests.api_workflow.mocked_api_workflow_client import ( + MockedApiWorkflowClient, + MockedApiWorkflowSetup, +) class TestCLIDownload(MockedApiWorkflowSetup): - @classmethod def setUpClass(cls) -> None: - sys.modules["lightly.cli.download_cli"].ApiWorkflowClient = MockedApiWorkflowClient + sys.modules[ + "lightly.cli.download_cli" + ].ApiWorkflowClient = MockedApiWorkflowClient def setUp(self): with initialize(config_path="../../lightly/cli/config", job_name="test_app"): self.cfg = compose(config_name="config") def create_fake_dataset(self, n_data: int = 5): - self.dataset = torchvision.datasets.FakeData(size=n_data, - image_size=(3, 32, 32)) + self.dataset = torchvision.datasets.FakeData( + size=n_data, image_size=(3, 32, 32) + ) self.input_dir = tempfile.mkdtemp() - sample_names = [f'img_{i}.jpg' for i in range(n_data)] + sample_names = [f"img_{i}.jpg" for i in range(n_data)] self.sample_names = sample_names for sample_idx in range(n_data): data = self.dataset[sample_idx] @@ -37,19 +41,19 @@ def create_fake_dataset(self, n_data: int = 5): self.output_dir = tempfile.mkdtemp() def parse_cli_string(self, cli_words: str): - cli_words = cli_words.replace('lightly-download ', '') - overrides = cli_words.split(' ') - with initialize(config_path='../../lightly/cli/config/'): + cli_words = cli_words.replace("lightly-download ", "") + overrides = cli_words.split(" ") + with initialize(config_path="../../lightly/cli/config/"): self.cfg = compose( - config_name='config', + config_name="config", overrides=overrides, ) def test_parse_cli_string(self): cli_string = "lightly-download token='123' dataset_id='XYZ'" self.parse_cli_string(cli_string) - assert self.cfg["token"] == '123' - assert self.cfg["dataset_id"] == 'XYZ' + assert self.cfg["token"] == "123" + assert self.cfg["dataset_id"] == "XYZ" def test_download_base(self): cli_string = "lightly-download token='123' dataset_id='XYZ'" @@ -57,12 +61,16 @@ def test_download_base(self): lightly.cli.download_cli(self.cfg) def test_download_tag_name(self): - cli_string = "lightly-download token='123' dataset_id='XYZ' tag_name='selected_tag_xyz'" + cli_string = ( + "lightly-download token='123' dataset_id='XYZ' tag_name='selected_tag_xyz'" + ) self.parse_cli_string(cli_string) lightly.cli.download_cli(self.cfg) def test_download_tag_name_nonexisting(self): - cli_string = "lightly-download token='123' dataset_id='XYZ' tag_name='nonexisting_xyz'" + cli_string = ( + "lightly-download token='123' dataset_id='XYZ' tag_name='nonexisting_xyz'" + ) self.parse_cli_string(cli_string) with self.assertRaises(ValueError): lightly.cli.download_cli(self.cfg) @@ -92,15 +100,19 @@ def test_download_no_dataset_id(self): def test_download_copy_from_input_to_output_dir(self): self.create_fake_dataset(n_data=100) - cli_string = f"lightly-download token='123' dataset_id='dataset_1_id' tag_name='selected_tag_xyz' " \ - f"input_dir={self.input_dir} output_dir={self.output_dir}" + cli_string = ( + f"lightly-download token='123' dataset_id='dataset_1_id' tag_name='selected_tag_xyz' " + f"input_dir={self.input_dir} output_dir={self.output_dir}" + ) self.parse_cli_string(cli_string) lightly.cli.download_cli(self.cfg) def test_download_from_tag_with_integer_name(self): """Test to reproduce issue #575.""" # use tag name "1000" - cli_string = "lightly-download token='123' dataset_id='dataset_1_id' tag_name=1000" + cli_string = ( + "lightly-download token='123' dataset_id='dataset_1_id' tag_name=1000" + ) self.parse_cli_string(cli_string) with pytest.warns(None) as record: lightly.cli.download_cli(self.cfg) diff --git a/tests/cli/test_cli_embed.py b/tests/cli/test_cli_embed.py index 61e2b88fc..5165141d1 100644 --- a/tests/cli/test_cli_embed.py +++ b/tests/cli/test_cli_embed.py @@ -7,32 +7,38 @@ from hydra.experimental import compose, initialize import lightly -from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup, MockedApiWorkflowClient +from tests.api_workflow.mocked_api_workflow_client import ( + MockedApiWorkflowClient, + MockedApiWorkflowSetup, +) class TestCLIEmbed(MockedApiWorkflowSetup): - @classmethod def setUpClass(cls) -> None: - sys.modules["lightly.cli.embed_cli"].ApiWorkflowClient = \ - MockedApiWorkflowClient + sys.modules["lightly.cli.embed_cli"].ApiWorkflowClient = MockedApiWorkflowClient def setUp(self): MockedApiWorkflowSetup.setUp(self) self.create_fake_dataset() with initialize(config_path="../../lightly/cli/config", job_name="test_app"): - self.cfg = compose(config_name="config", overrides=[ - "token='123'", - f"input_dir={self.folder_path}", - "trainer.max_epochs=0" - ]) + self.cfg = compose( + config_name="config", + overrides=[ + "token='123'", + f"input_dir={self.folder_path}", + "trainer.max_epochs=0", + ], + ) def create_fake_dataset(self): n_data = 16 - self.dataset = torchvision.datasets.FakeData(size=n_data, image_size=(3, 32, 32)) + self.dataset = torchvision.datasets.FakeData( + size=n_data, image_size=(3, 32, 32) + ) self.folder_path = tempfile.mkdtemp() - sample_names = [f'img_{i}.jpg' for i in range(n_data)] + sample_names = [f"img_{i}.jpg" for i in range(n_data)] self.sample_names = sample_names for sample_idx in range(n_data): data = self.dataset[sample_idx] @@ -41,11 +47,16 @@ def create_fake_dataset(self): def test_embed(self): lightly.cli.embed_cli(self.cfg) - self.assertGreater(len(os.getenv( - self.cfg['environment_variable_names'][ - 'lightly_last_embedding_path'] - )), 0) - + self.assertGreater( + len( + os.getenv( + self.cfg["environment_variable_names"][ + "lightly_last_embedding_path" + ] + ) + ), + 0, + ) def tearDown(self) -> None: for filename in ["embeddings.csv", "embeddings_sorted.csv"]: @@ -53,6 +64,3 @@ def tearDown(self) -> None: os.remove(filename) except FileNotFoundError: pass - - - diff --git a/tests/cli/test_cli_magic.py b/tests/cli/test_cli_magic.py index 86787d009..717e7efe6 100644 --- a/tests/cli/test_cli_magic.py +++ b/tests/cli/test_cli_magic.py @@ -7,32 +7,41 @@ from hydra.experimental import compose, initialize from lightly import cli -from tests.api_workflow.mocked_api_workflow_client import \ - MockedApiWorkflowSetup, MockedApiWorkflowClient, N_FILES_ON_SERVER +from tests.api_workflow.mocked_api_workflow_client import ( + N_FILES_ON_SERVER, + MockedApiWorkflowClient, + MockedApiWorkflowSetup, +) class TestCLIMagic(MockedApiWorkflowSetup): - @classmethod def setUpClass(cls) -> None: - sys.modules["lightly.cli.upload_cli"].ApiWorkflowClient = MockedApiWorkflowClient + sys.modules[ + "lightly.cli.upload_cli" + ].ApiWorkflowClient = MockedApiWorkflowClient def setUp(self): MockedApiWorkflowSetup.setUp(self) self.create_fake_dataset() with initialize(config_path="../../lightly/cli/config", job_name="test_app"): - self.cfg = compose(config_name="config", overrides=[ - "token='123'", - f"input_dir={self.folder_path}", - "trainer.max_epochs=0" - ]) - - def create_fake_dataset(self, filename_appendix: str = ''): + self.cfg = compose( + config_name="config", + overrides=[ + "token='123'", + f"input_dir={self.folder_path}", + "trainer.max_epochs=0", + ], + ) + + def create_fake_dataset(self, filename_appendix: str = ""): n_data = len(self.api_workflow_client.get_filenames()) - self.dataset = torchvision.datasets.FakeData(size=n_data, image_size=(3, 32, 32)) + self.dataset = torchvision.datasets.FakeData( + size=n_data, image_size=(3, 32, 32) + ) self.folder_path = tempfile.mkdtemp() - sample_names = [f'img_{i}{filename_appendix}.jpg' for i in range(n_data)] + sample_names = [f"img_{i}{filename_appendix}.jpg" for i in range(n_data)] self.sample_names = sample_names for sample_idx in range(n_data): data = self.dataset[sample_idx] @@ -46,31 +55,32 @@ def parse_cli_string(self, cli_words: str): dict_keys = cli_words[0::2] dict_values = cli_words[1::2] for key, value in zip(dict_keys, dict_values): - value = value.strip('\"') - value = value.strip('\'') + value = value.strip('"') + value = value.strip("'") try: value = int(value) except ValueError: pass - key_subparts = key.split('.') + key_subparts = key.split(".") if len(key_subparts) == 1: self.cfg[key] = value elif len(key_subparts) == 2: self.cfg[key_subparts[0]][key_subparts[1]] = value else: raise ValueError( - f'Keys with more than 2 subparts are not supported,' - f'but you entered {key}.' + f"Keys with more than 2 subparts are not supported," + f"but you entered {key}." ) - def test_parse_cli_string(self): - cli_string = "lightly-magic dataset_id='XYZ' upload='thumbnails' trainer.max_epochs=3" + cli_string = ( + "lightly-magic dataset_id='XYZ' upload='thumbnails' trainer.max_epochs=3" + ) self.parse_cli_string(cli_string) - self.assertEqual(self.cfg["dataset_id"], 'XYZ') - self.assertEqual(self.cfg["upload"], 'thumbnails') - self.assertEqual(self.cfg['trainer']['max_epochs'], 3) + self.assertEqual(self.cfg["dataset_id"], "XYZ") + self.assertEqual(self.cfg["upload"], "thumbnails") + self.assertEqual(self.cfg["trainer"]["max_epochs"], 3) def test_magic_new_dataset_name(self): MockedApiWorkflowClient.n_dims_embeddings_on_server = 32 @@ -103,6 +113,3 @@ def tearDown(self) -> None: os.remove(filename) except FileNotFoundError: pass - - - diff --git a/tests/cli/test_cli_train.py b/tests/cli/test_cli_train.py index d80255ab7..aea173862 100644 --- a/tests/cli/test_cli_train.py +++ b/tests/cli/test_cli_train.py @@ -7,31 +7,40 @@ from hydra.experimental import compose, initialize from lightly import cli -from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup, MockedApiWorkflowClient +from tests.api_workflow.mocked_api_workflow_client import ( + MockedApiWorkflowClient, + MockedApiWorkflowSetup, +) class TestCLITrain(MockedApiWorkflowSetup): - @classmethod def setUpClass(cls) -> None: - sys.modules["lightly.cli.upload_cli"].ApiWorkflowClient = MockedApiWorkflowClient + sys.modules[ + "lightly.cli.upload_cli" + ].ApiWorkflowClient = MockedApiWorkflowClient def setUp(self): MockedApiWorkflowSetup.setUp(self) self.create_fake_dataset() with initialize(config_path="../../lightly/cli/config", job_name="test_app"): - self.cfg = compose(config_name="config", overrides=[ - "token='123'", - f"input_dir={self.folder_path}", - "trainer.max_epochs=1" - ]) + self.cfg = compose( + config_name="config", + overrides=[ + "token='123'", + f"input_dir={self.folder_path}", + "trainer.max_epochs=1", + ], + ) def create_fake_dataset(self): n_data = 5 - self.dataset = torchvision.datasets.FakeData(size=n_data, image_size=(3, 32, 32)) + self.dataset = torchvision.datasets.FakeData( + size=n_data, image_size=(3, 32, 32) + ) self.folder_path = tempfile.mkdtemp() - sample_names = [f'img_{i}.jpg' for i in range(n_data)] + sample_names = [f"img_{i}.jpg" for i in range(n_data)] self.sample_names = sample_names for sample_idx in range(n_data): data = self.dataset[sample_idx] @@ -41,7 +50,7 @@ def create_fake_dataset(self): def test_checkpoint_created(self): cli.train_cli(self.cfg) checkpoint_path = os.getenv( - self.cfg['environment_variable_names']['lightly_last_checkpoint_path'] + self.cfg["environment_variable_names"]["lightly_last_checkpoint_path"] ) assert checkpoint_path.endswith(".ckpt") assert os.path.isfile(checkpoint_path) diff --git a/tests/cli/test_cli_upload.py b/tests/cli/test_cli_upload.py index b8da9ddf3..d6644bebf 100644 --- a/tests/cli/test_cli_upload.py +++ b/tests/cli/test_cli_upload.py @@ -1,7 +1,7 @@ +import json import os import re import sys -import json import tempfile import warnings @@ -10,29 +10,32 @@ from hydra.experimental import compose, initialize import lightly -from lightly.api.api_workflow_upload_embeddings import \ - EmbeddingDoesNotExistError +from lightly.api.api_workflow_upload_embeddings import EmbeddingDoesNotExistError from lightly.cli.upload_cli import SUCCESS_RETURN_VALUE from lightly.openapi_generated.swagger_client import DatasetEmbeddingData from lightly.utils.io import save_embeddings -from tests.api_workflow.mocked_api_workflow_client import \ - MockedApiWorkflowSetup, MockedApiWorkflowClient, N_FILES_ON_SERVER +from tests.api_workflow.mocked_api_workflow_client import ( + N_FILES_ON_SERVER, + MockedApiWorkflowClient, + MockedApiWorkflowSetup, +) class TestCLIUpload(MockedApiWorkflowSetup): - @classmethod def setUpClass(cls) -> None: - sys.modules["lightly.cli.upload_cli"].ApiWorkflowClient = MockedApiWorkflowClient + sys.modules[ + "lightly.cli.upload_cli" + ].ApiWorkflowClient = MockedApiWorkflowClient - - def set_tags(self, zero_tags: bool=True): + def set_tags(self, zero_tags: bool = True): # make the dataset appear empty def mocked_get_all_tags_zero(*args, **kwargs): if zero_tags: return [] else: return ["Any tag"] + MockedApiWorkflowClient.get_all_tags = mocked_get_all_tags_zero def set_embedding(self, has_embedding: bool): @@ -43,15 +46,11 @@ def mocked_get_embedding_by_name(*args, **kwargs): name="name", is_processed=True, created_at=0, - ) else: raise EmbeddingDoesNotExistError - MockedApiWorkflowClient.get_embedding_by_name = \ - mocked_get_embedding_by_name - - + MockedApiWorkflowClient.get_embedding_by_name = mocked_get_embedding_by_name def setUp(self): # make the API dataset appear empty @@ -62,53 +61,54 @@ def setUp(self): self.create_fake_dataset() def create_fake_dataset( - self, n_data: int = N_FILES_ON_SERVER, - n_rows_embeddings: int = N_FILES_ON_SERVER, - n_dims_embeddings: int = 4 + self, + n_data: int = N_FILES_ON_SERVER, + n_rows_embeddings: int = N_FILES_ON_SERVER, + n_dims_embeddings: int = 4, ): - self.dataset = torchvision.datasets.FakeData(size=n_data, - image_size=(3, 32, 32)) + self.dataset = torchvision.datasets.FakeData( + size=n_data, image_size=(3, 32, 32) + ) self.folder_path = tempfile.mkdtemp() - sample_names = [f'img_{i}.jpg' for i in range(n_data)] + sample_names = [f"img_{i}.jpg" for i in range(n_data)] self.sample_names = sample_names for sample_idx in range(n_data): data = self.dataset[sample_idx] path = os.path.join(self.folder_path, sample_names[sample_idx]) data[0].save(path) - + coco_json = {} - coco_json['images'] = [ - {'id': i, 'file_name': fname} for i, fname in enumerate(self.sample_names) + coco_json["images"] = [ + {"id": i, "file_name": fname} for i, fname in enumerate(self.sample_names) ] - coco_json['metadata'] = [ - {'id': i, 'image_id': i, 'custom_metadata': 0 } for i, _ in enumerate(self.sample_names) + coco_json["metadata"] = [ + {"id": i, "image_id": i, "custom_metadata": 0} + for i, _ in enumerate(self.sample_names) ] - + self.tfile = tempfile.NamedTemporaryFile(mode="w+") json.dump(coco_json, self.tfile) self.tfile.flush() # create fake embeddings - self.path_to_embeddings = os.path.join(self.folder_path, 'embeddings.csv') - sample_names_embeddings = [f'img_{i}.jpg' for i in range(n_rows_embeddings)] + self.path_to_embeddings = os.path.join(self.folder_path, "embeddings.csv") + sample_names_embeddings = [f"img_{i}.jpg" for i in range(n_rows_embeddings)] labels = [0] * len(sample_names_embeddings) save_embeddings( self.path_to_embeddings, np.random.randn(n_rows_embeddings, n_dims_embeddings), labels, - sample_names_embeddings + sample_names_embeddings, ) MockedApiWorkflowClient.n_dims_embeddings_on_server = n_dims_embeddings MockedApiWorkflowClient.n_embedding_rows_on_server = n_rows_embeddings - def parse_cli_string( - self, - cli_words: str, + self, + cli_words: str, ): - with initialize(config_path="../../lightly/cli/config", - job_name="test_app"): + with initialize(config_path="../../lightly/cli/config", job_name="test_app"): overrides = [ "token='123'", f"input_dir={self.folder_path}", @@ -120,16 +120,18 @@ def parse_cli_string( self.cfg.merge_with_cli() def test_parse_cli_string(self): - cli_string = f"lightly-upload dataset_id='XYZ' upload='thumbnails' append={True}" + cli_string = ( + f"lightly-upload dataset_id='XYZ' upload='thumbnails' append={True}" + ) self.parse_cli_string(cli_string) - self.assertEqual(self.cfg["dataset_id"], 'XYZ') - self.assertEqual(self.cfg["upload"], 'thumbnails') - self.assertTrue(self.cfg['append']) + self.assertEqual(self.cfg["dataset_id"], "XYZ") + self.assertEqual(self.cfg["upload"], "thumbnails") + self.assertTrue(self.cfg["append"]) def test_upload_no_token(self): cli_string = f"lightly-upload" self.parse_cli_string(cli_string) - self.cfg['token'] = '' + self.cfg["token"] = "" with self.assertWarns(UserWarning): lightly.cli.upload_cli(self.cfg) @@ -138,10 +140,14 @@ def test_upload_new_dataset_name(self): self.parse_cli_string(cli_string) result = lightly.cli.upload_cli(self.cfg) self.assertEqual(result, SUCCESS_RETURN_VALUE) - self.assertGreater(len(os.getenv( - self.cfg['environment_variable_names'][ - 'lightly_last_dataset_id'] - )), 0) + self.assertGreater( + len( + os.getenv( + self.cfg["environment_variable_names"]["lightly_last_dataset_id"] + ) + ), + 0, + ) def test_upload_new_dataset_name_and_embeddings(self): """ @@ -165,16 +171,19 @@ def test_upload_new_dataset_name_and_embeddings(self): with self.subTest( append=append, n_dims_embeddings=n_dims_embeddings, - n_dims_embeddings_server=n_dims_embeddings_server + n_dims_embeddings_server=n_dims_embeddings_server, ): - self.create_fake_dataset( n_data=N_FILES_ON_SERVER, n_rows_embeddings=N_FILES_ON_SERVER, - n_dims_embeddings=n_dims_embeddings + n_dims_embeddings=n_dims_embeddings, + ) + MockedApiWorkflowClient.n_embedding_rows_on_server = ( + n_embedding_rows_on_server + ) + MockedApiWorkflowClient.n_dims_embeddings_on_server = ( + n_dims_embeddings_server ) - MockedApiWorkflowClient.n_embedding_rows_on_server = n_embedding_rows_on_server - MockedApiWorkflowClient.n_dims_embeddings_on_server = n_dims_embeddings_server self.set_embedding(has_embedding=True) cli_string = f"lightly-upload new_dataset_name='new_dataset_name_xyz' embeddings={self.path_to_embeddings} append={append}" self.parse_cli_string(cli_string) @@ -201,7 +210,9 @@ def test_upload_no_dataset(self): lightly.cli.upload_cli(self.cfg) def test_upload_both_dataset(self): - cli_string = "lightly-upload new_dataset_name='new_dataset_name_xyz' dataset_id='xyz'" + cli_string = ( + "lightly-upload new_dataset_name='new_dataset_name_xyz' dataset_id='xyz'" + ) self.parse_cli_string(cli_string) with self.assertWarns(UserWarning): lightly.cli.upload_cli(self.cfg) @@ -248,7 +259,6 @@ def check_upload_dataset_and_embedding( self.assertEqual(result, SUCCESS_RETURN_VALUE) def test_upload_dataset_and_embedding(self): - for input_dir in [True, False]: for existing_dataset in [True, False]: for embeddings_path in [True, False]: diff --git a/tests/conftest.py b/tests/conftest.py index 5854a713d..44cfe3f46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,4 +20,4 @@ def pytest_collection_modifyitems(config, items): skip_slow = pytest.mark.skip(reason="need --runslow option to run") for item in items: if "slow" in item.keywords: - item.add_marker(skip_slow) \ No newline at end of file + item.add_marker(skip_slow) diff --git a/tests/core/test_Core.py b/tests/core/test_Core.py index a534eab9b..0d118b56d 100644 --- a/tests/core/test_Core.py +++ b/tests/core/test_Core.py @@ -1,64 +1,59 @@ -import unittest import os import re import shutil +import tempfile +import unittest import numpy as np -import torchvision -import tempfile import pytest +import torchvision from lightly.core import train_model_and_embed_images class TestCore(unittest.TestCase): - def ensure_dir(self, path_to_folder: str): if not os.path.exists(path_to_folder): os.makedirs(path_to_folder) def create_dataset(self, n_subfolders=5, n_samples_per_subfolder=20): n_tot = n_subfolders * n_samples_per_subfolder - dataset = torchvision.datasets.FakeData(size=n_tot, - image_size=(3, 32, 32)) + dataset = torchvision.datasets.FakeData(size=n_tot, image_size=(3, 32, 32)) tmp_dir = tempfile.mkdtemp() - folder_names = [f'folder_{i}' for i in range(n_subfolders)] - sample_names = [f'img_{i}.jpg' for i in range(n_samples_per_subfolder)] + folder_names = [f"folder_{i}" for i in range(n_subfolders)] + sample_names = [f"img_{i}.jpg" for i in range(n_samples_per_subfolder)] for folder_idx in range(n_subfolders): for sample_idx in range(n_samples_per_subfolder): idx = (folder_idx * n_subfolders) + sample_idx data = dataset[idx] - self.ensure_dir(os.path.join(tmp_dir, - folder_names[folder_idx])) + self.ensure_dir(os.path.join(tmp_dir, folder_names[folder_idx])) - data[0].save(os.path.join(tmp_dir, - folder_names[folder_idx], - sample_names[sample_idx])) + data[0].save( + os.path.join( + tmp_dir, folder_names[folder_idx], sample_names[sample_idx] + ) + ) self.dataset_dir = tmp_dir return tmp_dir, folder_names, sample_names - - #@pytest.mark.slow + # @pytest.mark.slow def test_train_and_embed(self): n_subfolders = 3 n_samples_per_subfolder = 3 n_samples = n_subfolders * n_samples_per_subfolder - # embed, no overwrites - dataset_dir, _, _ = self.create_dataset( - n_subfolders, - n_samples_per_subfolder - ) + # embed, no overwrites + dataset_dir, _, _ = self.create_dataset(n_subfolders, n_samples_per_subfolder) - # train, one overwrite + # train, one overwrite embeddings, labels, filenames = train_model_and_embed_images( input_dir=dataset_dir, - trainer={'max_epochs': 1}, - loader={'num_workers': 0}, + trainer={"max_epochs": 1}, + loader={"num_workers": 0}, ) self.assertEqual(len(embeddings), n_samples) self.assertEqual(len(labels), n_samples) @@ -67,10 +62,9 @@ def test_train_and_embed(self): self.assertIsInstance(int(labels[0]), int) # see if casting to int works self.assertIsInstance(filenames[0], str) - def tearDown(self) -> None: shutil.rmtree(self.dataset_dir) - pattern = '(.*)?.ckpt$' + pattern = "(.*)?.ckpt$" for root, dirs, files in os.walk(os.getcwd()): for file in filter(lambda x: re.match(pattern, x), files): os.remove(os.path.join(root, file)) diff --git a/tests/data/test_LightlyDataset.py b/tests/data/test_LightlyDataset.py index 3d560506e..c5a1a7c6c 100644 --- a/tests/data/test_LightlyDataset.py +++ b/tests/data/test_LightlyDataset.py @@ -1,27 +1,27 @@ -import re -import unittest import os import random +import re import shutil -from typing import Tuple, List - -import torch -import torchvision import tempfile +import unittest import warnings +from typing import List, Tuple + import numpy as np +import torch +import torchvision from PIL.Image import Image from lightly.data import LightlyDataset - from lightly.data._utils import check_images from lightly.utils.io import INVALID_FILENAME_CHARACTERS try: - from lightly.data._video import VideoDataset import av import cv2 + from lightly.data._video import VideoDataset + VIDEO_DATASET_AVAILABLE = True except ModuleNotFoundError: VIDEO_DATASET_AVAILABLE = False @@ -66,7 +66,6 @@ def create_dataset(self, n_subfolders=5, n_samples_per_subfolder=20): return tmp_dir, folder_names, sample_names def create_video_dataset(self, n_videos=5, n_frames_per_video=10, w=32, h=32, c=3): - self.n_videos = n_videos self.n_frames_per_video = n_frames_per_video @@ -122,7 +121,6 @@ def test_create_lightly_dataset_from_folder(self): shutil.rmtree(out_dir) def test_create_lightly_dataset_from_folder_nosubdir(self): - # create a dataset n_tot = 100 tmp_dir, sample_names = self.create_dataset_no_subdir(n_tot) @@ -140,7 +138,6 @@ def test_create_lightly_dataset_from_folder_nosubdir(self): sample, target, fname = dataset[i] def test_create_lightly_dataset_with_invalid_char_in_filename(self): - # create a dataset n_tot = 100 dataset = torchvision.datasets.FakeData(size=n_tot, image_size=(3, 32, 32)) @@ -159,7 +156,6 @@ def test_create_lightly_dataset_with_invalid_char_in_filename(self): dataset = LightlyDataset(input_dir=tmp_dir) def test_check_images(self): - # create a dataset tmp_dir = tempfile.mkdtemp() n_healthy = 100 @@ -320,7 +316,6 @@ def filename_img_fits_video(filename_img: str): self.assertIn(filename, filenames_dataset) def test_transform_setter(self, dataset: LightlyDataset = None): - if dataset is None: tmp_dir, _, _ = self.create_dataset() dataset = LightlyDataset(input_dir=tmp_dir) diff --git a/tests/data/test_LightlySubset.py b/tests/data/test_LightlySubset.py index ca9161e0c..e88b897b9 100644 --- a/tests/data/test_LightlySubset.py +++ b/tests/data/test_LightlySubset.py @@ -1,18 +1,18 @@ -import tempfile import random -from typing import Tuple, List +import tempfile import unittest +from typing import List, Tuple from lightly.data.dataset import LightlyDataset from lightly.data.lightly_subset import LightlySubset - from tests.data.test_LightlyDataset import TestLightlyDataset try: - from lightly.data._video import VideoDataset import av import cv2 + from lightly.data._video import VideoDataset + VIDEO_DATASET_AVAILABLE = True except ModuleNotFoundError: VIDEO_DATASET_AVAILABLE = False diff --git a/tests/data/test_VideoDataset.py b/tests/data/test_VideoDataset.py index 03f3e3ded..6e6e36ba1 100644 --- a/tests/data/test_VideoDataset.py +++ b/tests/data/test_VideoDataset.py @@ -1,15 +1,16 @@ import contextlib import io -import warnings -from fractions import Fraction -import unittest import os import shutil +import tempfile +import unittest +import warnings +from fractions import Fraction from typing import List from unittest import mock +import cv2 import numpy as np -import tempfile import PIL import torch import torchvision @@ -17,12 +18,10 @@ from lightly.data import LightlyDataset, NonIncreasingTimestampError from lightly.data._video import ( VideoDataset, - _make_dataset, _find_non_increasing_timestamps, + _make_dataset, ) -import cv2 - try: import av @@ -62,7 +61,6 @@ def create_dataset_specified_frames_per_video( out.release() def create_dataset(self, n_videos=5, n_frames_per_video=10, w=32, h=32, c=3): - self.n_videos = n_videos self.n_frames_per_video = n_frames_per_video @@ -119,7 +117,6 @@ def test_video_similar_timestamps_for_different_backends(self): shutil.rmtree(self.input_dir) def test_video_dataset_tqdm_args(self): - self.create_dataset() desc = "test_video_dataset_tqdm_args description asdf" f = io.StringIO() @@ -159,7 +156,6 @@ def test_video_dataset_init_dataloader(self): ) def test_video_dataset_from_folder(self): - self.create_dataset() # iterate through different backends diff --git a/tests/data/test_data_collate.py b/tests/data/test_data_collate.py index ae42fe793..c6a6b00a3 100644 --- a/tests/data/test_data_collate.py +++ b/tests/data/test_data_collate.py @@ -5,14 +5,15 @@ import torchvision import torchvision.transforms as transforms -from lightly.transforms import RandomRotate -from lightly.data import collate -from lightly.data import BaseCollateFunction -from lightly.data import ImageCollateFunction -from lightly.data import SimCLRCollateFunction -from lightly.data import MultiCropCollateFunction -from lightly.data import SwaVCollateFunction -from lightly.data import PIRLCollateFunction +from lightly.data import ( + BaseCollateFunction, + ImageCollateFunction, + MultiCropCollateFunction, + PIRLCollateFunction, + SimCLRCollateFunction, + SwaVCollateFunction, + collate, +) from lightly.data.collate import ( DINOCollateFunction, MAECollateFunction, @@ -21,6 +22,7 @@ VICRegCollateFunction, VICRegLCollateFunction, ) +from lightly.transforms import RandomRotate class TestDataCollate(unittest.TestCase): diff --git a/tests/data/test_multi_view_collate.py b/tests/data/test_multi_view_collate.py index 37f9edbab..4dcc2bc5b 100644 --- a/tests/data/test_multi_view_collate.py +++ b/tests/data/test_multi_view_collate.py @@ -1,7 +1,9 @@ -import torch -from torch import Tensor from typing import List, Tuple, Union from warnings import warn + +import torch +from torch import Tensor + from lightly.data.multi_view_collate import MultiViewCollate diff --git a/tests/embedding/test_callbacks.py b/tests/embedding/test_callbacks.py index 5f7917830..dc57bbabe 100644 --- a/tests/embedding/test_callbacks.py +++ b/tests/embedding/test_callbacks.py @@ -1,5 +1,5 @@ -from omegaconf import OmegaConf import pytest +from omegaconf import OmegaConf from lightly.embedding import callbacks diff --git a/tests/embedding/test_embedding.py b/tests/embedding/test_embedding.py index ec1543805..653c9e6b0 100644 --- a/tests/embedding/test_embedding.py +++ b/tests/embedding/test_embedding.py @@ -1,12 +1,12 @@ import os import tempfile import unittest -from typing import Tuple, List +from typing import List, Tuple import numpy as np -import torchvision -from hydra.experimental import initialize, compose import torch +import torchvision +from hydra.experimental import compose, initialize from torch import manual_seed from torch.utils.data import DataLoader @@ -17,13 +17,13 @@ class TestLightlyDataset(unittest.TestCase): def setUp(self): self.folder_path, self.sample_names = self.create_dataset_no_subdir(10) - with initialize(config_path='../../lightly/cli/config', job_name='test_app'): + with initialize(config_path="../../lightly/cli/config", job_name="test_app"): self.cfg = compose( - config_name='config', + config_name="config", overrides=[ 'token="123"', - f'input_dir={self.folder_path}', - 'trainer.max_epochs=0', + f"input_dir={self.folder_path}", + "trainer.max_epochs=0", ], ) @@ -31,7 +31,7 @@ def create_dataset_no_subdir(self, n_samples: int) -> Tuple[str, List[str]]: dataset = torchvision.datasets.FakeData(size=n_samples, image_size=(3, 32, 32)) tmp_dir = tempfile.mkdtemp() - sample_names = [f'img_{i}.jpg' for i in range(n_samples)] + sample_names = [f"img_{i}.jpg" for i in range(n_samples)] for sample_idx in range(n_samples): data = dataset[sample_idx] path = os.path.join(tmp_dir, sample_names[sample_idx]) @@ -44,9 +44,9 @@ def test_embed_correct_order(self): dataset = LightlyDataset(self.folder_path, transform=transform) encoder = get_model_from_config(self.cfg) if torch.cuda.is_available(): - device = torch.device('cuda') + device = torch.device("cuda") else: - device = torch.device('cpu') + device = torch.device("cpu") manual_seed(42) dataloader_1_worker = DataLoader( @@ -62,7 +62,7 @@ def test_embed_correct_order(self): dataset, shuffle=True, num_workers=4, batch_size=4 ) embeddings_4_worker, labels_4_worker, filenames_4_worker = encoder.embed( - dataloader_4_worker, + dataloader_4_worker, device=device, ) diff --git a/tests/loss/test_CO2Regularizer.py b/tests/loss/test_CO2Regularizer.py index 90c97b6e1..7b56efcb1 100644 --- a/tests/loss/test_CO2Regularizer.py +++ b/tests/loss/test_CO2Regularizer.py @@ -1,26 +1,25 @@ import unittest + import torch from lightly.loss.regularizer import CO2Regularizer -class TestCO2Regularizer(unittest.TestCase): +class TestCO2Regularizer(unittest.TestCase): def test_forward_pass_no_memory_bank(self): reg = CO2Regularizer(memory_bank_size=0) for bsz in range(1, 20): - batch_1 = torch.randn((bsz, 32)) batch_2 = torch.randn((bsz, 32)) # symmetry l1 = reg(batch_1, batch_2) l2 = reg(batch_2, batch_1) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.) + self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) def test_forward_pass_memory_bank(self): reg = CO2Regularizer(memory_bank_size=4096) for bsz in range(1, 20): - batch_1 = torch.randn((bsz, 32)) batch_2 = torch.randn((bsz, 32)) @@ -33,15 +32,13 @@ def test_forward_pass_cuda_no_memory_bank(self): reg = CO2Regularizer(memory_bank_size=0) for bsz in range(1, 20): - batch_1 = torch.randn((bsz, 32)).cuda() batch_2 = torch.randn((bsz, 32)).cuda() # symmetry l1 = reg(batch_1, batch_2) l2 = reg(batch_2, batch_1) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.) - + self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) def test_forward_pass_cuda_memory_bank(self): if not torch.cuda.is_available(): @@ -49,11 +46,9 @@ def test_forward_pass_cuda_memory_bank(self): reg = CO2Regularizer(memory_bank_size=4096) for bsz in range(1, 20): - batch_1 = torch.randn((bsz, 32)).cuda() batch_2 = torch.randn((bsz, 32)).cuda() # symmetry l1 = reg(batch_1, batch_2) self.assertGreater(l1.cpu().item(), 0) - diff --git a/tests/loss/test_DCLLoss.py b/tests/loss/test_DCLLoss.py index f8fbfa5dd..efb4fa798 100644 --- a/tests/loss/test_DCLLoss.py +++ b/tests/loss/test_DCLLoss.py @@ -1,4 +1,5 @@ import unittest + import torch from lightly.loss.dcl_loss import DCLLoss, DCLWLoss, negative_mises_fisher_weights @@ -30,7 +31,7 @@ def test_dclloss_forward(self, seed=0): weight_fn=weight_fn, ): criterion = DCLLoss( - temperature=temperature, + temperature=temperature, gather_distributed=gather_distributed, weight_fn=weight_fn, ) @@ -38,7 +39,7 @@ def test_dclloss_forward(self, seed=0): loss1 = criterion(out1, out0) self.assertGreater(loss0, 0) self.assertAlmostEqual(loss0, loss1) - + def test_dclloss_backprop(self, seed=0): torch.manual_seed(seed=seed) out0 = torch.rand(3, 5) diff --git a/tests/loss/test_DINOLoss.py b/tests/loss/test_DINOLoss.py index 4c241c19f..97b550e62 100644 --- a/tests/loss/test_DINOLoss.py +++ b/tests/loss/test_DINOLoss.py @@ -4,8 +4,8 @@ import numpy as np import torch -from torch import nn import torch.nn.functional as F +from torch import nn from lightly.loss import DINOLoss from lightly.models.utils import deactivate_requires_grad @@ -19,11 +19,20 @@ class OriginalDINOLoss(nn.Module): longer assumed. Source: https://github.com/facebookresearch/dino/blob/cb711401860da580817918b9167ed73e3eef3dcf/main_dino.py#L363 - + """ - def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, - warmup_teacher_temp_epochs, nepochs, student_temp=0.1, - center_momentum=0.9): + + def __init__( + self, + out_dim, + ncrops, + warmup_teacher_temp, + teacher_temp, + warmup_teacher_temp_epochs, + nepochs, + student_temp=0.1, + center_momentum=0.9, + ): super().__init__() self.student_temp = student_temp self.center_momentum = center_momentum @@ -31,11 +40,14 @@ def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, self.register_buffer("center", torch.zeros(1, out_dim)) # we apply a warm up for the teacher temperature because # a too high temperature makes the training instable at the beginning - self.teacher_temp_schedule = np.concatenate(( - np.linspace(warmup_teacher_temp, - teacher_temp, warmup_teacher_temp_epochs), - np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp - )) + self.teacher_temp_schedule = np.concatenate( + ( + np.linspace( + warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs + ), + np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp, + ) + ) def forward(self, student_output, teacher_output, epoch): """ @@ -72,10 +84,12 @@ def update_center(self, teacher_output): batch_center = torch.sum(teacher_output, dim=0, keepdim=True) batch_center = batch_center / len(teacher_output) # ema update - self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) + self.center = self.center * self.center_momentum + batch_center * ( + 1 - self.center_momentum + ) -class TestDINOLoss(unittest.TestCase): +class TestDINOLoss(unittest.TestCase): def generate_output(self, batch_size=2, n_views=3, output_dim=4, seed=0): """Returns a list of view representations. @@ -84,7 +98,7 @@ def generate_output(self, batch_size=2, n_views=3, output_dim=4, seed=0): torch.Tensor([img0_view0, img1_view0]), torch.Tensor([img0_view1, img1_view1]) ] - + """ torch.manual_seed(seed) out = [] @@ -93,13 +107,11 @@ def generate_output(self, batch_size=2, n_views=3, output_dim=4, seed=0): out.append(torch.stack(views)) return out - def test_dino_loss_equal_to_original(self): - def test( batch_size=3, - n_global=2, # number of global views - n_local=6, # number of local views + n_global=2, # number of global views + n_local=6, # number of local views output_dim=4, warmup_teacher_temp=0.04, teacher_temp=0.04, @@ -108,17 +120,17 @@ def test( center_momentum=0.9, epoch=0, n_epochs=100, - ): + ): """Runs test with the given input parameters.""" with self.subTest( - f'batch_size={batch_size}, n_global={n_global}, ' - f'n_local={n_local}, output_dim={output_dim}, ' - f'warmup_teacher_temp={warmup_teacher_temp}, ' - f'teacher_temp={teacher_temp}, ' - f'warmup_teacher_temp_epochs={warmup_teacher_temp_epochs}, ' - f'student_temp={student_temp}, ' - f'center_momentum={center_momentum}, epoch={epoch}, ' - f'n_epochs={n_epochs}' + f"batch_size={batch_size}, n_global={n_global}, " + f"n_local={n_local}, output_dim={output_dim}, " + f"warmup_teacher_temp={warmup_teacher_temp}, " + f"teacher_temp={teacher_temp}, " + f"warmup_teacher_temp_epochs={warmup_teacher_temp_epochs}, " + f"student_temp={student_temp}, " + f"center_momentum={center_momentum}, epoch={epoch}, " + f"n_epochs={n_epochs}" ): loss_fn = DINOLoss( output_dim=output_dim, @@ -128,7 +140,7 @@ def test( student_temp=student_temp, center_momentum=center_momentum, ) - + orig_loss_fn = OriginalDINOLoss( out_dim=output_dim, ncrops=n_global + n_local, @@ -139,9 +151,9 @@ def test( student_temp=student_temp, center_momentum=center_momentum, ) - + # Create dummy single layer network. We use this to verify - # that the gradient backprop works properly. + # that the gradient backprop works properly. teacher = torch.nn.Linear(output_dim, output_dim) deactivate_requires_grad(teacher) student = torch.nn.Linear(output_dim, output_dim) @@ -151,7 +163,7 @@ def test( optimizer = torch.optim.SGD(student.parameters(), lr=1) orig_optimizer = torch.optim.SGD(orig_student.parameters(), lr=1) - # Create fake output + # Create fake output teacher_out = self.generate_output( batch_size=batch_size, n_views=n_global, @@ -161,31 +173,31 @@ def test( student_out = self.generate_output( batch_size=batch_size, n_views=n_global + n_local, - output_dim=output_dim, + output_dim=output_dim, seed=1, ) - # Clone input tensors + # Clone input tensors orig_teacher_out = torch.cat(teacher_out) orig_teacher_out = orig_teacher_out.detach().clone() orig_student_out = torch.cat(student_out) orig_student_out = orig_student_out.detach().clone() - # Forward pass + # Forward pass teacher_out = [teacher(view) for view in teacher_out] student_out = [student(view) for view in student_out] orig_teacher_out = orig_teacher(orig_teacher_out) orig_student_out = orig_student(orig_student_out) - + # Calculate loss loss = loss_fn( - teacher_out=teacher_out, - student_out=student_out, + teacher_out=teacher_out, + student_out=student_out, epoch=epoch, ) orig_loss = orig_loss_fn( - student_output=orig_student_out, - teacher_output=orig_teacher_out, + student_output=orig_student_out, + teacher_output=orig_teacher_out, epoch=epoch, ) @@ -203,19 +215,22 @@ def test( self.assertTrue(torch.allclose(center, orig_center)) self.assertTrue(torch.allclose(loss, orig_loss)) - # Parameters of network should be equal after backward pass - for param, orig_param in zip(student.parameters(), orig_student.parameters()): + # Parameters of network should be equal after backward pass + for param, orig_param in zip( + student.parameters(), orig_student.parameters() + ): self.assertTrue(torch.allclose(param, orig_param)) - for param, orig_param in zip(teacher.parameters(), orig_teacher.parameters()): + for param, orig_param in zip( + teacher.parameters(), orig_teacher.parameters() + ): self.assertTrue(torch.allclose(param, orig_param)) - def test_all(**kwargs): """Tests all combinations of the input parameters""" parameters = [] for name, values in kwargs.items(): parameters.append([(name, value) for value in values]) - # parameters = [ + # parameters = [ # [(param1, val11), (param1, val12), ..], # [(param2, val21), (param2, val22), ..], # ... @@ -224,14 +239,14 @@ def test_all(**kwargs): for params in itertools.product(*parameters): # params = [(param1, value1), (param2, value2), ...] test(**dict(params)) - - # test input sizes + + # test input sizes test_all( - batch_size=np.arange(1,4), + batch_size=np.arange(1, 4), n_local=np.arange(0, 4), output_dim=np.arange(1, 4), ) - # test teacher temp warmup + # test teacher temp warmup test_all( warmup_teacher_temp=[0.01, 0.04, 0.07], teacher_temp=[0.01, 0.04, 0.07], diff --git a/tests/loss/test_HyperSphere.py b/tests/loss/test_HyperSphere.py index d3c6c6841..158fc8f55 100644 --- a/tests/loss/test_HyperSphere.py +++ b/tests/loss/test_HyperSphere.py @@ -1,16 +1,15 @@ import unittest + import torch from lightly.loss.hypersphere_loss import HypersphereLoss class TestHyperSphereLoss(unittest.TestCase): - def test_forward_pass(self): loss = HypersphereLoss() # NOTE: skipping bsz==1 case as its not relevant to this loss, and will produce nan-values for bsz in range(2, 20): - batch_1 = torch.randn((bsz, 32)) batch_2 = torch.randn((bsz, 32)) @@ -18,4 +17,4 @@ def test_forward_pass(self): l1 = loss(batch_1, batch_2) l2 = loss(batch_2, batch_1) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.) + self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) diff --git a/tests/loss/test_MSNLoss.py b/tests/loss/test_MSNLoss.py index 34b17aef3..034ef1631 100644 --- a/tests/loss/test_MSNLoss.py +++ b/tests/loss/test_MSNLoss.py @@ -1,16 +1,17 @@ import unittest from unittest import TestCase + import torch -from torch import nn import torch.nn.functional as F +from torch import nn from torch.optim import SGD -from lightly.models.modules.heads import MSNProjectionHead from lightly.loss import msn_loss from lightly.loss.msn_loss import MSNLoss +from lightly.models.modules.heads import MSNProjectionHead -class TestMSNLoss(TestCase): +class TestMSNLoss(TestCase): def test_prototype_probabilitiy(self, seed=0): torch.manual_seed(seed) queries = F.normalize(torch.rand((8, 10)), dim=1) @@ -20,7 +21,7 @@ def test_prototype_probabilitiy(self, seed=0): self.assertLessEqual(prob.max(), 1.0) self.assertGreater(prob.min(), 0.0) - # verify sharpening + # verify sharpening prob1 = msn_loss.prototype_probabilities(queries, prototypes, temperature=0.1) # same prototypes should be assigned regardless of temperature self.assertTrue(torch.all(prob.argmax(dim=1) == prob1.argmax(dim=1))) @@ -34,7 +35,7 @@ def test_sharpen(self, seed=0): p1 = msn_loss.sharpen(prob, temperature=0.1) # indices of max probabilities should be the same regardless of temperature self.assertTrue(torch.all(p0.argmax(dim=1) == p1.argmax(dim=1))) - # max probabilities should be higher for lower temperature + # max probabilities should be higher for lower temperature self.assertTrue(torch.all(p0.max(dim=1)[0] < p1.max(dim=1)[0])) def test_sinkhorn(self, seed=0): @@ -76,7 +77,7 @@ def test_backward(self, seed=0): optimizer = SGD(head.parameters(), lr=0.1) anchors = torch.rand((8 * 4, 5)) targets = torch.rand((8, 5)) - prototypes = nn.Linear(6, 4).weight # 4 prototypes with dim 6 + prototypes = nn.Linear(6, 4).weight # 4 prototypes with dim 6 optimizer.zero_grad() anchors = head(anchors) with torch.no_grad(): @@ -93,12 +94,12 @@ def test_backward(self, seed=0): def test_backward_cuda(self, seed=0): torch.manual_seed(seed) head = MSNProjectionHead(5, 16, 6) - head.to('cuda') + head.to("cuda") criterion = MSNLoss() optimizer = SGD(head.parameters(), lr=0.1) anchors = torch.rand((8 * 4, 5)).cuda() targets = torch.rand((8, 5)).cuda() - prototypes = nn.Linear(6, 4).weight.cuda() # 4 prototypes with dim 6 + prototypes = nn.Linear(6, 4).weight.cuda() # 4 prototypes with dim 6 optimizer.zero_grad() anchors = head(anchors) with torch.no_grad(): diff --git a/tests/loss/test_MemoryBank.py b/tests/loss/test_MemoryBank.py index fa89d9278..dc1df7ca0 100644 --- a/tests/loss/test_MemoryBank.py +++ b/tests/loss/test_MemoryBank.py @@ -1,11 +1,11 @@ import unittest + import torch from lightly.loss.memory_bank import MemoryBankModule class TestNTXentLoss(unittest.TestCase): - def test_init__negative_size(self): with self.assertRaises(ValueError): MemoryBankModule(size=-1) @@ -18,7 +18,6 @@ def test_forward_easy(self): ptr = 0 for i in range(0, n, bsz): - output = torch.randn(2 * bsz, dim) output.requires_grad = True out0, out1 = output[:bsz], output[bsz:] @@ -26,8 +25,8 @@ def test_forward_easy(self): _, curr_memory_bank = memory_bank(out1, update=True) next_memory_bank = memory_bank.bank - curr_diff = out0.T - curr_memory_bank[:, ptr:ptr + bsz] - next_diff = out1.T - next_memory_bank[:, ptr:ptr + bsz] + curr_diff = out0.T - curr_memory_bank[:, ptr : ptr + bsz] + next_diff = out1.T - next_memory_bank[:, ptr : ptr + bsz] # the current memory bank should not hold the batch yet self.assertGreater(curr_diff.norm(), 1e-5) @@ -43,7 +42,6 @@ def test_forward(self): memory_bank = MemoryBankModule(size=size) for i in range(0, n, bsz): - # see if there are any problems when the bank size # is no multiple of the batch size output = torch.randn(bsz, dim) @@ -55,11 +53,10 @@ def test_forward__cuda(self): dim, size = 2, 10 n = 33 * bsz memory_bank = MemoryBankModule(size=size) - device = torch.device('cuda') + device = torch.device("cuda") memory_bank.to(device=device) for i in range(0, n, bsz): - # see if there are any problems when the bank size # is no multiple of the batch size output = torch.randn(bsz, dim, device=device) diff --git a/tests/loss/test_NTXentLoss.py b/tests/loss/test_NTXentLoss.py index ea8470db8..22351094b 100644 --- a/tests/loss/test_NTXentLoss.py +++ b/tests/loss/test_NTXentLoss.py @@ -7,7 +7,6 @@ class TestNTXentLoss(unittest.TestCase): - def test_with_values(self): for n_samples in [1, 2, 4]: for dimension in [1, 2, 16, 64]: @@ -15,18 +14,27 @@ def test_with_values(self): for gather_distributed in [False, True]: out0 = np.random.normal(0, 1, size=(n_samples, dimension)) out1 = np.random.normal(0, 1, size=(n_samples, dimension)) - with self.subTest(msg=( + with self.subTest( + msg=( f"out0.shape={out0.shape}, temperature={temperature}, " f"gather_distributed={gather_distributed}" - )): + ) + ): out0 = torch.FloatTensor(out0) out1 = torch.FloatTensor(out1) - loss_function = NTXentLoss(temperature=temperature, gather_distributed=gather_distributed) + loss_function = NTXentLoss( + temperature=temperature, + gather_distributed=gather_distributed, + ) l1 = float(loss_function(out0, out1)) l2 = float(loss_function(out1, out0)) - l1_manual = self.calc_ntxent_loss_manual(out0, out1, temperature=temperature) - l2_manual = self.calc_ntxent_loss_manual(out0, out1, temperature=temperature) + l1_manual = self.calc_ntxent_loss_manual( + out0, out1, temperature=temperature + ) + l2_manual = self.calc_ntxent_loss_manual( + out0, out1, temperature=temperature + ) self.assertAlmostEqual(l1, l2, places=5) self.assertAlmostEqual(l1, l1_manual, places=5) self.assertAlmostEqual(l2, l2_manual, places=5) @@ -45,7 +53,9 @@ def calc_ntxent_loss_manual(self, out0, out1, temperature: float) -> float: s_i_j = np.zeros((2 * len(out0), 2 * len(out1))) for i in range(2 * N): for j in range(2 * N): - sim = np.inner(z[i], z[j]) / (np.linalg.norm(z[i]) * np.linalg.norm(z[j])) + sim = np.inner(z[i], z[j]) / ( + np.linalg.norm(z[i]) * np.linalg.norm(z[j]) + ) s_i_j[i, j] = sim exponential_i_j = np.exp(s_i_j / temperature) @@ -63,7 +73,7 @@ def calc_ntxent_loss_manual(self, out0, out1, temperature: float) -> float: loss = 0 for k in range(N): loss += l_i_j[k, k + N] + l_i_j[k + N, k] - loss /= (2 * N) + loss /= 2 * N return loss def test_with_correlated_embedding(self): @@ -79,11 +89,16 @@ def test_with_correlated_embedding(self): out1 = torch.FloatTensor(out1) out0.requires_grad = True - with self.subTest(msg=( + with self.subTest( + msg=( f"n_samples: {n_samples}, memory_bank_size: {memory_bank_size}," f"temperature: {temperature}, gather_distributed: {gather_distributed}" - )): - loss_function = NTXentLoss(temperature=temperature, memory_bank_size=memory_bank_size) + ) + ): + loss_function = NTXentLoss( + temperature=temperature, + memory_bank_size=memory_bank_size, + ) if memory_bank_size > 0: for i in range(int(memory_bank_size / n_samples) + 2): # fill the memory bank over multiple rounds @@ -103,7 +118,7 @@ def test_forward_pass(self): # symmetry l1 = loss(batch_1, batch_2) l2 = loss(batch_2, batch_1) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.) + self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) def test_forward_pass_1d(self): loss = NTXentLoss(memory_bank_size=0) @@ -114,10 +129,10 @@ def test_forward_pass_1d(self): # symmetry l1 = loss(batch_1, batch_2) l2 = loss(batch_2, batch_1) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.) + self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) def test_forward_pass_neg_temp(self): - loss = NTXentLoss(temperature=-1., memory_bank_size=0) + loss = NTXentLoss(temperature=-1.0, memory_bank_size=0) for bsz in range(1, 20): batch_1 = torch.randn((bsz, 32)) batch_2 = torch.randn((bsz, 32)) @@ -125,7 +140,7 @@ def test_forward_pass_neg_temp(self): # symmetry l1 = loss(batch_1, batch_2) l2 = loss(batch_2, batch_1) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.) + self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) def test_forward_pass_memory_bank(self): loss = NTXentLoss(memory_bank_size=64) @@ -134,7 +149,7 @@ def test_forward_pass_memory_bank(self): batch_2 = torch.randn((bsz, 32)) l = loss(batch_1, batch_2) - @unittest.skipUnless(torch.cuda.is_available(), 'No cuda') + @unittest.skipUnless(torch.cuda.is_available(), "No cuda") def test_forward_pass_memory_bank_cuda(self): loss = NTXentLoss(memory_bank_size=64) for bsz in range(1, 20): @@ -142,7 +157,7 @@ def test_forward_pass_memory_bank_cuda(self): batch_2 = torch.randn((bsz, 32)).cuda() l = loss(batch_1, batch_2) - @unittest.skipUnless(torch.cuda.is_available(), 'No cuda') + @unittest.skipUnless(torch.cuda.is_available(), "No cuda") def test_forward_pass_cuda(self): loss = NTXentLoss(memory_bank_size=0) for bsz in range(1, 20): @@ -152,4 +167,4 @@ def test_forward_pass_cuda(self): # symmetry l1 = loss(batch_1, batch_2) l2 = loss(batch_2, batch_1) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.) + self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) diff --git a/tests/loss/test_NegativeCosineSimilarity.py b/tests/loss/test_NegativeCosineSimilarity.py index fb1edc2f4..44d29522a 100644 --- a/tests/loss/test_NegativeCosineSimilarity.py +++ b/tests/loss/test_NegativeCosineSimilarity.py @@ -1,4 +1,5 @@ import unittest + import torch from lightly.loss import NegativeCosineSimilarity diff --git a/tests/loss/test_SwaVLoss.py b/tests/loss/test_SwaVLoss.py index c4c868362..50d1009b3 100644 --- a/tests/loss/test_SwaVLoss.py +++ b/tests/loss/test_SwaVLoss.py @@ -7,9 +7,7 @@ class TestSwaVLoss(unittest.TestCase): - def test_forward_pass(self): - n = 32 n_high_res = 2 high_res = [torch.eye(32, 32) for i in range(n_high_res)] @@ -18,15 +16,15 @@ def test_forward_pass(self): for sinkhorn_iterations in range(3): criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) low_res = [torch.eye(n, n) for i in range(n_low_res)] - - with self.subTest(msg=f'n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}'): + + with self.subTest( + msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}" + ): loss = criterion(high_res, low_res) # loss should be almost zero for unit matrix self.assertGreater(0.5, loss.cpu().numpy()) - def test_forward_pass_queue(self): - n = 32 n_high_res = 2 high_res = [torch.eye(32, 32) for i in range(n_high_res)] @@ -37,14 +35,15 @@ def test_forward_pass_queue(self): for sinkhorn_iterations in range(3): criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) low_res = [torch.eye(n, n) for i in range(n_low_res)] - - with self.subTest(msg=f'n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}'): + + with self.subTest( + msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}" + ): loss = criterion(high_res, low_res, queue) # loss should be almost zero for unit matrix self.assertGreater(0.5, loss.cpu().numpy()) def test_forward_pass_bsz_1(self): - n = 32 n_high_res = 2 high_res = [torch.eye(1, n) for i in range(n_high_res)] @@ -53,8 +52,10 @@ def test_forward_pass_bsz_1(self): for sinkhorn_iterations in range(3): criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) low_res = [torch.eye(1, n) for i in range(n_low_res)] - - with self.subTest(msg=f'n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}'): + + with self.subTest( + msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}" + ): loss = criterion(high_res, low_res) def test_forward_pass_1d(self): @@ -66,8 +67,10 @@ def test_forward_pass_1d(self): for sinkhorn_iterations in range(3): criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) low_res = [torch.eye(n, 1) for i in range(n_low_res)] - - with self.subTest(msg=f'n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}'): + + with self.subTest( + msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}" + ): loss = criterion(high_res, low_res) # loss should be almost zero for unit matrix self.assertGreater(0.5, loss.cpu().numpy()) @@ -82,8 +85,10 @@ def test_forward_pass_cuda(self): for sinkhorn_iterations in range(3): criterion = SwaVLoss(sinkhorn_iterations=sinkhorn_iterations) low_res = [torch.eye(n, n).cuda() for i in range(n_low_res)] - - with self.subTest(msg=f'n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}'): + + with self.subTest( + msg=f"n_low_res={n_low_res}, sinkhorn_iterations={sinkhorn_iterations}" + ): loss = criterion(high_res, low_res) # loss should be almost zero for unit matrix self.assertGreater(0.5, loss.cpu().numpy()) diff --git a/tests/loss/test_SymNegCosineSimilarityLoss.py b/tests/loss/test_SymNegCosineSimilarityLoss.py index c3b052676..05fbf44e9 100644 --- a/tests/loss/test_SymNegCosineSimilarityLoss.py +++ b/tests/loss/test_SymNegCosineSimilarityLoss.py @@ -1,4 +1,5 @@ import unittest + import torch from lightly.loss import SymNegCosineSimilarityLoss @@ -8,7 +9,6 @@ class TestSymNegCosineSimilarityLoss(unittest.TestCase): def test_forward_pass(self): loss = SymNegCosineSimilarityLoss() for bsz in range(1, 20): - z0 = torch.randn((bsz, 32)) p0 = torch.randn((bsz, 32)) z1 = torch.randn((bsz, 32)) @@ -17,8 +17,7 @@ def test_forward_pass(self): # symmetry l1 = loss((z0, p0), (z1, p1)) l2 = loss((z1, p1), (z0, p0)) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.) - + self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) def test_forward_pass_cuda(self): if not torch.cuda.is_available(): @@ -26,7 +25,6 @@ def test_forward_pass_cuda(self): loss = SymNegCosineSimilarityLoss() for bsz in range(1, 20): - z0 = torch.randn((bsz, 32)).cuda() p0 = torch.randn((bsz, 32)).cuda() z1 = torch.randn((bsz, 32)).cuda() @@ -35,20 +33,18 @@ def test_forward_pass_cuda(self): # symmetry l1 = loss((z0, p0), (z1, p1)) l2 = loss((z1, p1), (z0, p0)) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.) - + self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) def test_neg_cosine_simililarity(self): loss = SymNegCosineSimilarityLoss() for bsz in range(1, 20): - x = torch.randn((bsz, 32)) y = torch.randn((bsz, 32)) # symmetry l1 = loss._neg_cosine_simililarity(x, y) l2 = loss._neg_cosine_simililarity(y, x) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.) + self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) def test_neg_cosine_simililarity_cuda(self): if not torch.cuda.is_available(): @@ -56,11 +52,10 @@ def test_neg_cosine_simililarity_cuda(self): loss = SymNegCosineSimilarityLoss() for bsz in range(1, 20): - x = torch.randn((bsz, 32)).cuda() y = torch.randn((bsz, 32)).cuda() # symmetry l1 = loss._neg_cosine_simililarity(x, y) l2 = loss._neg_cosine_simililarity(y, x) - self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.) + self.assertAlmostEqual((l1 - l2).pow(2).item(), 0.0) diff --git a/tests/loss/test_TicoLoss.py b/tests/loss/test_TicoLoss.py index 863189ccc..c06647562 100644 --- a/tests/loss/test_TicoLoss.py +++ b/tests/loss/test_TicoLoss.py @@ -1,8 +1,10 @@ import unittest + import torch from lightly.loss.tico_loss import TiCoLoss + class TestTiCoLoss(unittest.TestCase): def test_forward_pass(self): torch.manual_seed(0) @@ -43,4 +45,4 @@ def test_forward_pass__error_different_shapes(self): x0 = torch.randn((2, 32)) x1 = torch.randn((2, 16)) with self.assertRaises(AssertionError): - loss(x0, x1, update_covariance_matrix=False) \ No newline at end of file + loss(x0, x1, update_covariance_matrix=False) diff --git a/tests/loss/test_VICRegLLoss.py b/tests/loss/test_VICRegLLoss.py index d22d5c860..87755a203 100644 --- a/tests/loss/test_VICRegLLoss.py +++ b/tests/loss/test_VICRegLLoss.py @@ -1,10 +1,11 @@ import unittest + import torch from lightly.loss import VICRegLLoss -class TestVICRegLLoss(unittest.TestCase): +class TestVICRegLLoss(unittest.TestCase): def test_forward_pass(self): loss = VICRegLLoss() x0 = torch.randn((2, 32)) @@ -27,7 +28,6 @@ def test_forward_pass_cuda(self): grid1 = torch.randn((2, 7, 7, 2)).cuda() assert loss(x0, x1, x0_L, x1_L, grid0, grid1) - def test_forward_pass__error_batch_size_1(self): loss = VICRegLLoss() x0 = torch.randn((1, 32)) diff --git a/tests/loss/test_VICRegLoss.py b/tests/loss/test_VICRegLoss.py index 3471ebdee..dea6b480d 100644 --- a/tests/loss/test_VICRegLoss.py +++ b/tests/loss/test_VICRegLoss.py @@ -1,8 +1,10 @@ import unittest + import torch from lightly.loss import VICRegLoss + class TestVICRegLoss(unittest.TestCase): def test_forward_pass(self): loss = VICRegLoss() diff --git a/tests/models/modules/test_masked_autoencoder.py b/tests/models/modules/test_masked_autoencoder.py index 373eb521b..8479ed26a 100644 --- a/tests/models/modules/test_masked_autoencoder.py +++ b/tests/models/modules/test_masked_autoencoder.py @@ -1,11 +1,14 @@ import unittest + import torch import torchvision + from lightly import _torchvision_vit_available from lightly.models import utils if _torchvision_vit_available: - from lightly.models.modules import MAEEncoder, MAEDecoder, MAEBackbone + from lightly.models.modules import MAEBackbone, MAEDecoder, MAEEncoder + @unittest.skipUnless(_torchvision_vit_available, "Torchvision ViT not available") class TestMAEEncoder(unittest.TestCase): @@ -34,15 +37,15 @@ def _test_forward(self, device, batch_size=8, seed=0): expected_shape[1] = idx_keep.shape[1] self.assertListEqual(list(out.shape), expected_shape) - # output must have reasonable numbers + # output must have reasonable numbers self.assertTrue(torch.all(torch.not_equal(out, torch.inf))) def test_forward(self): - self._test_forward(torch.device('cpu')) + self._test_forward(torch.device("cpu")) @unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.") def test_forward_cuda(self): - self._test_forward(torch.device('cuda')) + self._test_forward(torch.device("cuda")) @unittest.skipUnless(_torchvision_vit_available, "Torchvision ViT not available") @@ -70,19 +73,19 @@ def _test_forward(self, device, batch_size=8, seed=0): expected_shape = [batch_size, vit.hidden_dim] self.assertListEqual(list(class_tokens.shape), expected_shape) - # output must have reasonable numbers + # output must have reasonable numbers self.assertTrue(torch.all(torch.not_equal(class_tokens, torch.inf))) def test_forward(self): - self._test_forward(torch.device('cpu')) + self._test_forward(torch.device("cpu")) @unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.") def test_forward_cuda(self): - self._test_forward(torch.device('cuda')) + self._test_forward(torch.device("cuda")) + @unittest.skipUnless(_torchvision_vit_available, "Torchvision ViT not available") class TestMAEDecoder(unittest.TestCase): - def test_init(self): return MAEDecoder( seq_length=50, @@ -91,14 +94,14 @@ def test_init(self): embed_input_dim=128, hidden_dim=256, mlp_dim=256 * 4, - out_dim=3 * 32 ** 2, + out_dim=3 * 32**2, ) def _test_forward(self, device, batch_size=8, seed=0): torch.manual_seed(seed) seq_length = 50 embed_input_dim = 128 - out_dim = 3 * 32 ** 2 + out_dim = 3 * 32**2 decoder = MAEDecoder( seq_length=seq_length, num_layers=2, @@ -115,12 +118,12 @@ def _test_forward(self, device, batch_size=8, seed=0): expected_shape = [batch_size, seq_length, out_dim] self.assertListEqual(list(predictions.shape), expected_shape) - # output must have reasonable numbers + # output must have reasonable numbers self.assertTrue(torch.all(torch.not_equal(predictions, torch.inf))) def test_forward(self): - self._test_forward(torch.device('cpu')) + self._test_forward(torch.device("cpu")) @unittest.skipUnless(torch.cuda.is_available(), "Cuda not available.") def test_forward_cuda(self): - self._test_forward(torch.device('cuda')) + self._test_forward(torch.device("cuda")) diff --git a/tests/models/test_ModelUtils.py b/tests/models/test_ModelUtils.py index f6ef70a39..ce2b99d18 100644 --- a/tests/models/test_ModelUtils.py +++ b/tests/models/test_ModelUtils.py @@ -1,18 +1,20 @@ -import unittest import copy +import unittest import torch import torch.nn as nn from lightly.models import utils -from lightly.models.utils import batch_shuffle -from lightly.models.utils import batch_unshuffle -from lightly.models.utils import activate_requires_grad -from lightly.models.utils import deactivate_requires_grad -from lightly.models.utils import update_momentum -from lightly.models.utils import normalize_weight -from lightly.models.utils import _no_grad_trunc_normal -from lightly.models.utils import nearest_neighbors +from lightly.models.utils import ( + _no_grad_trunc_normal, + activate_requires_grad, + batch_shuffle, + batch_unshuffle, + deactivate_requires_grad, + nearest_neighbors, + normalize_weight, + update_momentum, +) def has_grad(model: nn.Module): @@ -181,7 +183,7 @@ def test_patchify(self, seed=0): ) # make sure that patches are correctly formed - for (image, img_patches) in zip(images, batch_patches): + for image, img_patches in zip(images, batch_patches): for i in range(height_patches): for j in range(width_patches): # extract patch from original image diff --git a/tests/models/test_ModelsBYOL.py b/tests/models/test_ModelsBYOL.py index 9dc547abe..a27223655 100644 --- a/tests/models/test_ModelsBYOL.py +++ b/tests/models/test_ModelsBYOL.py @@ -1,4 +1,3 @@ - import unittest import torch @@ -6,8 +5,7 @@ import torchvision import lightly -from lightly.models import ResNetGenerator -from lightly.models import BYOL +from lightly.models import BYOL, ResNetGenerator def get_backbone(resnet, num_ftrs=64): @@ -22,12 +20,8 @@ def get_backbone(resnet, num_ftrs=64): class TestModelsBYOL(unittest.TestCase): - def setUp(self): - self.resnet_variants = [ - 'resnet-18', - 'resnet-50' - ] + self.resnet_variants = ["resnet-18", "resnet-50"] self.batch_size = 2 self.input_tensor = torch.rand((self.batch_size, 3, 32, 32)) @@ -38,8 +32,8 @@ def test_create_variations_cpu(self): self.assertIsNotNone(model) def test_create_variations_gpu(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' - if device == 'cuda': + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cuda": for model_name in self.resnet_variants: resnet = ResNetGenerator(model_name) model = BYOL(get_backbone(resnet)).to(device) @@ -48,60 +42,49 @@ def test_create_variations_gpu(self): pass def test_feature_dim_configurable(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" for model_name in self.resnet_variants: for num_ftrs, out_dim in zip([16, 64], [64, 256]): resnet = ResNetGenerator(model_name) - model = BYOL(get_backbone(resnet, num_ftrs=num_ftrs), - num_ftrs=num_ftrs, - out_dim=out_dim).to(device) + model = BYOL( + get_backbone(resnet, num_ftrs=num_ftrs), + num_ftrs=num_ftrs, + out_dim=out_dim, + ).to(device) # check that feature vector has correct dimension with torch.no_grad(): - out_features = model.backbone( - self.input_tensor.to(device) - ) + out_features = model.backbone(self.input_tensor.to(device)) self.assertEqual(out_features.shape[1], num_ftrs) # check that projection head output has right dimension with torch.no_grad(): - out_projection = model.projection_head( - out_features.squeeze() - ) + out_projection = model.projection_head(out_features.squeeze()) self.assertEqual(out_projection.shape[1], out_dim) self.assertIsNotNone(model) def test_variations_input_dimension(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" for model_name in self.resnet_variants: for input_width, input_height in zip([32, 64], [64, 64]): resnet = ResNetGenerator(model_name) - model = BYOL( - get_backbone(resnet, num_ftrs=32), - num_ftrs=32 - ).to(device) + model = BYOL(get_backbone(resnet, num_ftrs=32), num_ftrs=32).to(device) - input_tensor = torch.rand((self.batch_size, - 3, - input_height, - input_width)) + input_tensor = torch.rand( + (self.batch_size, 3, input_height, input_width) + ) with torch.no_grad(): - out, _ = model( - input_tensor.to(device), - input_tensor.to(device) - ) + out, _ = model(input_tensor.to(device), input_tensor.to(device)) self.assertIsNotNone(model) self.assertIsNotNone(out) def test_tuple_input(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' - resnet = ResNetGenerator('resnet-18') - model = BYOL( - get_backbone(resnet, num_ftrs=32), - num_ftrs=32, - out_dim=128 - ).to(device) + device = "cuda" if torch.cuda.is_available() else "cpu" + resnet = ResNetGenerator("resnet-18") + model = BYOL(get_backbone(resnet, num_ftrs=32), num_ftrs=32, out_dim=128).to( + device + ) x0 = torch.rand((self.batch_size, 3, 64, 64)).to(device) x1 = torch.rand((self.batch_size, 3, 64, 64)).to(device) @@ -113,8 +96,7 @@ def test_tuple_input(self): self.assertEqual(p1.shape, (self.batch_size, 128)) def test_raises(self): - - resnet = ResNetGenerator('resnet-18') + resnet = ResNetGenerator("resnet-18") model = BYOL(get_backbone(resnet)) x0 = torch.rand((self.batch_size, 3, 64, 64)) @@ -130,5 +112,5 @@ def test_raises(self): model(x0, x1) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/models/test_ModelsMoCo.py b/tests/models/test_ModelsMoCo.py index 9eb721da6..f94dc33da 100644 --- a/tests/models/test_ModelsMoCo.py +++ b/tests/models/test_ModelsMoCo.py @@ -5,8 +5,7 @@ import torchvision import lightly -from lightly.models import ResNetGenerator -from lightly.models import MoCo +from lightly.models import MoCo, ResNetGenerator def get_backbone(resnet, num_ftrs=64): @@ -21,12 +20,8 @@ def get_backbone(resnet, num_ftrs=64): class TestModelsMoCo(unittest.TestCase): - def setUp(self): - self.resnet_variants = [ - 'resnet-18', - 'resnet-50' - ] + self.resnet_variants = ["resnet-18", "resnet-50"] self.batch_size = 2 self.input_tensor = torch.rand((self.batch_size, 3, 32, 32)) @@ -37,8 +32,8 @@ def test_create_variations_cpu(self): self.assertIsNotNone(model) def test_create_variations_gpu(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' - if device == 'cuda': + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cuda": for model_name in self.resnet_variants: resnet = ResNetGenerator(model_name) model = MoCo(get_backbone(resnet)).to(device) @@ -47,40 +42,37 @@ def test_create_variations_gpu(self): pass def test_feature_dim_configurable(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" for model_name in self.resnet_variants: for num_ftrs, out_dim in zip([16, 64], [64, 256]): resnet = ResNetGenerator(model_name) - model = MoCo(get_backbone(resnet, num_ftrs=num_ftrs), - num_ftrs=num_ftrs, - out_dim=out_dim).to(device) + model = MoCo( + get_backbone(resnet, num_ftrs=num_ftrs), + num_ftrs=num_ftrs, + out_dim=out_dim, + ).to(device) # check that feature vector has correct dimension with torch.no_grad(): - out_features = model.backbone( - self.input_tensor.to(device) - ) + out_features = model.backbone(self.input_tensor.to(device)) self.assertEqual(out_features.shape[1], num_ftrs) # check that projection head output has right dimension with torch.no_grad(): - out_projection = model.projection_head( - out_features.squeeze() - ) + out_projection = model.projection_head(out_features.squeeze()) self.assertEqual(out_projection.shape[1], out_dim) self.assertIsNotNone(model) def test_variations_input_dimension(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" for model_name in self.resnet_variants: for input_width, input_height in zip([32, 64], [64, 64]): resnet = ResNetGenerator(model_name) model = MoCo(get_backbone(resnet, num_ftrs=32)).to(device) - input_tensor = torch.rand((self.batch_size, - 3, - input_height, - input_width)) + input_tensor = torch.rand( + (self.batch_size, 3, input_height, input_width) + ) with torch.no_grad(): out = model(input_tensor.to(device)) @@ -88,8 +80,8 @@ def test_variations_input_dimension(self): self.assertIsNotNone(out) def test_tuple_input(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' - resnet = ResNetGenerator('resnet-18') + device = "cuda" if torch.cuda.is_available() else "cpu" + resnet = ResNetGenerator("resnet-18") model = MoCo(get_backbone(resnet, num_ftrs=32), out_dim=128).to(device) x0 = torch.rand((self.batch_size, 3, 64, 64)).to(device) @@ -113,5 +105,5 @@ def test_tuple_input(self): self.assertEqual(f1.shape, (self.batch_size, 32)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/models/test_ModelsNNCLR.py b/tests/models/test_ModelsNNCLR.py index 6217758f4..3760faef3 100644 --- a/tests/models/test_ModelsNNCLR.py +++ b/tests/models/test_ModelsNNCLR.py @@ -9,9 +9,9 @@ def resnet_generator(name: str): - if name == 'resnet18': + if name == "resnet18": return torchvision.models.resnet18() - elif name == 'resnet50': + elif name == "resnet50": return torchvision.models.resnet50() raise NotImplementedError @@ -20,22 +20,22 @@ def get_backbone(model: nn.Module): backbone = torch.nn.Sequential(*(list(model.children())[:-1])) return backbone -class TestNNCLR(unittest.TestCase): +class TestNNCLR(unittest.TestCase): def setUp(self): self.resnet_variants = dict( - resnet18 = dict( + resnet18=dict( num_ftrs=512, proj_hidden_dim=512, pred_hidden_dim=128, out_dim=512, ), - resnet50 = dict( + resnet50=dict( num_ftrs=2048, proj_hidden_dim=2048, pred_hidden_dim=512, out_dim=2048, - ) + ), ) self.batch_size = 2 self.input_tensor = torch.rand((self.batch_size, 3, 32, 32)) @@ -52,38 +52,32 @@ def test_create_variations_gpu(self): for model_name, config in self.resnet_variants.items(): resnet = resnet_generator(model_name) - model = NNCLR(get_backbone(resnet), **config).to('cuda') + model = NNCLR(get_backbone(resnet), **config).to("cuda") self.assertIsNotNone(model) def test_feature_dim_configurable(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" for model_name, config in self.resnet_variants.items(): resnet = resnet_generator(model_name) model = NNCLR(get_backbone(resnet), **config).to(device) # check that feature vector has correct dimension with torch.no_grad(): - out_features = model.backbone( - self.input_tensor.to(device) - ) - self.assertEqual(out_features.shape[1], config['num_ftrs']) + out_features = model.backbone(self.input_tensor.to(device)) + self.assertEqual(out_features.shape[1], config["num_ftrs"]) # check that projection head output has right dimension with torch.no_grad(): - out_projection = model.projection_mlp( - out_features.squeeze() - ) - self.assertEqual(out_projection.shape[1], config['out_dim']) + out_projection = model.projection_mlp(out_features.squeeze()) + self.assertEqual(out_projection.shape[1], config["out_dim"]) # check that prediction head output has right dimension with torch.no_grad(): - out_prediction = model.prediction_mlp( - out_projection.squeeze() - ) - self.assertEqual(out_prediction.shape[1], config['out_dim']) + out_prediction = model.prediction_mlp(out_projection.squeeze()) + self.assertEqual(out_prediction.shape[1], config["out_dim"]) def test_tuple_input(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" for model_name, config in self.resnet_variants.items(): resnet = resnet_generator(model_name) model = NNCLR(get_backbone(resnet), **config).to(device) @@ -92,35 +86,35 @@ def test_tuple_input(self): x1 = torch.rand((self.batch_size, 3, 64, 64)).to(device) out = model(x0) - self.assertEqual(out[0].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out[1].shape, (self.batch_size, config['out_dim'])) + self.assertEqual(out[0].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out[1].shape, (self.batch_size, config["out_dim"])) out, features = model(x0, return_features=True) - self.assertEqual(out[0].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out[1].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(features.shape, (self.batch_size, config['num_ftrs'])) + self.assertEqual(out[0].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out[1].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(features.shape, (self.batch_size, config["num_ftrs"])) out0, out1 = model(x0, x1) - self.assertEqual(out0[0].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out0[1].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out1[0].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out1[1].shape, (self.batch_size, config['out_dim'])) + self.assertEqual(out0[0].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out0[1].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out1[0].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out1[1].shape, (self.batch_size, config["out_dim"])) (out0, f0), (out1, f1) = model(x0, x1, return_features=True) - self.assertEqual(out0[0].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out0[1].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out1[0].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out1[1].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(f0.shape, (self.batch_size, config['num_ftrs'])) - self.assertEqual(f1.shape, (self.batch_size, config['num_ftrs'])) + self.assertEqual(out0[0].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out0[1].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out1[0].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out1[1].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(f0.shape, (self.batch_size, config["num_ftrs"])) + self.assertEqual(f1.shape, (self.batch_size, config["num_ftrs"])) def test_memory_bank(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" for model_name, config in self.resnet_variants.items(): resnet = resnet_generator(model_name) model = NNCLR(get_backbone(resnet), **config).to(device) - for nn_size in [2 ** 3, 2 ** 8]: + for nn_size in [2**3, 2**8]: nn_replacer = NNMemoryBankModule(size=nn_size) with torch.no_grad(): @@ -129,4 +123,4 @@ def test_memory_bank(self): x1 = torch.rand((self.batch_size, 3, 64, 64)).to(device) (z0, p0), (z1, p1) = model(x0, x1) z0 = nn_replacer(z0.detach(), update=False) - z1 = nn_replacer(z1.detach(), update=True) \ No newline at end of file + z1 = nn_replacer(z1.detach(), update=True) diff --git a/tests/models/test_ModelsSimCLR.py b/tests/models/test_ModelsSimCLR.py index 749e2595c..037cbdf82 100644 --- a/tests/models/test_ModelsSimCLR.py +++ b/tests/models/test_ModelsSimCLR.py @@ -5,8 +5,7 @@ import torchvision import lightly -from lightly.models import ResNetGenerator -from lightly.models import SimCLR +from lightly.models import ResNetGenerator, SimCLR def get_backbone(resnet, num_ftrs=64): @@ -21,12 +20,8 @@ def get_backbone(resnet, num_ftrs=64): class TestModelsSimCLR(unittest.TestCase): - def setUp(self): - self.resnet_variants = [ - 'resnet-18', - 'resnet-50' - ] + self.resnet_variants = ["resnet-18", "resnet-50"] self.batch_size = 2 self.input_tensor = torch.rand((self.batch_size, 3, 32, 32)) @@ -37,8 +32,8 @@ def test_create_variations_cpu(self): self.assertIsNotNone(model) def test_create_variations_gpu(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' - if device == 'cuda': + device = "cuda" if torch.cuda.is_available() else "cpu" + if device == "cuda": for model_name in self.resnet_variants: resnet = ResNetGenerator(model_name) model = SimCLR(get_backbone(resnet)).to(device) @@ -47,40 +42,37 @@ def test_create_variations_gpu(self): pass def test_feature_dim_configurable(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" for model_name in self.resnet_variants: for num_ftrs, out_dim in zip([16, 64], [64, 256]): resnet = ResNetGenerator(model_name) - model = SimCLR(get_backbone(resnet, num_ftrs=num_ftrs), - num_ftrs=num_ftrs, - out_dim=out_dim).to(device) + model = SimCLR( + get_backbone(resnet, num_ftrs=num_ftrs), + num_ftrs=num_ftrs, + out_dim=out_dim, + ).to(device) # check that feature vector has correct dimension with torch.no_grad(): - out_features = model.backbone( - self.input_tensor.to(device) - ) + out_features = model.backbone(self.input_tensor.to(device)) self.assertEqual(out_features.shape[1], num_ftrs) # check that projection head output has right dimension with torch.no_grad(): - out_projection = model.projection_head( - out_features.squeeze() - ) + out_projection = model.projection_head(out_features.squeeze()) self.assertEqual(out_projection.shape[1], out_dim) self.assertIsNotNone(model) def test_variations_input_dimension(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" for model_name in self.resnet_variants: for input_width, input_height in zip([32, 64], [64, 64]): resnet = ResNetGenerator(model_name) model = SimCLR(get_backbone(resnet, num_ftrs=32)).to(device) - input_tensor = torch.rand((self.batch_size, - 3, - input_height, - input_width)) + input_tensor = torch.rand( + (self.batch_size, 3, input_height, input_width) + ) with torch.no_grad(): out = model(input_tensor.to(device)) @@ -88,8 +80,8 @@ def test_variations_input_dimension(self): self.assertIsNotNone(out) def test_tuple_input(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' - resnet = ResNetGenerator('resnet-18') + device = "cuda" if torch.cuda.is_available() else "cpu" + resnet = ResNetGenerator("resnet-18") model = SimCLR(get_backbone(resnet, num_ftrs=32), out_dim=128).to(device) x0 = torch.rand((self.batch_size, 3, 64, 64)).to(device) @@ -113,5 +105,5 @@ def test_tuple_input(self): self.assertEqual(f1.shape, (self.batch_size, 32)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/models/test_ModelsSimSiam.py b/tests/models/test_ModelsSimSiam.py index 2da3a3c94..5672812c4 100644 --- a/tests/models/test_ModelsSimSiam.py +++ b/tests/models/test_ModelsSimSiam.py @@ -8,9 +8,9 @@ def resnet_generator(name: str): - if name == 'resnet18': + if name == "resnet18": return torchvision.models.resnet18() - elif name == 'resnet50': + elif name == "resnet50": return torchvision.models.resnet50() raise NotImplementedError @@ -21,21 +21,20 @@ def get_backbone(model: nn.Module): class TestSimSiam(unittest.TestCase): - def setUp(self): self.resnet_variants = dict( - resnet18 = dict( + resnet18=dict( num_ftrs=512, proj_hidden_dim=512, pred_hidden_dim=128, out_dim=512, ), - resnet50 = dict( + resnet50=dict( num_ftrs=2048, proj_hidden_dim=2048, pred_hidden_dim=512, out_dim=2048, - ) + ), ) self.batch_size = 2 self.input_tensor = torch.rand((self.batch_size, 3, 32, 32)) @@ -52,38 +51,32 @@ def test_create_variations_gpu(self): for model_name, config in self.resnet_variants.items(): resnet = resnet_generator(model_name) - model = SimSiam(get_backbone(resnet), **config).to('cuda') + model = SimSiam(get_backbone(resnet), **config).to("cuda") self.assertIsNotNone(model) def test_feature_dim_configurable(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" for model_name, config in self.resnet_variants.items(): resnet = resnet_generator(model_name) model = SimSiam(get_backbone(resnet), **config).to(device) # check that feature vector has correct dimension with torch.no_grad(): - out_features = model.backbone( - self.input_tensor.to(device) - ) - self.assertEqual(out_features.shape[1], config['num_ftrs']) + out_features = model.backbone(self.input_tensor.to(device)) + self.assertEqual(out_features.shape[1], config["num_ftrs"]) # check that projection head output has right dimension with torch.no_grad(): - out_projection = model.projection_mlp( - out_features.squeeze() - ) - self.assertEqual(out_projection.shape[1], config['out_dim']) + out_projection = model.projection_mlp(out_features.squeeze()) + self.assertEqual(out_projection.shape[1], config["out_dim"]) # check that prediction head output has right dimension with torch.no_grad(): - out_prediction = model.prediction_mlp( - out_projection.squeeze() - ) - self.assertEqual(out_prediction.shape[1], config['out_dim']) + out_prediction = model.prediction_mlp(out_projection.squeeze()) + self.assertEqual(out_prediction.shape[1], config["out_dim"]) def test_tuple_input(self): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" for model_name, config in self.resnet_variants.items(): resnet = resnet_generator(model_name) model = SimSiam(get_backbone(resnet), **config).to(device) @@ -92,25 +85,24 @@ def test_tuple_input(self): x1 = torch.rand((self.batch_size, 3, 64, 64)).to(device) out = model(x0) - self.assertEqual(out[0].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out[1].shape, (self.batch_size, config['out_dim'])) + self.assertEqual(out[0].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out[1].shape, (self.batch_size, config["out_dim"])) out, features = model(x0, return_features=True) - self.assertEqual(out[0].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out[1].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(features.shape, (self.batch_size, config['num_ftrs'])) + self.assertEqual(out[0].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out[1].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(features.shape, (self.batch_size, config["num_ftrs"])) out0, out1 = model(x0, x1) - self.assertEqual(out0[0].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out0[1].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out1[0].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out1[1].shape, (self.batch_size, config['out_dim'])) + self.assertEqual(out0[0].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out0[1].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out1[0].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out1[1].shape, (self.batch_size, config["out_dim"])) (out0, f0), (out1, f1) = model(x0, x1, return_features=True) - self.assertEqual(out0[0].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out0[1].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out1[0].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(out1[1].shape, (self.batch_size, config['out_dim'])) - self.assertEqual(f0.shape, (self.batch_size, config['num_ftrs'])) - self.assertEqual(f1.shape, (self.batch_size, config['num_ftrs'])) - + self.assertEqual(out0[0].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out0[1].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out1[0].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(out1[1].shape, (self.batch_size, config["out_dim"])) + self.assertEqual(f0.shape, (self.batch_size, config["num_ftrs"])) + self.assertEqual(f1.shape, (self.batch_size, config["num_ftrs"])) diff --git a/tests/models/test_ProjectionHeads.py b/tests/models/test_ProjectionHeads.py index 59e5f8de1..2518d290c 100644 --- a/tests/models/test_ProjectionHeads.py +++ b/tests/models/test_ProjectionHeads.py @@ -3,26 +3,27 @@ import torch import lightly -from lightly.models.modules.heads import BarlowTwinsProjectionHead -from lightly.models.modules.heads import BYOLProjectionHead -from lightly.models.modules.heads import BYOLPredictionHead -from lightly.models.modules.heads import DINOProjectionHead -from lightly.models.modules.heads import MoCoProjectionHead -from lightly.models.modules.heads import MSNProjectionHead -from lightly.models.modules.heads import NNCLRProjectionHead -from lightly.models.modules.heads import NNCLRPredictionHead -from lightly.models.modules.heads import SimCLRProjectionHead -from lightly.models.modules.heads import SimSiamProjectionHead -from lightly.models.modules.heads import SimSiamPredictionHead -from lightly.models.modules.heads import SwaVProjectionHead -from lightly.models.modules.heads import SwaVPrototypes -from lightly.models.modules.heads import VicRegLLocalProjectionHead -from lightly.models.modules.heads import TiCoProjectionHead +from lightly.models.modules.heads import ( + BarlowTwinsProjectionHead, + BYOLPredictionHead, + BYOLProjectionHead, + DINOProjectionHead, + MoCoProjectionHead, + MSNProjectionHead, + NNCLRPredictionHead, + NNCLRProjectionHead, + SimCLRProjectionHead, + SimSiamPredictionHead, + SimSiamProjectionHead, + SwaVProjectionHead, + SwaVPrototypes, + TiCoProjectionHead, + VicRegLLocalProjectionHead, +) class TestProjectionHeads(unittest.TestCase): def setUp(self): - self.n_features = [ (8, 16, 32), (8, 32, 16), diff --git a/tests/transforms/test_GaussianBlur.py b/tests/transforms/test_GaussianBlur.py index f1cbc24a2..1a4823758 100644 --- a/tests/transforms/test_GaussianBlur.py +++ b/tests/transforms/test_GaussianBlur.py @@ -1,6 +1,7 @@ -from PIL import Image import unittest +from PIL import Image + from lightly.transforms import GaussianBlur diff --git a/tests/transforms/test_Jigsaw.py b/tests/transforms/test_Jigsaw.py index 77487afe9..964e738c5 100644 --- a/tests/transforms/test_Jigsaw.py +++ b/tests/transforms/test_Jigsaw.py @@ -1,6 +1,7 @@ -from PIL import Image import unittest +from PIL import Image + from lightly.transforms import Jigsaw diff --git a/tests/transforms/test_Solarize.py b/tests/transforms/test_Solarize.py index 5bb44e468..4a1cda349 100644 --- a/tests/transforms/test_Solarize.py +++ b/tests/transforms/test_Solarize.py @@ -1,4 +1,5 @@ import unittest + from PIL import Image from lightly.transforms.solarize import RandomSolarization diff --git a/tests/transforms/test_location_to_NxN_grid.py b/tests/transforms/test_location_to_NxN_grid.py index f3b607517..2ec1beb5b 100644 --- a/tests/transforms/test_location_to_NxN_grid.py +++ b/tests/transforms/test_location_to_NxN_grid.py @@ -1,4 +1,5 @@ import torch + import lightly.transforms.random_crop_and_flip_with_grid as test_module diff --git a/tests/transforms/test_moco_transform.py b/tests/transforms/test_moco_transform.py index 8b554ccf0..eef1c651f 100644 --- a/tests/transforms/test_moco_transform.py +++ b/tests/transforms/test_moco_transform.py @@ -11,6 +11,7 @@ def test_moco_v1_multi_view_on_pil_image(): assert output[0].shape == (3, 32, 32) assert output[1].shape == (3, 32, 32) + def test_moco_v2_multi_view_on_pil_image(): multi_view_transform = MoCoV2Transform(input_size=32) sample = Image.new("RGB", (100, 100)) diff --git a/tests/transforms/test_multi_crop_transform.py b/tests/transforms/test_multi_crop_transform.py index a47f6831d..6728eb52d 100644 --- a/tests/transforms/test_multi_crop_transform.py +++ b/tests/transforms/test_multi_crop_transform.py @@ -1 +1 @@ -from lightly.transforms.multi_crop_transform import MultiCropTranform \ No newline at end of file +from lightly.transforms.multi_crop_transform import MultiCropTranform diff --git a/tests/transforms/test_multi_view_transform.py b/tests/transforms/test_multi_view_transform.py index 01eef32dd..dddd4db3d 100644 --- a/tests/transforms/test_multi_view_transform.py +++ b/tests/transforms/test_multi_view_transform.py @@ -1,7 +1,9 @@ -from PIL import Image import unittest -from lightly.transforms.multi_view_transform import MultiViewTransform + import torchvision.transforms as T +from PIL import Image + +from lightly.transforms.multi_view_transform import MultiViewTransform def test_multi_view_on_pil_image(): diff --git a/tests/transforms/test_pirl_transform.py b/tests/transforms/test_pirl_transform.py index 46838d9f3..5042a1e8f 100644 --- a/tests/transforms/test_pirl_transform.py +++ b/tests/transforms/test_pirl_transform.py @@ -1,4 +1,5 @@ from PIL import Image + from lightly.transforms.pirl_transform import PIRLTransform diff --git a/tests/transforms/test_simclr_transform.py b/tests/transforms/test_simclr_transform.py index 9115dc05f..70fff7ab4 100644 --- a/tests/transforms/test_simclr_transform.py +++ b/tests/transforms/test_simclr_transform.py @@ -1,8 +1,6 @@ from PIL import Image -from lightly.transforms.simclr_transform import ( - SimCLRTransform, - SimCLRViewTransform, -) + +from lightly.transforms.simclr_transform import SimCLRTransform, SimCLRViewTransform def test_view_on_pil_image(): diff --git a/tests/transforms/test_vicreg_transform.py b/tests/transforms/test_vicreg_transform.py index 16dd2772a..5a2b0633d 100644 --- a/tests/transforms/test_vicreg_transform.py +++ b/tests/transforms/test_vicreg_transform.py @@ -1,8 +1,6 @@ from PIL import Image -from lightly.transforms.vicreg_transform import ( - VICRegTransform, - VICRegViewTransform, -) + +from lightly.transforms.vicreg_transform import VICRegTransform, VICRegViewTransform def test_view_on_pil_image(): diff --git a/tests/transforms/test_vicregl_transform.py b/tests/transforms/test_vicregl_transform.py index f761f11b8..2e9d8abc3 100644 --- a/tests/transforms/test_vicregl_transform.py +++ b/tests/transforms/test_vicregl_transform.py @@ -1,8 +1,6 @@ from PIL import Image -from lightly.transforms.vicregl_transform import ( - VICRegLTransform, - VICRegLViewTransform, -) + +from lightly.transforms.vicregl_transform import VICRegLTransform, VICRegLViewTransform def test_view_on_pil_image(): diff --git a/tests/utils/test_debug.py b/tests/utils/test_debug.py index 0d0701131..7d8c419f0 100644 --- a/tests/utils/test_debug.py +++ b/tests/utils/test_debug.py @@ -1,14 +1,16 @@ -import unittest -import torch import math +import unittest + import numpy as np +import torch from PIL import Image -from lightly.utils import debug from lightly.data import collate +from lightly.utils import debug try: import matplotlib.pyplot as plt + MATPLOTLIB_AVAILABLE = True except ImportError: MATPLOTLIB_AVAILABLE = False @@ -18,14 +20,13 @@ class TestDebug(unittest.TestCase): - def _generate_random_image(self, w: int, h: int, c: int): array = np.random.rand(h, w, c) * 255 - image = Image.fromarray(array.astype('uint8')).convert('RGB') + image = Image.fromarray(array.astype("uint8")).convert("RGB") return image def test_std_of_l2_normalized_collapsed(self): - z = torch.ones(BATCH_SIZE, DIMENSION) # collapsed output + z = torch.ones(BATCH_SIZE, DIMENSION) # collapsed output self.assertEqual(debug.std_of_l2_normalized(z), 0.0) def test_std_of_l2_normalized_uniform(self, eps: float = 1e-5): @@ -45,30 +46,26 @@ def test_std_of_l2_normalized_raises(self): @unittest.skipUnless(MATPLOTLIB_AVAILABLE, "Matplotlib not installed") def test_plot_augmented_images_image_collate_function(self): - # simclr collate function is a subclass of the image collate function collate_function = collate.SimCLRCollateFunction() for n_images in range(2, 10): with self.subTest(): images = [ - self._generate_random_image(100, 100, 3) - for _ in range(n_images) + self._generate_random_image(100, 100, 3) for _ in range(n_images) ] fig = debug.plot_augmented_images(images, collate_function) self.assertIsNotNone(fig) @unittest.skipUnless(MATPLOTLIB_AVAILABLE, "Matplotlib not installed") def test_plot_augmented_images_multi_view_collate_function(self): - # dion collate function is a subclass of the multi view collate function collate_function = collate.DINOCollateFunction() for n_images in range(1, 10): with self.subTest(): images = [ - self._generate_random_image(100, 100, 3) - for _ in range(n_images) + self._generate_random_image(100, 100, 3) for _ in range(n_images) ] fig = debug.plot_augmented_images(images, collate_function) self.assertIsNotNone(fig) @@ -84,5 +81,3 @@ def test_plot_augmented_images_invalid_collate_function(self): images = [self._generate_random_image(100, 100, 3)] with self.assertRaises(ValueError): debug.plot_augmented_images(images, None) - - diff --git a/tests/utils/test_dist.py b/tests/utils/test_dist.py index 5b035d476..79eb16013 100644 --- a/tests/utils/test_dist.py +++ b/tests/utils/test_dist.py @@ -6,9 +6,7 @@ from lightly.utils import dist - class TestDist(unittest.TestCase): - def test_eye_rank_undist(self): self.assertTrue(torch.all(dist.eye_rank(3) == torch.eye(3))) @@ -18,9 +16,13 @@ def test_eye_rank_dist(self): eye = torch.eye(n).bool() for world_size in [1, 3]: for rank in range(0, world_size): - with mock.patch('torch.distributed.is_initialized', lambda: True),\ - mock.patch('lightly.utils.dist.world_size', lambda: world_size),\ - mock.patch('lightly.utils.dist.rank', lambda: rank): + with mock.patch( + "torch.distributed.is_initialized", lambda: True + ), mock.patch( + "lightly.utils.dist.world_size", lambda: world_size + ), mock.patch( + "lightly.utils.dist.rank", lambda: rank + ): expected = [] for _ in range(0, rank): expected.append(zeros) @@ -28,4 +30,4 @@ def test_eye_rank_dist(self): for _ in range(rank, world_size - 1): expected.append(zeros) expected = torch.cat(expected, dim=1) - self.assertTrue(torch.all(dist.eye_rank(n) == expected)) + self.assertTrue(torch.all(dist.eye_rank(n) == expected)) diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py index 71c06069f..6365f0021 100644 --- a/tests/utils/test_io.py +++ b/tests/utils/test_io.py @@ -1,36 +1,40 @@ import csv -import sys import json +import sys import tempfile import unittest import numpy as np from lightly.utils.io import ( + check_embeddings, check_filenames, + save_custom_metadata, save_embeddings, - check_embeddings, - save_tasks, save_schema, - save_custom_metadata + save_tasks, +) +from tests.api_workflow.mocked_api_workflow_client import ( + MockedApiWorkflowClient, + MockedApiWorkflowSetup, ) -from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup, MockedApiWorkflowClient class TestCLICrop(MockedApiWorkflowSetup): - @classmethod def setUpClass(cls) -> None: - sys.modules["lightly.cli.upload_cli"].ApiWorkflowClient = MockedApiWorkflowClient + sys.modules[ + "lightly.cli.upload_cli" + ].ApiWorkflowClient = MockedApiWorkflowClient def test_save_metadata(self): metadata = [("filename.jpg", {"random_metadata": 42})] - metadata_filepath = tempfile.mktemp('.json', 'metadata') + metadata_filepath = tempfile.mktemp(".json", "metadata") save_custom_metadata(metadata_filepath, metadata) def test_valid_filenames(self): - valid = 'img.png' - non_valid = 'img,1.png' + valid = "img.png" + non_valid = "img,1.png" filenames_list = [ ([valid], True), ([valid, valid], True), @@ -45,14 +49,14 @@ def test_valid_filenames(self): with self.assertRaises(ValueError): check_filenames(filenames) -class TestEmbeddingsIO(unittest.TestCase): +class TestEmbeddingsIO(unittest.TestCase): def setUp(self): # correct embedding file as created through lightly - self.embeddings_path = tempfile.mktemp('.csv', 'embeddings') + self.embeddings_path = tempfile.mktemp(".csv", "embeddings") embeddings = np.random.rand(32, 2) labels = [0 for i in range(len(embeddings))] - filenames = [f'img_{i}.jpg' for i in range(len(embeddings))] + filenames = [f"img_{i}.jpg" for i in range(len(embeddings))] save_embeddings(self.embeddings_path, embeddings, labels, filenames) def test_valid_embeddings(self): @@ -60,123 +64,112 @@ def test_valid_embeddings(self): def test_whitespace_in_embeddings(self): # should fail because there whitespaces in the header columns - lines = ['filenames, embedding_0,embedding_1,labels\n', - 'img_1.jpg, 0.351,0.1231'] - with open(self.embeddings_path, 'w') as f: + lines = [ + "filenames, embedding_0,embedding_1,labels\n", + "img_1.jpg, 0.351,0.1231", + ] + with open(self.embeddings_path, "w") as f: f.writelines(lines) with self.assertRaises(RuntimeError) as context: check_embeddings(self.embeddings_path) - self.assertTrue('must not contain whitespaces' in str(context.exception)) + self.assertTrue("must not contain whitespaces" in str(context.exception)) def test_no_labels_in_embeddings(self): # should fail because there is no `labels` column in the header - lines = ['filenames,embedding_0,embedding_1\n', - 'img_1.jpg,0.351,0.1231'] - with open(self.embeddings_path, 'w') as f: + lines = ["filenames,embedding_0,embedding_1\n", "img_1.jpg,0.351,0.1231"] + with open(self.embeddings_path, "w") as f: f.writelines(lines) with self.assertRaises(RuntimeError) as context: check_embeddings(self.embeddings_path) - self.assertTrue('has no `labels` column' in str(context.exception)) + self.assertTrue("has no `labels` column" in str(context.exception)) def test_no_empty_rows_in_embeddings(self): # should fail because there are empty rows in the embeddings file - lines = ['filenames,embedding_0,embedding_1,labels\n', - 'img_1.jpg,0.351,0.1231\n\n' - 'img_2.jpg,0.311,0.6231'] - with open(self.embeddings_path, 'w') as f: + lines = [ + "filenames,embedding_0,embedding_1,labels\n", + "img_1.jpg,0.351,0.1231\n\n" "img_2.jpg,0.311,0.6231", + ] + with open(self.embeddings_path, "w") as f: f.writelines(lines) with self.assertRaises(RuntimeError) as context: check_embeddings(self.embeddings_path) - self.assertTrue('must not have empty rows' in str(context.exception)) + self.assertTrue("must not have empty rows" in str(context.exception)) def test_embeddings_extra_rows(self): rows = [ - ['filenames', 'embedding_0', 'embedding_1', 'labels', 'selected', - 'masked'], - ['image_0.jpg', '3.4', '0.23', '0', '1', '0'], - ['image_1.jpg', '3.4', '0.23', '1', '0', '1'] + ["filenames", "embedding_0", "embedding_1", "labels", "selected", "masked"], + ["image_0.jpg", "3.4", "0.23", "0", "1", "0"], + ["image_1.jpg", "3.4", "0.23", "1", "0", "1"], ] - with open(self.embeddings_path, 'w') as f: + with open(self.embeddings_path, "w") as f: csv_writer = csv.writer(f) csv_writer.writerows(rows) check_embeddings(self.embeddings_path, remove_additional_columns=True) with open(self.embeddings_path) as csv_file: - csv_reader = csv.reader(csv_file, delimiter=',') + csv_reader = csv.reader(csv_file, delimiter=",") for row_read, row_original in zip(csv_reader, rows): self.assertListEqual(row_read, row_original[:-2]) def test_embeddings_extra_rows_special_order(self): input_rows = [ - ['filenames', 'embedding_0', 'embedding_1', 'masked', 'labels', 'selected'], - ['image_0.jpg', '3.4', '0.23', '0', '1', '0'], - ['image_1.jpg', '3.4', '0.23', '1', '0', '1'] + ["filenames", "embedding_0", "embedding_1", "masked", "labels", "selected"], + ["image_0.jpg", "3.4", "0.23", "0", "1", "0"], + ["image_1.jpg", "3.4", "0.23", "1", "0", "1"], ] correct_output_rows = [ - ['filenames', 'embedding_0', 'embedding_1', 'labels'], - ['image_0.jpg', '3.4', '0.23', '1'], - ['image_1.jpg', '3.4', '0.23', '0'] + ["filenames", "embedding_0", "embedding_1", "labels"], + ["image_0.jpg", "3.4", "0.23", "1"], + ["image_1.jpg", "3.4", "0.23", "0"], ] - with open(self.embeddings_path, 'w') as f: + with open(self.embeddings_path, "w") as f: csv_writer = csv.writer(f) csv_writer.writerows(input_rows) check_embeddings(self.embeddings_path, remove_additional_columns=True) with open(self.embeddings_path) as csv_file: - csv_reader = csv.reader(csv_file, delimiter=',') + csv_reader = csv.reader(csv_file, delimiter=",") for row_read, row_original in zip(csv_reader, correct_output_rows): self.assertListEqual(row_read, row_original) def test_save_tasks(self): tasks = [ - 'task1', - 'task2', - 'task3', + "task1", + "task2", + "task3", ] - with tempfile.NamedTemporaryFile(suffix='.json') as file: + with tempfile.NamedTemporaryFile(suffix=".json") as file: save_tasks(file.name, tasks) - with open(file.name, 'r') as f: + with open(file.name, "r") as f: loaded = json.load(f) self.assertListEqual(tasks, loaded) def test_save_schema(self): - description = 'classification' + description = "classification" ids = [1, 2, 3, 4] - names = ['name1', 'name2', 'name3', 'name4'] + names = ["name1", "name2", "name3", "name4"] expected_format = { - 'task_type': 'classification', - 'categories': [ - { - 'id': 1, - 'name': 'name1' - }, - { - 'id': 2, - 'name': 'name2' - }, - { - 'id': 3, - 'name': 'name3' - }, - { - 'id': 4, - 'name': 'name4' - }, - ] + "task_type": "classification", + "categories": [ + {"id": 1, "name": "name1"}, + {"id": 2, "name": "name2"}, + {"id": 3, "name": "name3"}, + {"id": 4, "name": "name4"}, + ], } - with tempfile.NamedTemporaryFile(suffix='.json') as file: + with tempfile.NamedTemporaryFile(suffix=".json") as file: save_schema(file.name, description, ids, names) - with open(file.name, 'r') as f: + with open(file.name, "r") as f: loaded = json.load(f) self.assertListEqual(sorted(expected_format), sorted(loaded)) def test_save_schema_different(self): with self.assertRaises(ValueError): save_schema( - 'name_doesnt_matter', - 'description_doesnt_matter', + "name_doesnt_matter", + "description_doesnt_matter", [1, 2], - ['name1'], + ["name1"], ) diff --git a/tests/utils/test_scheduler.py b/tests/utils/test_scheduler.py index 4686e8a24..4b585e079 100644 --- a/tests/utils/test_scheduler.py +++ b/tests/utils/test_scheduler.py @@ -1,4 +1,5 @@ import unittest + import torch from torch import nn @@ -20,13 +21,14 @@ def test_cosine_schedule(self): with self.assertRaises(ValueError): cosine_schedule(11, 10, 0.0, 1.0) - def test_CosineWarmupScheduler(self): model = nn.Linear(10, 1) optimizer = torch.optim.SGD( model.parameters(), lr=1.0, momentum=0.0, weight_decay=0.0 ) - scheduler = CosineWarmupScheduler(optimizer, warmup_epochs=3, max_epochs=6, verbose=True) + scheduler = CosineWarmupScheduler( + optimizer, warmup_epochs=3, max_epochs=6, verbose=True + ) # warmup self.assertAlmostEqual(scheduler.get_last_lr()[0], 0.333333333) diff --git a/tests/utils/test_version_compare.py b/tests/utils/test_version_compare.py index 891e6696e..ce39dbb6a 100644 --- a/tests/utils/test_version_compare.py +++ b/tests/utils/test_version_compare.py @@ -4,31 +4,30 @@ class TestVersionCompare(unittest.TestCase): - def test_valid_versions(self): - # general test of smaller than version numbers - self.assertEqual(version_compare.version_compare('0.1.4', '1.2.0'), -1) - self.assertEqual(version_compare.version_compare('1.1.0', '1.2.0'), -1) + self.assertEqual(version_compare.version_compare("0.1.4", "1.2.0"), -1) + self.assertEqual(version_compare.version_compare("1.1.0", "1.2.0"), -1) # test bigger than - self.assertEqual(version_compare.version_compare('1.2.0', '1.1.0'), 1) - self.assertEqual(version_compare.version_compare('1.2.0', '0.1.4'), 1) + self.assertEqual(version_compare.version_compare("1.2.0", "1.1.0"), 1) + self.assertEqual(version_compare.version_compare("1.2.0", "0.1.4"), 1) # test equal - self.assertEqual(version_compare.version_compare('1.2.0', '1.2.0'), 0) - + self.assertEqual(version_compare.version_compare("1.2.0", "1.2.0"), 0) def test_invalid_versions(self): with self.assertRaises(ValueError): - version_compare.version_compare('1.2', '1.1.0') + version_compare.version_compare("1.2", "1.1.0") with self.assertRaises(ValueError): - version_compare.version_compare('1.2.0.1', '1.1.0') + version_compare.version_compare("1.2.0.1", "1.1.0") # test within same minor version and with special cases with self.assertRaises(ValueError): - self.assertEqual(version_compare.version_compare('1.0.7', '1.1.0.dev1'), -1) + self.assertEqual(version_compare.version_compare("1.0.7", "1.1.0.dev1"), -1) with self.assertRaises(ValueError): - self.assertEqual(version_compare.version_compare('1.1.0.dev1', '1.1.0rc1'), -1) \ No newline at end of file + self.assertEqual( + version_compare.version_compare("1.1.0.dev1", "1.1.0rc1"), -1 + )