xdq0 2018-12-24
pytorch是一个优秀的深度学习框架,它不仅更加灵活,支持动态图,而且它是一个以Python优先的深度学习框架,不仅能够实现强大的GPU加速,同时还支持动态神经网络,我们可以方便快速地使用它搭建深度学习网络。
深度学习框架
VGGNet是牛津大学计算机视觉组和Google DeepMind公司的研究员一起研发的深度卷积神经网络。vgg卷积神经网络是第一个真正意义上的深层网络结构,它是 ImageNet2014年的冠军。vgg 的网络结构非常简单,就是不断地堆叠卷积层和池化层,下面是一个简单的图示:
VGG的网络结构
VGG分为16层和19层两种情况,其中16层称为VGG16,而19层称为VGG19。vgg 几乎全部使用 3 x 3 的卷积核以及 2 x 2 的池化层,使用小的卷积核进行多层的堆叠和一个大的卷积核的感受野是相同的,并且小的卷积核还能减少参数,同时可以有更深的结构。vgg 的一个关键就是使用很多层 3 x 3 的卷积然后再使用一个最大池化层。
def vgg_block(num_convs, in_channels, out_channels): net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(True)] for i in range(num_convs-1): net.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)) net.append(nn.ReLU(True)) net.append(nn.MaxPool2d(2, 2)) return nn.Sequential(*net) def vgg_stack(num_convs, channels): net = [] for n, c in zip(num_convs, channels): in_c = c[0] out_c = c[1] net.append(vgg_block(n, in_c, out_c)) return nn.Sequential(*net) class vgg(nn.Module): def __init__(self): super(vgg, self).__init__() self.feature = vgg_net self.fc = nn.Sequential( nn.Linear(512, 100), nn.ReLU(True), nn.Linear(100, 10) ) def forward(self, x): x = self.feature(x) x = x.view(x.shape[0], -1) x = self.fc(x) return x
我们使用一个经典的图片集cifar 10 ,这个数据集一共有 50000 张训练集,10000 张测试集,两个数据集里面的图片都是 png 彩色图片,图片大小是 32 x 32 x 3,一共是 10 分类问题,分别为飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。
图片集cifar 10
train_set = CIFAR10('./data', train=True, transform=data_tf) train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True) test_set = CIFAR10('./data', train=False, transform=data_tf) test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False) net = vgg() optimizer = torch.optim.SGD(net.parameters(), lr=1e-1) criterion = nn.CrossEntropyLoss() for epoch in range(num_epochs): net = net.train() for im, label in train_data: output = net(im) loss = criterion(output, label) optimizer.zero_grad() loss.backward() optimizer.step()
只要运行训练代码,我们就可以使用我们创建的VGG卷积神经网络完成图片分类的任务了。神经网络会自动加载该数据集,将该数据集传入到VGG神经网络中就可以开启训练了,这就是使用pytorch深度学习框架构建VGG卷积神经网络完成图片分类的任务。