Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alterando para o treinamento com máscaras de oclusão #326

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
19 changes: 13 additions & 6 deletions bin/gen_mask_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ def __init__(self, impl, variants_n=2):
self.impl = impl
self.variants_n = variants_n

def get_masks(self, img):
def get_masks(self, img, indir= None):
img = np.transpose(np.array(img), (2, 0, 1))
return [self.impl(img)[0] for _ in range(self.variants_n)]
return [self.impl(img,indir=indir)[0] for _ in range(self.variants_n)]


def process_images(src_images, indir, outdir, config):
if config.generator_kind == 'segmentation':
mask_generator = SegmentationMask(**config.mask_generator_kwargs)
elif config.generator_kind == 'random':
variants_n = config.mask_generator_kwargs.pop('variants_n', 2)
mask_generator = MakeManyMasksWrapper(MixedMaskGenerator(**config.mask_generator_kwargs),
variants_n=variants_n)
mixed_mask_generator = MixedMaskGenerator(**config.mask_generator_kwargs)
mask_generator = MakeManyMasksWrapper(mixed_mask_generator, variants_n=variants_n)
else:
raise ValueError(f'Unexpected generator kind: {config.generator_kind}')

Expand Down Expand Up @@ -59,7 +59,7 @@ def process_images(src_images, indir, outdir, config):
image = image.resize(out_size, resample=Image.BICUBIC)

# generate and select masks
src_masks = mask_generator.get_masks(image)
src_masks = mask_generator.get_masks(image,indir)

filtered_image_mask_pairs = []
for cur_mask in src_masks:
Expand Down Expand Up @@ -104,7 +104,13 @@ def main(args):
os.makedirs(args.outdir, exist_ok=True)

config = load_yaml(args.config)

if args.occ_indir:
if "occ_mask_indir" in config.mask_generator_kwargs:
config.mask_generator_kwargs["occ_mask_indir"]= args.occ_indir
else:
print("ERROR | Trying to generate using occlusion masks but the config file does not contain the path to them")

print("DEBUG",config)
in_files = list(glob.glob(os.path.join(args.indir, '**', f'*.{args.ext}'), recursive=True))
if args.n_jobs == 0:
process_images(in_files, args.indir, args.outdir, config)
Expand All @@ -124,6 +130,7 @@ def main(args):
aparser.add_argument('config', type=str, help='Path to config for dataset generation')
aparser.add_argument('indir', type=str, help='Path to folder with images')
aparser.add_argument('outdir', type=str, help='Path to folder to store aligned images and masks to')
aparser.add_argument('--occ_indir', type=str,default=None, help ='Path to the occlusion folder')
aparser.add_argument('--n-jobs', type=int, default=0, help='How many processes to use')
aparser.add_argument('--ext', type=str, default='jpg', help='Input image extension')

Expand Down
28 changes: 28 additions & 0 deletions configs/data_gen/random_thin_occ_512.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
generator_kind: random

mask_generator_kwargs:
irregular_proba: 1
irregular_kwargs:
min_times: 4
max_times: 50
max_width: 10
max_angle: 4
max_len: 40
box_proba: 0
segm_proba: 0
squares_proba: 0

occ_mask: True
occ_mask_indir: ${training.location.occ_mask_root_dir} #overwrite when running gen_mask_dataset.py

variants_n: 5

max_masks_per_image: 1

cropping:
out_min_size: 256
handle_small_mode: upscale
out_square_crop: True
crop_min_overlap: 1

max_tamper_area: 0.5
27 changes: 15 additions & 12 deletions configs/training/data/abl-04-256-mh-dist.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# @package _group_

batch_size: 10
batch_size: 8
val_batch_size: 2
num_workers: 3
num_workers: 3

train:
indir: ${location.data_root_dir}/train
Expand All @@ -11,18 +11,21 @@ train:
irregular_proba: 1
irregular_kwargs:
max_angle: 4
max_len: 200
max_width: 100
max_times: 5
max_len: 100
max_width: 20
max_times: 3
min_times: 1

box_proba: 1
box_kwargs:
margin: 10
bbox_min_size: 30
bbox_max_size: 150
max_times: 4
min_times: 1
# box_proba: 1
# box_kwargs:
# margin: 10
# bbox_min_size: 30
# bbox_max_size: 150
# max_times: 4
# min_times: 1

occ_mask: True
occ_mask_indir: ${location.occ_mask_root_dir}/train

segm_proba: 0

Expand Down
1 change: 1 addition & 0 deletions configs/training/generator/ffc_resnet_075.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ ngf: 64
n_downsampling: 3
n_blocks: 9
add_out_act: sigmoid
conv_kind: depthwise

init_conv_kwargs:
ratio_gin: 0
Expand Down
1 change: 1 addition & 0 deletions configs/training/lama-fourier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ training_model:
visualize_each_iters: 1000
concat_mask: true
store_discr_outputs_for_vis: true

losses:
l1:
weight_missing: 0
Expand Down
6 changes: 6 additions & 0 deletions configs/training/location/places_standard.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# @package _group_
data_root_dir: /home/isaacfs/datasets/places_standard_dataset/
occ_mask_root_dir: /home/isaacfs/occlusions_mask/
out_root_dir: /home/isaacfs/lama/experiments/
tb_dir: /home/isaacfs/lama/tb_logs/
pretrained_models: /home/isaacfs/lama/
10 changes: 7 additions & 3 deletions configs/training/trainer/any_gpu_large_ssim_ddp_final.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
kwargs:
gpus: -1
accelerator: ddp
max_epochs: 40
max_epochs: 20
gradient_clip_val: 1
log_gpu_memory: None # set to min_max or all for debug
limit_train_batches: 25000
limit_train_batches: 10
val_check_interval: ${trainer.kwargs.limit_train_batches}
# fast_dev_run: True # uncomment for faster debug
# track_grad_norm: 2 # uncomment to track L2 gradients norm
Expand All @@ -22,10 +22,14 @@ kwargs:
# limit_val_batches: 1000000
replace_sampler_ddp: False

logs:
log_on_epoch: true
log_on_step: false

checkpoint_kwargs:
verbose: True
save_top_k: 5
save_last: True
period: 1
monitor: val_ssim_fid100_f1_total_mean
mode: max
mode: max
14 changes: 8 additions & 6 deletions fetch_data/places_standard_test_val_gen_masks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ mkdir -p places_standard_dataset/visual_test/


python3 bin/gen_mask_dataset.py \
$(pwd)/configs/data_gen/random_thick_512.yaml \
places_standard_dataset/val_hires/ \
places_standard_dataset/val/
$(pwd)/configs/data_gen/random_thin_occ_512.yaml \
/home/isaacfs/places_standard_dataset/val_hires/ \
/home/isaacfs/places_standard_dataset/val/ \
--occ_indir /home/isaacfs/occlusions_mask/original/test/test_large/

python3 bin/gen_mask_dataset.py \
$(pwd)/configs/data_gen/random_thick_512.yaml \
places_standard_dataset/visual_test_hires/ \
places_standard_dataset/visual_test/
$(pwd)/configs/data_gen/random_thin_occ_512.yaml \
/home/isaacfsplaces_standard_dataset/visual_test_hires/ \
/home/isaacfs/places_standard_dataset/visual_test/ \
--occ_indir /home/isaacfs/occlusions_mask/original/val/val_large/
4 changes: 2 additions & 2 deletions fetch_data/places_standard_train_prepare.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mkdir -p places_standard_dataset/train
#mkdir -p places_standard_dataset/train

# untar without folder structure
tar -xvf train_large_places365standard.tar -C places_standard_dataset/train
#tar -xvf train_large_places365standard.tar -C places_standard_dataset/train

# create location config places.yaml
PWD=$(pwd)
Expand Down
4 changes: 2 additions & 2 deletions fetch_data/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import random

test_files_path = os.path.abspath('.') + '/places_standard_dataset/original/test/'
test_files_path = os.path.abspath('.') + '/places_standard_dataset/original/test/test_large/'
list_of_random_test_files = os.path.abspath('.') + '/places_standard_dataset/original/test_random_files.txt'

