Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support for stacking dataset #3

Open
wants to merge 235 commits into
base: new_datasets
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
235 commits
Select commit Hold shift + click to select a range
0567b92
[WIP] Changes to add support for stacking dataset in micro search
j-varun Jul 2, 2018
cc446f7
[WIP] Bug Fixes
j-varun Jul 6, 2018
32e0820
Minor Fixes
j-varun Jul 6, 2018
e74d467
Merge pull request #3 from j-varun/good_fixes
ahundt Jul 6, 2018
1ae4b6c
Added missing files
j-varun Jul 6, 2018
6281d10
Merge pull request #3 from j-varun/good_fixes
ahundt Jul 6, 2018
b648cfe
block_stacking_reader.py better reading test loop
ahundt Jul 6, 2018
a4acbbc
add costar stacking search script
ahundt Jul 8, 2018
d4757ac
[WIP] Fixes for generator errors and other modifications to support s…
j-varun Jul 9, 2018
85a95c0
Merge commit 'd4757ac5df3e1d4a7e100dfcbabc9101d3d6372b' into stacking…
ahundt Jul 10, 2018
85fe169
Merge branch 'stacking_search' of github.com:ahundt/enas into stackin…
ahundt Jul 10, 2018
799282a
val_test_size to 200 for stacking
ahundt Jul 10, 2018
5a4975c
block stacking search batch size and num epochs updated
ahundt Jul 10, 2018
4982c72
Merge commit 'b704f400aa11c580b5a6d432a55840870c898d0e' into stacking…
ahundt Jul 10, 2018
3e2b432
[WIP]Changed Loss and Accuracy fn for regression task
j-varun Jul 11, 2018
72bf0d5
[WIP] Added more metrics and arguments + minor bug fixes
j-varun Jul 12, 2018
af4d86d
Minor Fixes
j-varun Jul 12, 2018
b774b87
[WIP] Minor changes to micro_child
j-varun Jul 12, 2018
faf5646
Merge commit 'b774b87beae87c0af469d234e4ac4041f403185f' into stacking…
ahundt Jul 13, 2018
fb9af5f
models.py pep8
ahundt Jul 13, 2018
2c53c5e
[WIP] Fixed hardcoded shapes
j-varun Jul 13, 2018
45dc0ca
[WIP] Minor Bug Fixes
j-varun Jul 13, 2018
0387a3a
Merged with stacking_search
j-varun Jul 13, 2018
2f11da0
correct validation size
ahundt Jul 13, 2018
a5f6352
bug fix
j-varun Jul 13, 2018
cf44641
Merge commit 'a5f63525eb0bba501dc4ab4d194f0239984857aa' into stacking…
ahundt Jul 13, 2018
e2d29c9
models.py reduce images per example to train on for temp bug workaround.
ahundt Jul 13, 2018
5658c27
models.py hack to reduce time in a batch
ahundt Jul 14, 2018
51609a7
cifar10 main.py add option for default child and controller optimizers
ahundt Jul 20, 2018
2b5611a
Miscellaneous Bug Fixes
j-varun Jul 21, 2018
9e3dc29
[WIP] Changes to support micro search based on loss and other fixes
j-varun Jul 21, 2018
114619a
[WIP] Changes for using validation loss as reward for controller
j-varun Jul 23, 2018
c96ce47
Merge commit '9e3dc292dd485d65bf053e587d4b486190c553ed' into stacking…
ahundt Jul 24, 2018
3b26af0
models.py indentation fix
ahundt Jul 24, 2018
bab03cf
models.py fix indentation
ahundt Jul 24, 2018
72c9b9a
Merge branch 'stacking_search' of https://github.com/ahundt/enas into…
j-varun Jul 24, 2018
ca3b58d
[WIP] Minor Fix for initializing child class
j-varun Jul 24, 2018
e1ac057
[WIP] Removed unnecessary print statements
j-varun Jul 24, 2018
7471d32
increase training # of workers and max queue size
ahundt Jul 25, 2018
291e982
costar_block_stacking_search.sh attempt at tuning parameters
ahundt Jul 25, 2018
60ab83a
costar_block_stacking_search.sh attempt at tuning parameters
ahundt Jul 25, 2018
ead2194
global_avg_pool -> global_max_pool
ahundt Jul 25, 2018
edb73f1
costar_block_stacking_search.sh reduce batch size due to memory const…
ahundt Jul 25, 2018
d5c87c5
lower child out filters to fit in memory
ahundt Jul 25, 2018
5857b05
[WIP] Fixes for fashion_mnist and stacking dataset
j-varun Jul 26, 2018
821874d
Minor bug fix
j-varun Jul 26, 2018
200a61d
Merge branch 'stacking_search' of https://github.com/ahundt/enas into…
j-varun Jul 26, 2018
cad34c7
models.py artifically lower number of steps in an epoch
ahundt Jul 28, 2018
c111f9e
micro_child.py lambda to def
ahundt Jul 28, 2018
d06b58f
models.py make stacking epoch longer by batch_size
ahundt Jul 28, 2018
585e6e9
Added additional evaluation metrics
j-varun Jul 31, 2018
463d143
[WIP] Minor bug fix
j-varun Jul 31, 2018
11f5108
[WIP] Minor fixes
j-varun Jul 31, 2018
f74d71d
Merge branch 'stacking_search' of https://github.com/ahundt/enas into…
j-varun Jul 31, 2018
5245c81
option for translation only training added
j-varun Aug 5, 2018
f465bce
[WIP] Minor Fixes
j-varun Aug 6, 2018
8a5e22c
Fixes for translation only run
j-varun Aug 7, 2018
f969e3f
Added Metrics for eval
j-varun Aug 7, 2018
0ca1112
Bug Fix
j-varun Aug 7, 2018
de02eeb
Fix for fashion-mnist
j-varun Aug 8, 2018
e955322
Added option to switch to alternate reward for stacking dataset
j-varun Aug 8, 2018
6ef5ae7
Added flad for reward
j-varun Aug 8, 2018
6721f97
Changes to support training of only the rotation component of the sta…
j-varun Aug 8, 2018
0480bb9
Minor Fixes
j-varun Aug 8, 2018
b0d56c2
Merge pull request #5 from j-varun/good_fixes
ahundt Aug 8, 2018
f0f440a
Minor Bug Fix
j-varun Aug 8, 2018
ffe9435
configure search parameters for rotation and translation only search
ahundt Aug 8, 2018
c2c1bc0
Merge commit 'f0f440a58a225f17c6b2457f1d0b0d763c47e802' into stacking…
ahundt Aug 8, 2018
62b9ede
Added Configurable max loss for reward
j-varun Aug 8, 2018
0aebeb4
Merge branch 'stacking_search' of https://github.com/ahundt/enas into…
j-varun Aug 8, 2018
b6bf3b5
Miscellaneous Fixes
j-varun Aug 8, 2018
79ac6bb
Merge pull request #6 from j-varun/good_fixes
ahundt Aug 8, 2018
050b7c7
costar_block_stacking_translation_search.sh last training run only fo…
ahundt Aug 9, 2018
aacd81c
models.py increase estimated images per example from 1 to 15
ahundt Aug 10, 2018
cf1c7fa
models.py estimated images per example = 16
ahundt Aug 10, 2018
79c27cc
costar_block_stacking_rotation_search.sh configure to match translati…
ahundt Aug 10, 2018
dc5e0f7
rotation search try increasing batch size to 32
ahundt Aug 10, 2018
da64b09
configure new search directories
ahundt Aug 10, 2018
cfce2b2
rotation search batch size 16
ahundt Aug 10, 2018
0cb5e68
costar stacking lower estimated images per example to 4
ahundt Aug 10, 2018
7be1e2b
models.py lower estimated_images_per_example to 2
ahundt Aug 10, 2018
6e232e5
Fixes for accuracy values being printed
j-varun Aug 11, 2018
a4a1087
batch_norm -> group_norm with exception of masked batch norm...
ahundt Aug 12, 2018
d0d309d
set max_loss to 2 to avoid negative losses
ahundt Aug 12, 2018
8e7f8e2
Merge commit '6e232e573515826ce172a3d0bb45154193c409ff' into stacking…
ahundt Aug 12, 2018
54a1cda
image_ops.py attempt group_norm fixes
ahundt Aug 12, 2018
861694f
group_norm try to fix shape problem
ahundt Aug 12, 2018
28f7764
image_ops.py group_norm attempt to fix shape
ahundt Aug 12, 2018
b2dc520
image_ops.py group_norm attempt at shape fix
ahundt Aug 12, 2018
9b0912e
image_ops.py more attempts at group norm fix
ahundt Aug 12, 2018
21673e7
image_ops.py group_norm fix attempt
ahundt Aug 12, 2018
69ed6a6
image_ops.py group_norm another shape fix attempt
ahundt Aug 12, 2018
8cc60ba
image_ops.py group_norm another shape fix attempt
ahundt Aug 12, 2018
1a803c8
image_ops.py group_norm another shape fix attempt
ahundt Aug 12, 2018
3f65d52
[1, C, 1, 1]
ahundt Aug 12, 2018
c4d2490
image_ops.py debug shape
ahundt Aug 12, 2018
f80a78a
image_ops.py group_norm another shape fix attempt
ahundt Aug 12, 2018
d147251
image_ops.py group_norm another shape fix attempt
ahundt Aug 12, 2018
9c6062c
iimage_ops.py group_norm another shape fix attempt
ahundt Aug 12, 2018
a3a69a6
micro_child.py fixes to is_training parameters
ahundt Aug 12, 2018
d367e52
image_ops.py get some additional debug info
ahundt Aug 12, 2018
2a7c846
image_ops.py group_norm another shape fix attempt
ahundt Aug 12, 2018
14c10c9
image_ops.py norm prints traceback if verbose
ahundt Aug 12, 2018
c049898
typo fix
ahundt Aug 12, 2018
cb1f5e8
image_ops.py group_norm another shape fix attempt
ahundt Aug 12, 2018
bdbef1e
micro_child.py add shape asserts
ahundt Aug 12, 2018
7467a7c
micro_child.py more explicit params
ahundt Aug 12, 2018
9dbeffc
image_ops.py print clearer separator
ahundt Aug 12, 2018
3d2a1d6
micro_child.py clearly set is training as named parameter
ahundt Aug 12, 2018
c986190
micro_child.py better setting of is_training
ahundt Aug 12, 2018
01b5f0d
mico_child.py typo fix
ahundt Aug 12, 2018
400a174
micro_child.py fix some hard coded variables
ahundt Aug 12, 2018
9622384
micro_child.py typo fix
ahundt Aug 12, 2018
edfbd00
micro_child.py typo fix
ahundt Aug 12, 2018
2d06cc8
micro_child.py add missing import
ahundt Aug 12, 2018
b4b183f
micro_child.py add more shape debugging
ahundt Aug 12, 2018
83c2056
micro_chilld.py more shape debugging
ahundt Aug 12, 2018
e14fb57
micro_child.py more shape debug output
ahundt Aug 12, 2018
1dc79a0
micro_child.py more shape debug output
ahundt Aug 12, 2018
a409e46
micro_child.py more shape debug outputs
ahundt Aug 12, 2018
b21f70e
micro_child.py try using integer shape components where possible
ahundt Aug 12, 2018
b73e176
micro_child.py typo fix
ahundt Aug 12, 2018
09006ad
image_ops.py micro_child.py set verbose shape debug output to 0
ahundt Aug 12, 2018
995593c
models.py CRITICAL BUGFIX: TRAIN GENERATOR WAS BEING USED FOR VAL + T…
ahundt Aug 13, 2018
3c970f6
micro_child.py fix is_training parameters
ahundt Aug 13, 2018
ce4fe15
cifar10 parameterize pool distance
ahundt Aug 13, 2018
07b82df
double num_layers so receptive field size is increased substantially
ahundt Aug 13, 2018
a6a748b
main.py set pool distance based on flag
ahundt Aug 13, 2018
60fddb5
block_stacking_reader.py merge code from costar_plan
ahundt Aug 13, 2018
bd46afc
block_stacking_reader.py incorporate upstream changes and add new inp…
ahundt Aug 13, 2018
2aaad4c
micro search add missing flags and an error message.
ahundt Aug 13, 2018
60d9c81
models.py translation_only add xy grid and remove input rotations
ahundt Aug 13, 2018
686c0db
micro_child.py change aux head training config
ahundt Aug 13, 2018
0cd3844
costar_block_stacking_translation_search.sh 64 filters
ahundt Aug 13, 2018
6638067
general_child.py, micro_child.py ALL RELU -> ELU
ahundt Aug 13, 2018
de1ed0f
micro_child.py typo fix
ahundt Aug 13, 2018
43909f8
block_stacking_reader.py fix major bugs in data loading code for data…
ahundt Aug 13, 2018
787e5c2
block_stacking_reader.py fix assert bug
ahundt Aug 13, 2018
8c9283c
block_stacking_reader.py disable debug code
ahundt Aug 13, 2018
5ebb658
rotation_search.sh batch_size=16
ahundt Aug 13, 2018
83b74cd
Changed output for test and validation to include arc
j-varun Aug 13, 2018
22308b5
Added new case with reward estimates in block_stacking_reader.py
j-varun Aug 14, 2018
54f7c3a
Merge commit '83b74cd0530fa2ce3cc58f98aae25b809944e965' into stacking…
ahundt Aug 14, 2018
676f082
Added support for reward_estimates training case
j-varun Aug 14, 2018
59d7d90
main.py typo fix
ahundt Aug 14, 2018
3fd3dc9
set batch size and filters to work on titan xp
ahundt Aug 14, 2018
bf8506b
add requirements.txt
ahundt Aug 14, 2018
71265b2
Merge commit '676f082dbdf7d5055a74df996e097f53a944d458' into stacking…
ahundt Aug 14, 2018
2a81e99
Merge commit 'bf8506b94c7eeda1f42d5d0ba2888c13b036c3ac' into stacking…
ahundt Aug 14, 2018
455dc9c
Add reward critic search script
ahundt Aug 14, 2018
9b7e3fd
set reward critic output directory
ahundt Aug 14, 2018
874494b
reward_estimate -> stacking_reward so it can be told apart from the e…
ahundt Aug 14, 2018
ed69b79
TRANSLATION SEARCH MAJOR REWARD CHANGE TO 1/MSE
ahundt Aug 14, 2018
19b4504
block_stacking_reader.py only check for length of goal ids if it is l…
ahundt Aug 14, 2018
b63bda2
rename stacking reward variable
ahundt Aug 14, 2018
6dfb557
rotation search reduce batch size for GTX 1080 Ti
ahundt Aug 14, 2018
23a590a
micro rotation search partial reversion of settings (filters & input …
ahundt Aug 15, 2018
1f12b22
attempt 2 at printing eval architecture
ahundt Aug 16, 2018
49004a4
micro_child.py convert arc data to tensor so it will work in fixed an…
ahundt Aug 16, 2018
de49c43
build_valid() is_training = True
ahundt Aug 16, 2018
b16bbbb
arc print fix
ahundt Aug 16, 2018
73cc9bb
micro_child.py disable printing arcs continuously, but make it easy t…
ahundt Aug 16, 2018
a5e7228
add and enable random augmentation for translation search only
ahundt Aug 16, 2018
c55e4d1
typo fix
ahundt Aug 16, 2018
0d30145
main.py add random_augmentation parameter to child class call
ahundt Aug 16, 2018
313c963
create rotation final script
ahundt Aug 20, 2018
6babeec
rotation final fix nchw -> nhwc
ahundt Aug 20, 2018
96afa21
rotation search and final command fixes
ahundt Aug 20, 2018
f557120
Temporary fix for group_norm issues
j-varun Aug 21, 2018
ca67c1a
Another temporary fix for norm issues
j-varun Aug 21, 2018
4dd5515
Switch to batch_norm until group_norm is fixed
j-varun Aug 21, 2018
bef61b2
Changes to hyperparameters for rotation
j-varun Aug 22, 2018
283bcb4
Fix Typo
j-varun Aug 22, 2018
4821dba
Changes in rotation search script
j-varun Aug 31, 2018
649795f
Changes for costar_block_stacking_v0.3, load files from train-test-va…
j-varun Sep 5, 2018
39a58eb
Metrics being saved to csv
j-varun Sep 5, 2018
cbb1f1a
Changes in script for costar_block_stacking_v0.3
j-varun Sep 5, 2018
9bf5b11
updated names of text files
j-varun Sep 5, 2018
ced06db
block_stacking_reader.py expand user path to filename
ahundt Sep 5, 2018
1997ec6
Fix in csv file write mode for Python2 and minor changes in scripts
j-varun Sep 6, 2018
4475bbf
Minor bug fix
j-varun Sep 6, 2018
91cd02b
fix for csv file handling
j-varun Sep 6, 2018
c9ef761
micro_child.py attempt to add printing of arc during eval and test
ahundt Sep 6, 2018
661db8e
Bug fixes
j-varun Sep 6, 2018
6bf5009
Merge commit '91cd02b5660ad3d85602faa685cc89084ac1faaf' into stacking…
ahundt Sep 6, 2018
d8ed25f
Merge commit '661db8e89ee9609f04b4001b6588c44c11939161' into stacking…
ahundt Sep 6, 2018
6c4040c
micro_child.py print_arc handle feed dict None case
ahundt Sep 6, 2018
b2010b5
Additional metrics + Arc print for eval
j-varun Sep 7, 2018
6b76317
Changed rotation_weight to 1 for encoding and decoding in grasp_metri…
j-varun Sep 7, 2018
bbc37d8
Merge commit '6b7631796bc5d3c5c96205af1374f924e90f722a' into stacking…
ahundt Sep 7, 2018
c97b145
Fixes for typos
j-varun Sep 7, 2018
eb27128
Merge commit 'c97b14544ff6d316fbaac3071ee16eba6a58616f' into stacking…
ahundt Sep 7, 2018
237a080
More typo fixes
j-varun Sep 7, 2018
e2336be
Commenting unsuccessful printing of arcs
j-varun Sep 7, 2018
ee8c0ba
Merge commit 'e2336be9d9320dcfbba9fbb67671bf31f1a1edfa' into stacking…
ahundt Sep 8, 2018
c46e27d
Added support for inference in reader
j-varun Sep 9, 2018
a53d1fd
Merge commit 'c46e27d346185ec98f603a5d4176f15182929545' into stacking…
ahundt Sep 9, 2018
d284789
Bug fixes for metrics
j-varun Sep 9, 2018
ba8b18f
Bug fixes for metrics
j-varun Sep 9, 2018
78f6b1e
inference mode function integrated into generator in block_stacking_r…
j-varun Sep 9, 2018
53d73c5
Option to switch to one hot encoding in block_stacking_reader.py and …
j-varun Sep 10, 2018
b19806d
Added ELU activation to convolutions in new root
j-varun Sep 10, 2018
fdd1a08
Merge branch 'stacking_search' of https://github.com/ahundt/enas into…
j-varun Sep 10, 2018
6628d6b
Merge commit 'b19806d129e9dd0a696d9fcd96a3e70c4e070a52' into stacking…
ahundt Sep 10, 2018
5eb698b
Commented print_arc in eval
j-varun Sep 10, 2018
08190d7
Merge commit '5eb698b676b435d98112fdb80027e46fe6d1b581' into stacking…
ahundt Sep 10, 2018
0cb8aa6
Bug fix and updated scripts"
j-varun Sep 10, 2018
bfa984b
Merge commit '0cb8aa65e2676418f90a7617bc3ae3bde93edc8e' into stacking…
ahundt Sep 10, 2018
2e84b1e
merge changes from costar_plan
ahundt Sep 10, 2018
0ed01f3
block_stacking_reader.py CRITICAL MERGE OF SPLIT OUT ENCODING API
ahundt Sep 10, 2018
1a2fa82
Changes in stem_conv
j-varun Sep 10, 2018
f749a00
Merge branch 'stacking_search' of https://github.com/ahundt/enas into…
j-varun Sep 10, 2018
92010cd
set new rotation final run, make sure val and test is collected
ahundt Sep 12, 2018
6ac9176
Create block stacking translation final script
ahundt Sep 12, 2018
2217e64
Merge pull request #8 from j-varun/stacking_search
ahundt Sep 12, 2018
cf735c9
Option to use MSLE as primary loss function
j-varun Sep 13, 2018
fe0c13b
Merge branch 'stacking_search' of https://github.com/ahundt/enas into…
j-varun Sep 13, 2018
9f35cf4
Changes in rotation final run
j-varun Sep 13, 2018
bcff854
Path fixes
j-varun Sep 13, 2018
5270e30
Additional condition check in reader
j-varun Sep 14, 2018
c224322
Batch size change in rotation final script
j-varun Sep 14, 2018
e5b93bb
Use model batch with low angular error and 100% acc
ahundt Sep 14, 2018
e2f50ff
enas/cifar10/block_stacking_reader.py fix bad merge
ahundt Sep 14, 2018
1bc0155
Fix for redundant condition
j-varun Sep 14, 2018
a40fcfe
enas/cifar10/block_stacking_reader.py fix anoter merge problem
ahundt Sep 14, 2018
3d81da8
Merge branch 'stacking_search' into good_fixes
ahundt Sep 14, 2018
c378936
Merge pull request #9 from j-varun/good_fixes
ahundt Sep 14, 2018
b03832c
Fix some details for starting the run
ahundt Sep 14, 2018
40d3508
batch size 32
ahundt Sep 14, 2018
08abed4
fix csv path
ahundt Sep 14, 2018
5398e8c
fix for running with msle
j-varun Sep 14, 2018
ff91ed7
Typo fix
j-varun Sep 14, 2018
bf5edbc
final rotation search settings, add no root search scripts
ahundt Sep 15, 2018
61bbfd7
New scripts for final run and added headers for csv
j-varun Sep 17, 2018
b156c95
Timestamp update in scripts
j-varun Sep 17, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
815 changes: 815 additions & 0 deletions enas/cifar10/block_stacking_reader.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions enas/cifar10/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _read_fmnist_data(data_path):
labels["train"] = np.array(data.train.labels, dtype = np.int32)
labels["test"] = np.array(data.test.labels, dtype = np.int32)
print("Read and processed data..")
print(labels["test"])

return images, labels

Expand Down Expand Up @@ -80,6 +81,9 @@ def read_data(data_path, num_valids=5000, dataset = "cifar"):
images, labels = valid_split_data(images, labels, num_valids)
return images, labels

if dataset == "stacking":
images["path"] = data_path
return images, labels
else:
train_files = [
"data_batch_1",
Expand Down
53 changes: 29 additions & 24 deletions enas/cifar10/general_child.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from enas.cifar10.image_ops import conv
from enas.cifar10.image_ops import fully_connected
from enas.cifar10.image_ops import batch_norm
from enas.cifar10.image_ops import norm
from enas.cifar10.image_ops import batch_norm_with_mask
from enas.cifar10.image_ops import relu
from enas.cifar10.image_ops import max_pool
from enas.cifar10.image_ops import global_avg_pool
from enas.cifar10.image_ops import global_max_pool

from enas.utils import count_model_params
from enas.utils import get_train_ops
Expand Down Expand Up @@ -101,8 +102,10 @@ def _get_C(self, x):
x: tensor of shape [N, H, W, C] or [N, C, H, W]
"""
if self.data_format == "NHWC":
assert x.get_shape().as_list()[3] is not None
return x.get_shape()[3].value
elif self.data_format == "NCHW":
assert x.get_shape().as_list()[1] is not None
return x.get_shape()[1].value
else:
raise ValueError("Unknown data_format '{0}'".format(self.data_format))
Expand All @@ -112,6 +115,7 @@ def _get_HW(self, x):
Args:
x: tensor of shape [N, H, W, C] or [N, C, H, W]
"""
assert x.get_shape().as_list()[2] is not None
return x.get_shape()[2].value

def _get_strides(self, stride):
Expand All @@ -136,7 +140,7 @@ def _factorized_reduction(self, x, out_filters, stride, is_training):
w = create_weight("w", [1, 1, inp_c, out_filters])
x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME",
data_format=self.data_format)
x = batch_norm(x, is_training, data_format=self.data_format)
x = norm(x, is_training, data_format=self.data_format)
return x

stride_spec = self._get_strides(stride)
Expand Down Expand Up @@ -171,7 +175,7 @@ def _factorized_reduction(self, x, out_filters, stride, is_training):

# Concat and apply BN
final_path = tf.concat(values=[path1, path2], axis=concat_axis)
final_path = batch_norm(final_path, is_training,
final_path = norm(final_path, is_training,
data_format=self.data_format)

return final_path
Expand All @@ -194,11 +198,11 @@ def _model(self, images, is_training, reuse=False):
layers = []

out_filters = self.out_filters
C = self._get_C(images)
C = self._get_C(images)
with tf.variable_scope("stem_conv"):
w = create_weight("w", [C, C, C, out_filters])
x = tf.nn.conv2d(images, w, [1, 1, 1, 1], "SAME", data_format=self.data_format)
x = batch_norm(x, is_training, data_format=self.data_format)
x = norm(x, is_training, data_format=self.data_format)
layers.append(x)

if self.whole_channels:
Expand Down Expand Up @@ -229,7 +233,7 @@ def _model(self, images, is_training, reuse=False):
start_idx += 2 * self.num_branches + layer_id
print(layers[-1])

x = global_avg_pool(x, data_format=self.data_format)
x = global_max_pool(x, data_format=self.data_format)
if is_training:
x = tf.nn.dropout(x, self.keep_prob)
with tf.variable_scope("fc"):
Expand Down Expand Up @@ -351,8 +355,8 @@ def _enas_layer(self, layer_id, prev_layers, start_idx, out_filters, is_training
branches = tf.reshape(branches, [N, -1, H, W])
out = tf.nn.conv2d(
branches, w, [1, 1, 1, 1], "SAME", data_format=self.data_format)
out = batch_norm(out, is_training, data_format=self.data_format)
out = tf.nn.relu(out)
out = norm(out, is_training, data_format=self.data_format)
out = tf.nn.elu(out)

if layer_id > 0:
if self.whole_channels:
Expand All @@ -368,7 +372,7 @@ def _enas_layer(self, layer_id, prev_layers, start_idx, out_filters, is_training
lambda: tf.zeros_like(prev_layers[i])))
res_layers.append(out)
out = tf.add_n(res_layers)
out = batch_norm(out, is_training, data_format=self.data_format)
out = norm(out, is_training, data_format=self.data_format)

return out

Expand Down Expand Up @@ -396,17 +400,17 @@ def _fixed_layer(
filter_size = size[count]
with tf.variable_scope("conv_1x1"):
w = create_weight("w", [1, 1, inp_c, out_filters])
out = tf.nn.relu(inputs)
out = tf.nn.elu(inputs)
out = tf.nn.conv2d(out, w, [1, 1, 1, 1], "SAME",
data_format=self.data_format)
out = batch_norm(out, is_training, data_format=self.data_format)
out = norm(out, is_training, data_format=self.data_format)

with tf.variable_scope("conv_{0}x{0}".format(filter_size)):
w = create_weight("w", [filter_size, filter_size, out_filters, out_filters])
out = tf.nn.relu(out)
out = tf.nn.elu(out)
out = tf.nn.conv2d(out, w, [1, 1, 1, 1], "SAME",
data_format=self.data_format)
out = batch_norm(out, is_training, data_format=self.data_format)
out = norm(out, is_training, data_format=self.data_format)
elif count == 4:
pass
elif count == 5:
Expand Down Expand Up @@ -449,10 +453,10 @@ def _fixed_layer(
branches = tf.concat(branches, axis=3)
elif self.data_format == "NCHW":
branches = tf.concat(branches, axis=1)
out = tf.nn.relu(branches)
out = tf.nn.elu(branches)
out = tf.nn.conv2d(out, w, [1, 1, 1, 1], "SAME",
data_format=self.data_format)
out = batch_norm(out, is_training, data_format=self.data_format)
out = norm(out, is_training, data_format=self.data_format)

if layer_id > 0:
if self.whole_channels:
Expand All @@ -477,10 +481,10 @@ def _fixed_layer(
with tf.variable_scope("skip"):
w = create_weight(
"w", [1, 1, total_skip_channels * out_filters, out_filters])
out = tf.nn.relu(out)
out = tf.nn.elu(out)
out = tf.nn.conv2d(
out, w, [1, 1, 1, 1], "SAME", data_format=self.data_format)
out = batch_norm(out, is_training, data_format=self.data_format)
out = norm(out, is_training, data_format=self.data_format)

return out

Expand All @@ -504,8 +508,8 @@ def _conv_branch(self, inputs, filter_size, is_training, count, out_filters,
with tf.variable_scope("inp_conv_1"):
w = create_weight("w", [1, 1, inp_c, out_filters])
x = tf.nn.conv2d(inputs, w, [1, 1, 1, 1], "SAME", data_format=self.data_format)
x = batch_norm(x, is_training, data_format=self.data_format)
x = tf.nn.relu(x)
x = norm(x, is_training, data_format=self.data_format)
x = tf.nn.elu(x)

with tf.variable_scope("out_conv_{}".format(filter_size)):
if start_idx is None:
Expand All @@ -515,12 +519,13 @@ def _conv_branch(self, inputs, filter_size, is_training, count, out_filters,
w_point = create_weight("w_point", [1, 1, out_filters * ch_mul, count])
x = tf.nn.separable_conv2d(x, w_depth, w_point, strides=[1, 1, 1, 1],
padding="SAME", data_format=self.data_format)
x = batch_norm(x, is_training, data_format=self.data_format)
x = norm(x, is_training, data_format=self.data_format)
else:
w = create_weight("w", [filter_size, filter_size, inp_c, count])
x = tf.nn.conv2d(x, w, [1, 1, 1, 1], "SAME", data_format=self.data_format)
x = batch_norm(x, is_training, data_format=self.data_format)
x = norm(x, is_training, data_format=self.data_format)
else:
print('TODO(ahundt) batch_norm_with_mask is definitely called... make a group norm version!')
if separable:
w_depth = create_weight("w_depth", [filter_size, filter_size, out_filters, ch_mul])
w_point = create_weight("w_point", [out_filters, out_filters * ch_mul])
Expand All @@ -544,7 +549,7 @@ def _conv_branch(self, inputs, filter_size, is_training, count, out_filters,
mask = tf.logical_and(start_idx <= mask, mask < start_idx + count)
x = batch_norm_with_mask(
x, is_training, mask, out_filters, data_format=self.data_format)
x = tf.nn.relu(x)
x = tf.nn.elu(x)
return x

def _pool_branch(self, inputs, is_training, count, avg_or_max, start_idx=None):
Expand All @@ -566,8 +571,8 @@ def _pool_branch(self, inputs, is_training, count, avg_or_max, start_idx=None):
with tf.variable_scope("conv_1"):
w = create_weight("w", [1, 1, inp_c, self.out_filters])
x = tf.nn.conv2d(inputs, w, [1, 1, 1, 1], "SAME", data_format=self.data_format)
x = batch_norm(x, is_training, data_format=self.data_format)
x = tf.nn.relu(x)
x = norm(x, is_training, data_format=self.data_format)
x = tf.nn.elu(x)

with tf.variable_scope("pool"):
if self.data_format == "NHWC":
Expand Down
Loading