Tensorflow学习实战之VGG16猫狗大战迁移训练.docx

上传人:安*** 文档编号:73281941 上传时间:2023-02-17 格式:DOCX 页数:7 大小:18.24KB
返回 下载 相关 举报
Tensorflow学习实战之VGG16猫狗大战迁移训练.docx_第1页
第1页 / 共7页
Tensorflow学习实战之VGG16猫狗大战迁移训练.docx_第2页
第2页 / 共7页
点击查看更多>>
资源描述

《Tensorflow学习实战之VGG16猫狗大战迁移训练.docx》由会员分享,可在线阅读,更多相关《Tensorflow学习实战之VGG16猫狗大战迁移训练.docx(7页珍藏版)》请在taowenge.com淘文阁网|工程机械CAD图纸|机械工程制图|CAD装配图下载|SolidWorks_CaTia_CAD_UG_PROE_设计图分享下载上搜索。

1、Tensorflow学习实战之VGG16猫狗大战迁移训练Tensorflow学习实战之VGG16猫狗大战迁移训练继续跑深度学习Tensorflow的实战今天的是VGG16的猫狗大战整体构造框架由13层卷积以及3层全连接层组成对于对pthon不熟悉的我来讲整个搭建调试经过在于数据的写入处理维度的变换和不同函数的灵敏使用。整个经过在两个地方卡住第一个是训练时中途进程中断修改了不少参数均没有效果最终找到的问题是数据集出现了问题最开场用的是微软官网下载的一个猫狗数据集后来换数据集之后没有问题。第二个是最后测试的时候出现的问题载入图片后总提示维度有问题。所谓迁移训练呢就是讲在别人训练好的权重根底上呢进展

2、自己的训练进而获得适应自己所需任务的模型结果如下整体的框架是这样的model.py是用来装载搭建的卷积层框架test.py用来最终结果的预测猫狗大战用来跑训练首先来看看16层卷积的搭建吧定义卷积、池化、全连接层然后是搭建一个13层的卷积以及一个三层的全连接层微调训练经过中因为最终结果只需要预测猫狗class为2所以只修改了最终一层全连接层trainableTureimporttensorflowastfimportnumpyasnpimportos#模型定义classvgg16:def_init_(self,imgs):#参加全局列表把所需参数加载进类里self.parameters#初始化s

3、elf.imgsimgsself.convlayers()self.fc_layers()#输出类别的概率self.probstf.nn.softmax(self.fc8)defsaver(self):returntf.train.Saver()#定义卷积层defconv(self,name,input_data,out_channel,trainableFalse):#获取通道数in_channelinput_data.get_shape()-1withtf.variable_scope(name):#初始化kerneltf.get_variable(weights,3,3,in_chann

4、el,out_channel,dtypetf.float32,trainableFalse)biasestf.get_variable(baises,out_channel,dtypetf.float32,trainableFalse)conv_restf.nn.conv2d(input_data,kernel,1,1,1,1,paddingSAME)restf.nn.bias_add(conv_res,biases)outtf.nn.relu(res,namename)self.parameterskernel,biasesreturnout#定义全连接层deffc(self,name,in

5、put_data,out_channel,trainableTrue):#获取维度shapeinput_data.get_shape().as_list()iflen(shape)4:sizeshape-1*shape-2*shape-3else:sizeshape1#数据展开input_data_flattf.reshape(input_data,-1,size)withtf.variable_scope(name):#初始化weightstf.get_variable(weights,shapesize,out_channel,dtypetf.float32,trainabletraina

6、ble)biasestf.get_variable(baises,shapeout_channel,dtypetf.float32,trainabletrainable)restf.matmul(input_data_flat,weights)outtf.nn.relu(tf.nn.bias_add(res,biases)self.parametersweights,biasesreturnout#定义池化defmaxpool(self,name,input_data):outtf.nn.max_pool(input_data,1,2,2,1,1,2,2,1,paddingSAME,namen

7、ame)returnout#卷积堆叠defconvlayers(self):#cov1self.conv1_1self.conv(conv1_1,self.imgs,64,trainableFalse)self.conv1_2self.conv(conv1_2,self.conv1_1,64,trainableFalse)self.pool1self.maxpool(poolre1,self.conv1_2)#cov2self.conv2_1self.conv(conv2_1,self.pool1,128,trainableFalse)self.conv2_2self.conv(conv2_2

8、,self.conv2_1,128,trainableFalse)self.pool2self.maxpool(pool1,self.conv2_2)#cov3self.conv3_1self.conv(conv3_1,self.pool2,256,trainableFalse)self.conv3_2self.conv(conv3_2,self.conv3_1,256,trainableFalse)self.conv3_3self.conv(conv3_3,self.conv3_2,256,trainableFalse)self.pool3self.maxpool(pool3,self.co

9、nv3_3)#cov4self.conv4_1self.conv(conv4_1,self.pool3,512,trainableFalse)self.conv4_2self.conv(conv4_2,self.conv4_1,512,trainableFalse)self.conv4_3self.conv(conv4_3,self.conv4_2,512,trainableFalse)self.pool4self.maxpool(pool4,self.conv4_3)#cov5self.conv5_1self.conv(conv5_1,self.pool4,512,trainableFals

10、e)self.conv5_2self.conv(conv5_2,self.conv5_1,512,trainableFalse)self.conv5_3self.conv(conv5_3,self.conv5_2,512,trainableFalse)self.pool5self.maxpool(pool5,self.conv5_3)#全连接层deffc_layers(self):self.fc6self.fc(fc1,self.pool5,4096,trainableFalse)self.fc7self.fc(fc2,self.fc6,4096,trainableFalse)#要进展微调tr

11、ainable为trueself.fc8self.fc(fc3,self.fc7,2,trainableTrue)#载入权重defload_weights(self,weight_file,sess):weightsnp.load(weight_file)keyssorted(weights.keys()fori,kinenumerate(keys):ifinotin30,31:sess.run(self.parametersi.assign(weightsk)print(weightsloading.)接下来是训练经过分成了数据集的读取预处理通过不同文件夹进展分类独热编码转换和参数设置及进展

12、迭代训练#VGG16importtensorflowastfimportnumpyasnpimportosfromtimeimporttimeimportutilsimportmodel#数据输入defget_file(file_dir):imagestemp#对不同文件夹分类forroot,sub_folders,filesinos.walk(file_dir):fornameinfiles:images.append(os.path.join(root,name)fornameinsub_folders:temp.append(os.path.join(root,name)labelsfo

