Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
devanshpratapsingh authored Apr 4, 2021
1 parent db62f03 commit e93ccf6
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
#a basic convolution block having two 3x3 convolution where each conv.
#is followed by batch normalizaton and ReLU then a max pooling operation
def conv_block(input, filters, pool=True):
x = conv2D(filters, 3, padding="same")(input)
x = Conv2D(filters, 3, padding="same")(input)
x = BatchNormalization()(x)
x = Activation("relu")(x)

x = conv2D(filters, 3, padding="same")(x)
x = Conv2D(filters, 3, padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)

Expand Down Expand Up @@ -38,6 +38,25 @@ def build_unet(shape, num_classes):
u1 = UpSampling2D((2, 2), interpolation="bilinear")(b1)
#contatenation with the feature man from encoder
c1 = Concatenate()([u1, x4])
x5 = conv_block(c1, 64, pool=False)

if __name__ == "__main__":
build_unet((256, 256, 3), 3)
u2 = UpSampling2D((2, 2), interpolation="bilinear")(x5)
c2 = Concatenate()([u2, x3])
x6 = conv_block(c2, 48, pool=False)

u3 = UpSampling2D((2, 2), interpolation="bilinear")(x6)
c3 = Concatenate()([u3, x2])
x7 = conv_block(c3, 32, pool=False)

u4 = UpSampling2D((2, 2), interpolation="bilinear")(x7)
c4 = Concatenate()([u4, x1])
x8 = conv_block(c4, 16, pool=False)

"""Output Layer"""
output = Conv2D(num_classes, 1, padding ="same", activation="softmax")(x8)

return Model(inputs, output)

if __name__ == "__main__":
model = build_unet((256, 256, 3), 10)
model.summary()

0 comments on commit e93ccf6

Please sign in to comment.