TensorFlow:tensorflow之CIFAR10与VGG13实战

   日期:2020-11-03     浏览:92    评论:0    
核心提示:文章目录CIFAR10与VGG13实战1. 准备数据2.构建网络模型3.训练模型CIFAR10与VGG13实战1. 准备数据CIFAR10 数据集由加拿大 Canadian Institute For Advanced Research 发布,它包含了飞机、汽车、鸟、猫等共 10 大类物体的彩色图片,每个种类收集了 6000 张 32x32 大小图片,共 60K 张图片。其中 50K 作为训练数据集,10K 作为测试数据集。import tensorflow as tffrom tensorflo

文章目录

  • CIFAR10与VGG13实战
    • 1. 准备数据
    • 2.构建网络模型
    • 3.训练模型

CIFAR10与VGG13实战

1. 准备数据

CIFAR10 数据集由加拿大 Canadian Institute For Advanced Research 发布,它包含了飞机、汽车、鸟、猫等共 10 大类物体的彩色图片,每个种类收集了 6000 张 32x32 大小图片,共 60K 张图片。其中 50K 作为训练数据集,10K 作为测试数据集。

import tensorflow as tf
from tensorflow.keras import datasets,layers,losses,optimizers,Sequential
import  os

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

使用datasets.cifar10.load_data()加载数据集,并对数据集进行预处理

def preprocess(x, y):
    # [0~1]
    x = 2*tf.cast(x, dtype=tf.float32) / 255.-1
    y = tf.cast(y, dtype=tf.int32)
    return x,y


(x,y), (x_test, y_test) = datasets.cifar10.load_data()
y = tf.squeeze(y, axis=1)
y_test = tf.squeeze(y_test, axis=1)
print(x.shape, y.shape, x_test.shape, y_test.shape)
(50000, 32, 32, 3) (50000,) (10000, 32, 32, 3) (10000,)
train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.shuffle(1000).map(preprocess).batch(128)

test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.map(preprocess).batch(64)

sample = next(iter(train_db))
print('sample:', sample[0].shape, sample[1].shape,
      tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))
sample: (128, 32, 32, 3) (128,) tf.Tensor(-1.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32)

2.构建网络模型

CIFAR10 图片识别任务并不简单,这主要是由于 CIFAR10 的图片内容需要大量细节才能呈现,而保存的图片分辨率仅有 32x32,使得部分主体信息较为模糊,甚至人眼都很难分辨。浅层的神经网络表达能力有限,很难训练优化到较好的性能,本节将基于表达能力更强的 VGG13 网络,根据我们的数据集特点修改部分网络结构,完成 CIFAR10 图片识别。

我们将网络实现为 2 个子网络:卷积子网络和全连接子网络。卷积子网络由 5 个子模块构成,每个子模块包含了 Conv-Conv-MaxPooling 单元结构:

conv_layers = [ # 先创建包含多层的列表
    # unit 1
    # 64 个 3x3 卷积核, 输入输出同大小
    layers.Conv2D(64, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(64, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),

    # unit 2
    #输出通道提升至 128,高宽大小减半
    layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(128, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),

    # unit 3
    #,输出通道提升至 256,高宽大小减半
    layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(256, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),

    # unit 4
    #输出通道提升至 512,高宽大小减半
    layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same'),

    # unit 5
    #输出通道提升至 512,高宽大小减半
    layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.Conv2D(512, kernel_size=[3, 3], padding="same", activation=tf.nn.relu),
    layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same')
]
# 利用前面创建的层列表构建网络容器
conv_net = Sequential(conv_layers)

全连接子网络包含了 3 个全连接层,每层添加 ReLU 非线性激活函数,最后一层除外

# 创建 3 层全连接层子网络
fc_net = Sequential([
    layers.Dense(256, activation=tf.nn.relu),
    layers.Dense(128, activation=tf.nn.relu),
    layers.Dense(10, activation=None),
])
# build2 个子网络,并打印网络参数信息
conv_net.build(input_shape=[4, 32, 32, 3])
fc_net.build(input_shape=[4, 512])
conv_net.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (4, 32, 32, 64)           1792      
_________________________________________________________________
conv2d_1 (Conv2D)            (4, 32, 32, 64)           36928     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (4, 16, 16, 64)           0         
_________________________________________________________________
conv2d_2 (Conv2D)            (4, 16, 16, 128)          73856     
_________________________________________________________________
conv2d_3 (Conv2D)            (4, 16, 16, 128)          147584    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (4, 8, 8, 128)            0         
_________________________________________________________________
conv2d_4 (Conv2D)            (4, 8, 8, 256)            295168    
_________________________________________________________________
conv2d_5 (Conv2D)            (4, 8, 8, 256)            590080    
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (4, 4, 4, 256)            0         
_________________________________________________________________
conv2d_6 (Conv2D)            (4, 4, 4, 512)            1180160   
_________________________________________________________________
conv2d_7 (Conv2D)            (4, 4, 4, 512)            2359808   
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (4, 2, 2, 512)            0         
_________________________________________________________________
conv2d_8 (Conv2D)            (4, 2, 2, 512)            2359808   
_________________________________________________________________
conv2d_9 (Conv2D)            (4, 2, 2, 512)            2359808   
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (4, 1, 1, 512)            0         
=================================================================
Total params: 9,404,992
Trainable params: 9,404,992
Non-trainable params: 0
_________________________________________________________________

3.训练模型

下面对数据进行训练,并对测试数据集进行测试

def main():

    # [b, 32, 32, 3] => [b, 1, 1, 512]
    conv_net = Sequential(conv_layers)

    fc_net = Sequential([
        layers.Dense(256, activation=tf.nn.relu),
        layers.Dense(128, activation=tf.nn.relu),
        layers.Dense(10, activation=None),
    ])

    conv_net.build(input_shape=[None, 32, 32, 3])
    fc_net.build(input_shape=[None, 512])
# conv_net.summary()
# fc_net.summary()
    optimizer = optimizers.Adam(lr=1e-4)

    # [1, 2] + [3, 4] => [1, 2, 3, 4]
    variables = conv_net.trainable_variables + fc_net.trainable_variables

    for epoch in range(50):

        for step, (x,y) in enumerate(train_db):

            with tf.GradientTape() as tape:
                # [b, 32, 32, 3] => [b, 1, 1, 512]
                out = conv_net(x)
                # flatten, => [b, 512]
                out = tf.reshape(out, [-1, 512])
                # [b, 512] => [b, 10]
                logits = fc_net(out)
                # [b] => [b, 10]
                y_onehot = tf.one_hot(y, depth=10)
                # compute loss
                loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
                loss = tf.reduce_mean(loss)

            grads = tape.gradient(loss, variables)
            optimizer.apply_gradients(zip(grads, variables))

            if step %100 == 0:
                print(epoch, step, 'loss:', float(loss))



        total_num = 0
        total_correct = 0
        for x,y in test_db:

            out = conv_net(x)
            out = tf.reshape(out, [-1, 512])
            logits = fc_net(out)
            prob = tf.nn.softmax(logits, axis=1)
            pred = tf.argmax(prob, axis=1)
            pred = tf.cast(pred, dtype=tf.int32)

            correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
            correct = tf.reduce_sum(correct)

            total_num += x.shape[0]
            total_correct += int(correct)

        acc = total_correct / total_num
        print(epoch, 'acc:', acc)



if __name__ == '__main__':
    main()
 
打赏
 本文转载自:网络 
所有权利归属于原作者,如文章来源标示错误或侵犯了您的权利请联系微信13520258486
更多>最近资讯中心
更多>最新资讯中心
0相关评论

推荐图文
推荐资讯中心
点击排行
最新信息
新手指南
采购商服务
供应商服务
交易安全
关注我们
手机网站:
新浪微博:
微信关注:

13520258486

周一至周五 9:00-18:00
(其他时间联系在线客服)

24小时在线客服