13、rone_folderintemp:n_imglen(os.listdir(one_folder)letterone_folder.split(/)-1ifletterCat:labelsnp.append(labels,n_img*0)else:labelsnp.append(labels,n_img*1)tempnp.array(images,labels)temptemp.transpose()np.random.shuffle(temp)image_listlist(temp:,0)label_listlist(temp:,1)label_listint(float(i)foriinl

14、abel_listreturnimage_list,label_listdefget_batch(image_list,label_list,img_width,img_height,batch_size,capacity):imagetf.cast(image_list,tf.string)labeltf.cast(label_list,tf.int32)input_queuetf.train.slice_input_producer(image,label)labelinput_queue1image_contentstf.read_file(input_queue0)imagetf.im

15、age.decode_jpeg(image_contents,channels3)imagetf.image.resize_images(image,224,224,methodtf.image.ResizeMethod.NEAREST_NEIGHBOR)image_batch,label_batchtf.train.batch(image,label,batch_sizebatch_size,num_threads64,capacitycapacity)label_batchtf.reshape(label_batch,batch_size)returnimage_batch,label_b

16、atch#转换独热编码defonehot(labels):n_samplelen(labels)n_classmax(labels)1onehot_labelsnp.zeros(n_sample,n_class)onehot_labelsnp.arange(n_sample),labels1returnonehot_labelsbatch_size10capacity256#存储容量#VGG预训练是减掉的均值means123.68,116.779,103.939img_width224img_height224start_timetime()xs,ysget_file(D:/demo/tens

17、orflow_猫狗大战/tensorflow_猫狗大战/data/)image_batch,label_batchget_batch(xs,ys,img_width,img_height,batch_size,capacity)print(len(xs),len(ys)xtf.placeholder(tf.float32,None,224,224,3)ytf.placeholder(tf.int32,None,2)vggmodel.vgg16(x)fc8_finetuiningvgg.probsloss_functiontf.reduce_mean(tf.nn.softmax_cross_en

18、tropy_with_logits(logitsfc8_finetuining,labelsy)optimizertf.train.GradientDescentOptimizer(learning_rate0.001).minimize(loss_function)sesstf.Session()sess.run(tf.global_variables_initializer()vgg.load_weights(vgg16_weights.npz,sess)savertf.train.Saver()coordtf.train.Coordinator()threadstf.train.star

19、t_queue_runners(coordcoord,sesssess)epoch_start_timetime()foriinrange(100):print(开场训练.)images,labelssess.run(image_batch,label_batch)labelsonehot(labels)sess.run(optimizer,feed_dictx:images,y:labels)losssess.run(loss_function,feed_dictx:images,y:labels)print(loss:%f%loss)epoch_end_timetime()print(ti

20、me:,(epoch_end_time-epoch_start_time)epoch_start_timeepoch_end_timeif(i1)%200:saver.save(sess,os.path.join(./model/,epoch:06d.ckpt.format(i1)print(_epoch%dfinish%(i1)saver.save(sess,./model/)print(finish,alltime:,time()-start_time()coord.request_stop()coord.join(threads)最后是检测局部的函数。importtensorflowas

21、tfimportnumpyasnpfromimageioimportimreadimportmodelasmodelfromPILimportImageimportmatplotlib.pyplotasplttf.reset_default_graph()means123.68,116.779,103.939xtf.placeholder(tf.float32,None,224,224,3)withtf.Session()assess:vggmodel.vgg16(x)fc8_finetuiningvgg.probssavertf.train.Saver()print(restoring.)s

22、aver.restore(sess,./model/)filepath./1116.jpgimage_raw_datatf.gfile.FastGFile(filepath,rb).read()img_datatf.image.decode_jpeg(image_raw_data)plt.imshow(img_data.eval()imagetf.image.resize_images(img_data,224,224,methodtf.image.ResizeMethod.NEAREST_NEIGHBOR)imgimage.eval()print(img.shape)imagetf.expand_dims(image,0)print(image.shape)#forcinrange(3):#image:,:,c-meanscprobsess.run(fc8_finetuining,feed_dictx:img)max_indexnp.argmax(prob)ifmax_index0:plt.title(Theresult:cat%.6f%prob:,0)plt.show()else:plt.title(Theresult:dog%.6f%prob:,1)plt.show()全程手敲几近崩溃趁年度轻读点书不想做小工了算了不讲了我今晚还有两车砖要搬爆炒小肥牛

展开阅读全文
相关资源
相关搜索

当前位置:首页 > 技术资料 > 工程图纸

本站为文档C TO C交易模式,本站只提供存储空间、用户上传的文档直接被用户下载,本站只是中间服务平台,本站所有文档下载所得的收益归上传人(含作者)所有。本站仅对用户上传内容的表现方式做保护处理,对上载内容本身不做任何修改或编辑。若文档所含内容侵犯了您的版权或隐私,请立即通知淘文阁网,我们立即给予删除!客服QQ:136780468 微信:18945177775 电话:18904686070

工信部备案号:黑ICP备15003705号© 2020-2023 www.taowenge.com 淘文阁