第五讲 卷积神经网络 - Resnet--cifar10

georgesale 2020-05-10

import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Dropout, Flatten, Dense, GlobalAveragePooling2D
from tensorflow.keras import Model

np.set_printoptions(threshold=np.inf)


cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train/255.0, x_test/255.0



class ResnetBlock(Model):
    def __init__(self, filters, strides=1, residual_path=False):
        super(ResnetBlock, self).__init__()
        self.filters = filters
        self.strides = strides
        self.residual_path = residual_path

        self.c1 = Conv2D(filters, (3, 3), strides=strides, padding=‘same‘, use_bias=False)
        self.b1 = BatchNormalization()
        self.a1 = Activation(‘relu‘)

        self.c2 = Conv2D(filters, (3, 3), strides=1, padding=‘same‘, use_bias=False)
        self.b2 = BatchNormalization()

        # residual_path为True时,对输入进行下采样,即用1x1的卷积核做卷积操作,保证x能和F(x)维度相同,顺利相加
        if residual_path:
            self.down_c1 = Conv2D(filters, (1, 1), strides=strides, padding=‘same‘, use_bias=False)
            self.down_b1 = BatchNormalization()

        self.a2 = Activation(‘relu‘)

    def call(self, inputs):
        residual = inputs # residual等于输入值本身,即residual=x
        x = self.c1(inputs)
        x = self.b1(x)
        x = self.a1(x)

        x = self.c2(x)
        y = self.b2(x)

        if self.residual_path:
            residual  = self.down_c1(inputs)
            residual  = self.down_b1(residual)

        out = self.a2(y + residual) # 最后输出的是两部分的和,即F(x)+x或F(x)+Wx,再过激活函数
        return out



class ResNet18(Model):
    def __init__(self, block_list, initial_filters=64): # block_list表示每个block有几个卷积层
        super(ResNet18, self).__init__()
        self.num_blocks = len(block_list) # 共有几个block
        self.block_list = block_list
        self.out_filters = initial_filters
        self.c1 = Conv2D(self.out_filters, (3, 3), strides=1, padding=‘same‘, use_bias = False)
        self.b1 = BatchNormalization()
        self.a1 = Activation(‘relu‘)
        self.blocks = tf.keras.models.Sequential()
        # 构建ResNet网络结构
        for block_id in range(len(block_list)):
            for layer_id in range(block_list[block_id]):
                if block_id != 0 and layer_id == 0: # 对除第一个block以外的每个block的输入进行下采样
                    block = ResnetBlock(self.out_filters, strides=2, residual_path=True)
                else:
                    block = ResnetBlock(self.out_filters, residual_path=False)
                self.blocks.add(block) # 将构建好的block加入resnet
            self.out_filters *= 2 # 下一个block的卷积核数是上一个block的2倍
        self.p1 = tf.keras.layers.GlobalAveragePooling2D()
        self.f1 = tf.keras.layers.Dense(10, activation=‘softmax‘, kernel_regularizer=tf.keras.regularizers.l2())

    
    def call(self, inputs):
        x = self.c1(inputs)
        x = self.b1(x)
        x = self.a1(x)
        x = self.blocks(x)
        x = self.p1(x)
        y = self.f1(x)
        return y



model = ResNet18([2, 2, 2, 2])

model.compile(optimizer=‘adam‘,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=[‘sparse_categorical_accuracy‘])


checkpoint_save_path = "./checkpoint/Inception10.ckpt"
if os.path.exists(checkpoint_save_path + ‘.index‘):
    print(‘-------------load the model---------------‘)
    model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath = checkpoint_save_path,
                                                save_weights_only = True,
                                                save_best_only = True)

history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test),validation_freq=1,
                    callbacks=[cp_callback])
model.summary()



with open(‘./weights.txt‘, ‘w‘) as f:
    for v in model.trainable_variables:
        f.write(str(v.name) + ‘\n‘)
        f.write(str(v.shape) + ‘\n‘)
        f.write(str(v.numpy()) + ‘\n‘)


def plot_acc_loss_curve(history):
    # 显示训练集和验证集的acc和loss曲线
    from matplotlib import pyplot as plt
    acc = history.history[‘sparse_categorical_accuracy‘]
    val_acc = history.history[‘val_sparse_categorical_accuracy‘]
    loss = history.history[‘loss‘]
    val_loss = history.history[‘val_loss‘]
    
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 2, 1)
    plt.plot(acc, label=‘Training Accuracy‘)
    plt.plot(val_acc, label=‘Validation Accuracy‘)
    plt.title(‘Training and Validation Accuracy‘)
    plt.legend()
    #plt.grid()
    
    plt.subplot(1, 2, 2)
    plt.plot(loss, label=‘Training Loss‘)
    plt.plot(val_loss, label=‘Validation Loss‘)
    plt.title(‘Training and Validation Loss‘)
    plt.legend()
    #plt.grid()
    plt.show()

plot_acc_loss_curve(history)

相关推荐