Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
hamedrq7 authored Sep 30, 2023
1 parent f05b266 commit e022821
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 0 deletions.
Binary file added exploring_data/all_domains.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 25 additions & 0 deletions exploring_data/dataloader_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from out import get_data_loaders
import torch

def loaders_demo():
full_dataloaders, _ = get_data_loaders(
{'train': './data/1000_train_mnistmnistmsvhnsynusps.npz',
'test': './data/1000_test_mnistmnistmsvhnsynusps.npz',
},
batch_size= 64)
print(full_dataloaders.keys())

for phase in ['train', 'test', 'test_missing']:
print(f'{phase} data...')
for batch_indx, (images, features, domain_labels, digit_labels) in enumerate(full_dataloaders[phase]):
print(f'{batch_indx}-th batch')
print('images shape: ', images.shape)
print('features shape: ', features.shape)
if phase == 'test_missing':
print('in test-missing dataloaders, since the features are not available, features are filled with zeros', torch.sum(features))
print('domain labels freq: ', torch.unique(domain_labels, return_counts=True))
print('digit labels freq: ', torch.unique(digit_labels, return_counts=True))
print()
break

loaders_demo()
94 changes: 94 additions & 0 deletions exploring_data/exploring_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
from out import get_data_loaders

NUM_SAMPLES_FROM_DIGITS = 5
NUM_DOMAINS = 5
NUM_DIGITS = 10
SEED = 141+1

DOMAIN_NAME_DICT = {
0: 'mnist',
1: 'mnistm',
2: 'svhn',
3: 'syn',
4: 'usps'
}

# get data loaders
# 1000_test_mnistmnistmsvhnsynusps.npz
# 1000_train_mnistmnistmsvhnsynusps.npz
full_dataloaders, _ = get_data_loaders(
{'train': './data/1000_train_mnistmnistmsvhnsynusps.npz',
'test': './data/1000_test_mnistmnistmsvhnsynusps.npz',
},
batch_size= 64, init_seed=SEED)

# utils
def get_samples(dataloader, num_samples: int) -> np.ndarray:
samples = [
[ [], [], [], [], [], [], [], [], [], [] ],
[ [], [], [], [], [], [], [], [], [], [] ],
[ [], [], [], [], [], [], [], [], [], [] ],
[ [], [], [], [], [], [], [], [], [], [] ],
[ [], [], [], [], [], [], [], [], [], [] ],
]

for batch_indx, (images, _, domain_labels, digit_labels) in enumerate(dataloader):
# images are normalized using mean=(0.5, 0.5, 0.5) and std=(0.5, 0.5, 0.5),
# so images habve been normalized using: image = image - mean / std
# to plot images we have to undo the normalization
images = images * 0.5 + 0.5

for img_indx, curr_image in enumerate(images):

if len(samples[domain_labels[img_indx]][digit_labels[img_indx]]) < num_samples:
samples[domain_labels[img_indx]][digit_labels[img_indx]].append(curr_image.numpy())

# convert samples to numpy array
return np.array(samples)



# plot NUM_SAMPLES_FROM_DIGITS of each domain together:
samples = get_samples(full_dataloaders['train'], NUM_SAMPLES_FROM_DIGITS)

fig_height = 10
fig_width = fig_height * (NUM_DOMAINS*NUM_SAMPLES_FROM_DIGITS) / NUM_DIGITS
fig, axs = plt.subplots(NUM_DIGITS, NUM_DOMAINS*NUM_SAMPLES_FROM_DIGITS, figsize=(fig_width, fig_height),
# gridspec_kw=dict(hspace=0.0)
gridspec_kw={'height_ratios':[1]*10}
)

for i in range(NUM_DIGITS):
for dom in range(NUM_DOMAINS):
for j in range(NUM_SAMPLES_FROM_DIGITS):
image = np.transpose(samples[dom][i][j], (1, 2, 0)) # Transpose the image to (32, 32, 3)
axs[i, j+dom*NUM_SAMPLES_FROM_DIGITS].imshow(image)
axs[i, j+dom*NUM_SAMPLES_FROM_DIGITS].axis('off')

plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
plt.tight_layout()
plt.savefig('exploring_data/all_domains.jpg')
plt.clf()


# plot 10 samples of each domain separately
samples = get_samples(full_dataloaders['train'], 10)

for domain in DOMAIN_NAME_DICT.keys():
print(f'plotting samples from domain {DOMAIN_NAME_DICT[domain]}')

fig, axs = plt.subplots(10, 10, figsize=(10, 10), gridspec_kw=dict(hspace=0.0))

for i in range(10):
for j in range(10):
image = np.transpose(samples[domain][i][j], (1, 2, 0)) # Transpose the image to (32, 32, 3)
axs[j, i].imshow(image)
axs[j, i].axis('off')

plt.tight_layout()
plt.title(f'{DOMAIN_NAME_DICT[domain]}')
plt.savefig(f'exploring_data/{DOMAIN_NAME_DICT[domain]}.jpg')
plt.clf()
Binary file added exploring_data/mnist.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added exploring_data/mnistm.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added exploring_data/svhn.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added exploring_data/syn.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added exploring_data/usps.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit e022821

Please sign in to comment.