Created
December 5, 2018 08:46
-
-
Save Mainvooid/84d117f6d24ceaa171b0a76acd9827a7 to your computer and use it in GitHub Desktop.
keras实现deeplabv3+模型 #Keras #Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from keras.layers import Activation,Conv2D,MaxPooling2D,BatchNormalization,Input,DepthwiseConv2D,add,Dropout,AveragePooling2D,Concatenate | |
| from keras.models import Model | |
| import keras.backend as K | |
| from keras.engine import Layer,InputSpec | |
| from keras.utils import conv_utils | |
| class BilinearUpsampling(Layer): | |
| def __init__(self, upsampling=(2, 2), data_format=None, **kwargs): | |
| super(BilinearUpsampling, self).__init__(**kwargs) | |
| self.data_format = conv_utils.normalize_data_format(data_format) | |
| self.upsampling = conv_utils.normalize_tuple(upsampling, 2, 'size') | |
| self.input_spec = InputSpec(ndim=4) | |
| def compute_output_shape(self, input_shape): | |
| height = self.upsampling[0] * \ | |
| input_shape[1] if input_shape[1] is not None else None | |
| width = self.upsampling[1] * \ | |
| input_shape[2] if input_shape[2] is not None else None | |
| return (input_shape[0], | |
| height, | |
| width, | |
| input_shape[3]) | |
| def call(self, inputs): | |
| return K.tf.image.resize_bilinear(inputs, (int(inputs.shape[1]*self.upsampling[0]), | |
| int(inputs.shape[2]*self.upsampling[1]))) | |
| def get_config(self): | |
| config = {'size': self.upsampling, | |
| 'data_format': self.data_format} | |
| base_config = super(BilinearUpsampling, self).get_config() | |
| return dict(list(base_config.items()) + list(config.items())) | |
| def xception_downsample_block(x,channels,top_relu=False): | |
| ##separable conv1 | |
| if top_relu: | |
| x=Activation("relu")(x) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(channels,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| ##separable conv2 | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(channels,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| ##separable conv3 | |
| x=DepthwiseConv2D((3,3),strides=(2,2),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(channels,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| return x | |
| def res_xception_downsample_block(x,channels): | |
| res=Conv2D(channels,(1,1),strides=(2,2),padding="same",use_bias=False)(x) | |
| res=BatchNormalization()(res) | |
| x=xception_downsample_block(x,channels) | |
| x=add([x,res]) | |
| return x | |
| def xception_block(x,channels): | |
| ##separable conv1 | |
| x=Activation("relu")(x) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(channels,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| ##separable conv2 | |
| x=Activation("relu")(x) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(channels,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| ##separable conv3 | |
| x=Activation("relu")(x) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(channels,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| return x | |
| def res_xception_block(x,channels): | |
| res=x | |
| x=xception_block(x,channels) | |
| x=add([x,res]) | |
| return x | |
| def aspp(x,input_shape,out_stride): | |
| b0=Conv2D(256,(1,1),padding="same",use_bias=False)(x) | |
| b0=BatchNormalization()(b0) | |
| b0=Activation("relu")(b0) | |
| b1=DepthwiseConv2D((3,3),dilation_rate=(6,6),padding="same",use_bias=False)(x) | |
| b1=BatchNormalization()(b1) | |
| b1=Activation("relu")(b1) | |
| b1=Conv2D(256,(1,1),padding="same",use_bias=False)(b1) | |
| b1=BatchNormalization()(b1) | |
| b1=Activation("relu")(b1) | |
| b2=DepthwiseConv2D((3,3),dilation_rate=(12,12),padding="same",use_bias=False)(x) | |
| b2=BatchNormalization()(b2) | |
| b2=Activation("relu")(b2) | |
| b2=Conv2D(256,(1,1),padding="same",use_bias=False)(b2) | |
| b2=BatchNormalization()(b2) | |
| b2=Activation("relu")(b2) | |
| b3=DepthwiseConv2D((3,3),dilation_rate=(12,12),padding="same",use_bias=False)(x) | |
| b3=BatchNormalization()(b3) | |
| b3=Activation("relu")(b3) | |
| b3=Conv2D(256,(1,1),padding="same",use_bias=False)(b3) | |
| b3=BatchNormalization()(b3) | |
| b3=Activation("relu")(b3) | |
| out_shape=int(input_shape[0]/out_stride) | |
| b4=AveragePooling2D(pool_size=(out_shape,out_shape))(x) | |
| b4=Conv2D(256,(1,1),padding="same",use_bias=False)(b4) | |
| b4=BatchNormalization()(b4) | |
| b4=Activation("relu")(b4) | |
| b4=BilinearUpsampling((out_shape,out_shape))(b4) | |
| x=Concatenate()([b4,b0,b1,b2,b3]) | |
| return x | |
| def deeplabv3_plus(input_shape=(512,512,3),out_stride=16,num_classes=21): | |
| img_input=Input(shape=input_shape) | |
| x=Conv2D(32,(3,3),strides=(2,2),padding="same",use_bias=False)(img_input) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| x=Conv2D(64,(3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| x=res_xception_downsample_block(x,128) | |
| res=Conv2D(256,(1,1),strides=(2,2),padding="same",use_bias=False)(x) | |
| res=BatchNormalization()(res) | |
| x=Activation("relu")(x) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(256,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(256,(1,1),padding="same",use_bias=False)(x) | |
| skip=BatchNormalization()(x) | |
| x=Activation("relu")(skip) | |
| x=DepthwiseConv2D((3,3),strides=(2,2),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(256,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=add([x,res]) | |
| x=xception_downsample_block(x,728,top_relu=True) | |
| for i in range(16): | |
| x=res_xception_block(x,728) | |
| res=Conv2D(1024,(1,1),padding="same",use_bias=False)(x) | |
| res=BatchNormalization()(res) | |
| x=Activation("relu")(x) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(728,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(1024,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(1024,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=add([x,res]) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(1536,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(1536,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Conv2D(2048,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| #aspp | |
| x=aspp(x,input_shape,out_stride) | |
| x=Conv2D(256,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| x=Dropout(0.9)(x) | |
| ##decoder | |
| x=BilinearUpsampling((4,4))(x) | |
| dec_skip=Conv2D(48,(1,1),padding="same",use_bias=False)(skip) | |
| dec_skip=BatchNormalization()(dec_skip) | |
| dec_skip=Activation("relu")(dec_skip) | |
| x=Concatenate()([x,dec_skip]) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| x=Conv2D(256,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| x=DepthwiseConv2D((3,3),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| x=Conv2D(256,(1,1),padding="same",use_bias=False)(x) | |
| x=BatchNormalization()(x) | |
| x=Activation("relu")(x) | |
| x=Conv2D(num_classes,(1,1),padding="same")(x) | |
| x=BilinearUpsampling((4,4))(x) | |
| model=Model(img_input,x) | |
| return model | |
| if __name__=="__main__": | |
| model=deeplabv3_plus(num_classes=1) | |
| model.summary() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment