Skip to content

Commit

Permalink
Fix shape issue
Browse files Browse the repository at this point in the history
  • Loading branch information
a00012025 committed Jun 11, 2019
1 parent 4bcfd73 commit 07a5d44
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/cifar10/general_child.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,13 @@ def _enas_layer(self, layer_id, prev_layers, start_idx, out_filters, is_training
y = self._pool_branch(inputs, is_training, out_filters, "max",
start_idx=0)
branches[tf.equal(count, 5)] = lambda: y
out = tf.case(branches, default=lambda: tf.constant(0, tf.float32),
exclusive=True)

if self.data_format == "NHWC":
out.set_shape([None, inp_h, inp_w, out_filters])
out_shape = [self.batch_size, inp_h, inp_w, out_filters]
elif self.data_format == "NCHW":
out.set_shape([None, out_filters, inp_h, inp_w])
out_shape = [self.batch_size, out_filters, inp_h, inp_w]
out = tf.case(branches, default=lambda: tf.constant(0, tf.float32, shape=out_shape),
exclusive=True)
else:
count = self.sample_arc[start_idx:start_idx + 2 * self.num_branches]
branches = []
Expand Down

0 comments on commit 07a5d44

Please sign in to comment.