Simple Multi Unet Model
Simple Multi Unet Model
be/XyX5HNuv-xE
# https://youtu.be/q-p8v1Bxvac
"""
Standard Unet
Model not compiled here, instead will be done externally to make it
easy to test various loss functions and optimizers.
"""
################################################################
def multi_unet_model(n_classes=5, IMG_HEIGHT=256, IMG_WIDTH=256, IMG_CHANNELS=1):
#Build the model
inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
#s = Lambda(lambda x: x / 255)(inputs) #No need for this if we normalize our
inputs beforehand
s = inputs
#Contraction path
c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal',
padding='same')(s)
c1 = Dropout(0.1)(c1)
c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal',
padding='same')(c1)
p1 = MaxPooling2D((2, 2))(c1)
#NOTE: Compile the model in the main program to make it easy to test with
various loss functions
#model.compile(optimizer='adam', loss='binary_crossentropy',
metrics=['accuracy'])
#model.summary()
return model