test_files = [
Expand All @@ -22,7 +22,7 @@

# --------------------------------

val_files_path = os.path.abspath('.') + '/places_standard_dataset/original/val/'
val_files_path = os.path.abspath('.') + '/places_standard_dataset/original/val/val_large/'
list_of_random_val_files = os.path.abspath('.') + '/places_standard_dataset/original/val_random_files.txt'

val_files = [
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ pyyaml
tqdm
numpy
easydict==1.9.0
scikit-image==0.17.2
scikit-learn==0.24.2
scikit-image
scikit-learn
opencv-python
tensorflow
joblib
Expand All @@ -16,5 +16,5 @@ tabulate
kornia==0.5.0
webdataset
packaging
scikit-learn==0.24.2
wldhx.yadisk-direct

3 changes: 2 additions & 1 deletion saicinpainting/training/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, indir, mask_generator, transform):
self.mask_generator = mask_generator
self.transform = transform
self.iter_i = 0
self.indir = indir

def __len__(self):
return len(self.in_files)
Expand All @@ -39,7 +40,7 @@ def __getitem__(self, item):
img = self.transform(image=img)['image']
img = np.transpose(img, (2, 0, 1))
# TODO: maybe generate mask before augmentations? slower, but better for segmentation-based masks
mask = self.mask_generator(img, iter_i=self.iter_i)
mask = self.mask_generator(img, path=path, iter_i=self.iter_i, indir=self.indir)
self.iter_i += 1
return dict(image=img,
mask=mask)
Expand Down
32 changes: 30 additions & 2 deletions saicinpainting/training/data/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import hashlib
import logging
from enum import Enum

import cv2
import numpy as np
import os

from saicinpainting.evaluation.masks.mask import SegmentationMask
from saicinpainting.utils import LinearRamp
Expand Down Expand Up @@ -256,9 +256,12 @@ def __init__(self, irregular_proba=1/3, irregular_kwargs=None,
squares_proba=0, squares_kwargs=None,
superres_proba=0, superres_kwargs=None,
outpainting_proba=0, outpainting_kwargs=None,
occ_mask=False, occ_mask_indir=None,
invert_proba=0):
self.probas = []
self.gens = []
self.occ_mask = occ_mask
self.occ_mask_indir = occ_mask_indir

if irregular_proba > 0:
self.probas.append(irregular_proba)
Expand Down Expand Up @@ -306,12 +309,37 @@ def __init__(self, irregular_proba=1/3, irregular_kwargs=None,
self.probas /= self.probas.sum()
self.invert_proba = invert_proba

def __call__(self, img, iter_i=None, raw_image=None):
def __call__(self, img, path=None, iter_i=None, raw_image=None, indir=None):
kind = np.random.choice(len(self.probas), p=self.probas)
gen = self.gens[kind]
result = gen(img, iter_i=iter_i, raw_image=raw_image)
if self.invert_proba > 0 and random.random() < self.invert_proba:
result = 1 - result

# Training for parallax tasks
if self.occ_mask:

if path is None:
raise Exception("Trying to use occlusion mask but no path is provided!\nTroubleshoot-Idea: check the dataset call for the mask generation function")

# Deriving the occlusion mask path from the image path
filename = os.path.basename(path)
occ_mask_path = path.replace(indir,self.occ_mask_indir)
occ_mask_path = occ_mask_path.replace(filename, "")
occ_mask_path = os.path.join(occ_mask_path, f"occlusion_{filename}")

occ_mask = cv2.imread(occ_mask_path)
occ_mask = cv2.cvtColor(occ_mask, cv2.COLOR_BGR2GRAY)

occ_mask = np.expand_dims(occ_mask, axis=0)
# Convert mask2 to 0 and 1
occ_mask = occ_mask / np.max(occ_mask)
occ_mask = (occ_mask > 0).astype('float32')

# Blend the masks
result = np.logical_or(result, occ_mask).astype(result.dtype)


return result


Expand Down
Loading