如何使用tensorflow实现VGG网络,训练mnist数据集-创新互联
小编这次要给大家分享的是如何使用tensorflow实现VGG网络,训练mnist数据集,文章内容丰富,感兴趣的小伙伴可以来了解一下,希望大家阅读完这篇文章之后能够有所收获。
成都创新互联公司专注于企业网络营销推广、网站重做改版、灞桥网站定制设计、自适应品牌网站建设、HTML5建站、购物商城网站建设、集团公司官网建设、成都外贸网站建设公司、高端网站制作、响应式网页设计等建站业务,价格优惠性价比高,为灞桥等各大城市提供网站开发制作服务。VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。
先介绍下VGG
ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。
他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。
VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。
模型结构:
本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:
weights ={ 'wc1':tf.Variable(tf.random_normal([3,3,1,64])), 'wc2':tf.Variable(tf.random_normal([3,3,64,64])), 'wc3':tf.Variable(tf.random_normal([3,3,64,128])), 'wc4':tf.Variable(tf.random_normal([3,3,128,128])), 'wc5':tf.Variable(tf.random_normal([3,3,128,256])), 'wc6':tf.Variable(tf.random_normal([3,3,256,256])), 'wc7':tf.Variable(tf.random_normal([3,3,256,256])), 'wc8':tf.Variable(tf.random_normal([3,3,256,256])), 'wc9':tf.Variable(tf.random_normal([3,3,256,512])), 'wc10':tf.Variable(tf.random_normal([3,3,512,512])), 'wc11':tf.Variable(tf.random_normal([3,3,512,512])), 'wc12':tf.Variable(tf.random_normal([3,3,512,512])), 'wc13':tf.Variable(tf.random_normal([3,3,512,512])), 'wc14':tf.Variable(tf.random_normal([3,3,512,512])), 'wc15':tf.Variable(tf.random_normal([3,3,512,512])), 'wc16':tf.Variable(tf.random_normal([3,3,512,256])), 'wd1':tf.Variable(tf.random_normal([4096,4096])), 'wd2':tf.Variable(tf.random_normal([4096,4096])), 'out':tf.Variable(tf.random_normal([4096,nn_classes])), } biases ={ 'bc1':tf.Variable(tf.zeros([64])), 'bc2':tf.Variable(tf.zeros([64])), 'bc3':tf.Variable(tf.zeros([128])), 'bc4':tf.Variable(tf.zeros([128])), 'bc5':tf.Variable(tf.zeros([256])), 'bc6':tf.Variable(tf.zeros([256])), 'bc7':tf.Variable(tf.zeros([256])), 'bc8':tf.Variable(tf.zeros([256])), 'bc9':tf.Variable(tf.zeros([512])), 'bc10':tf.Variable(tf.zeros([512])), 'bc11':tf.Variable(tf.zeros([512])), 'bc12':tf.Variable(tf.zeros([512])), 'bc13':tf.Variable(tf.zeros([512])), 'bc14':tf.Variable(tf.zeros([512])), 'bc15':tf.Variable(tf.zeros([512])), 'bc16':tf.Variable(tf.zeros([256])), 'bd1':tf.Variable(tf.zeros([4096])), 'bd2':tf.Variable(tf.zeros([4096])), 'out':tf.Variable(tf.zeros([nn_classes])), }
另外有需要云服务器可以了解下创新互联scvps.cn,海内外云服务器15元起步,三天无理由+7*72小时售后在线,公司持有idc许可证,提供“云服务器、裸金属服务器、高防服务器、香港服务器、美国服务器、虚拟主机、免备案服务器”等云主机租用服务以及企业上云的综合解决方案,具有“安全稳定、简单易用、服务可用性高、性价比高”等特点与优势,专为企业上云打造定制,能够满足用户丰富、多元化的应用场景需求。
当前文章:如何使用tensorflow实现VGG网络,训练mnist数据集-创新互联
本文地址:http://pwwzsj.com/article/ccdios.html