《Tensorflow学习实战之VGG16猫狗大战迁移训练.docx》由会员分享,可在线阅读,更多相关《Tensorflow学习实战之VGG16猫狗大战迁移训练.docx(7页珍藏版)》请在taowenge.com淘文阁网|工程机械CAD图纸|机械工程制图|CAD装配图下载|SolidWorks_CaTia_CAD_UG_PROE_设计图分享下载上搜索。
1、Tensorflow学习实战之VGG16猫狗大战迁移训练Tensorflow学习实战之VGG16猫狗大战迁移训练继续跑深度学习Tensorflow的实战今天的是VGG16的猫狗大战整体构造框架由13层卷积以及3层全连接层组成对于对pthon不熟悉的我来讲整个搭建调试经过在于数据的写入处理维度的变换和不同函数的灵敏使用。整个经过在两个地方卡住第一个是训练时中途进程中断修改了不少参数均没有效果最终找到的问题是数据集出现了问题最开场用的是微软官网下载的一个猫狗数据集后来换数据集之后没有问题。第二个是最后测试的时候出现的问题载入图片后总提示维度有问题。所谓迁移训练呢就是讲在别人训练好的权重根底上呢进展
2、自己的训练进而获得适应自己所需任务的模型结果如下整体的框架是这样的model.py是用来装载搭建的卷积层框架test.py用来最终结果的预测猫狗大战用来跑训练首先来看看16层卷积的搭建吧定义卷积、池化、全连接层然后是搭建一个13层的卷积以及一个三层的全连接层微调训练经过中因为最终结果只需要预测猫狗class为2所以只修改了最终一层全连接层trainableTureimporttensorflowastfimportnumpyasnpimportos#模型定义classvgg16:def_init_(self,imgs):#参加全局列表把所需参数加载进类里self.parameters#初始化s
3、elf.imgsimgsself.convlayers()self.fc_layers()#输出类别的概率self.probstf.nn.softmax(self.fc8)defsaver(self):returntf.train.Saver()#定义卷积层defconv(self,name,input_data,out_channel,trainableFalse):#获取通道数in_channelinput_data.get_shape()-1withtf.variable_scope(name):#初始化kerneltf.get_variable(weights,3,3,in_chann
4、el,out_channel,dtypetf.float32,trainableFalse)biasestf.get_variable(baises,out_channel,dtypetf.float32,trainableFalse)conv_restf.nn.conv2d(input_data,kernel,1,1,1,1,paddingSAME)restf.nn.bias_add(conv_res,biases)outtf.nn.relu(res,namename)self.parameterskernel,biasesreturnout#定义全连接层deffc(self,name,in
5、put_data,out_channel,trainableTrue):#获取维度shapeinput_data.get_shape().as_list()iflen(shape)4:sizeshape-1*shape-2*shape-3else:sizeshape1#数据展开input_data_flattf.reshape(input_data,-1,size)withtf.variable_scope(name):#初始化weightstf.get_variable(weights,shapesize,out_channel,dtypetf.float32,trainabletraina
6、ble)biasestf.get_variable(baises,shapeout_channel,dtypetf.float32,trainabletrainable)restf.matmul(input_data_flat,weights)outtf.nn.relu(tf.nn.bias_add(res,biases)self.parametersweights,biasesreturnout#定义池化defmaxpool(self,name,input_data):outtf.nn.max_pool(input_data,1,2,2,1,1,2,2,1,paddingSAME,namen
7、ame)returnout#卷积堆叠defconvlayers(self):#cov1self.conv1_1self.conv(conv1_1,self.imgs,64,trainableFalse)self.conv1_2self.conv(conv1_2,self.conv1_1,64,trainableFalse)self.pool1self.maxpool(poolre1,self.conv1_2)#cov2self.conv2_1self.conv(conv2_1,self.pool1,128,trainableFalse)self.conv2_2self.conv(conv2_2
8、,self.conv2_1,128,trainableFalse)self.pool2self.maxpool(pool1,self.conv2_2)#cov3self.conv3_1self.conv(conv3_1,self.pool2,256,trainableFalse)self.conv3_2self.conv(conv3_2,self.conv3_1,256,trainableFalse)self.conv3_3self.conv(conv3_3,self.conv3_2,256,trainableFalse)self.pool3self.maxpool(pool3,self.co
9、nv3_3)#cov4self.conv4_1self.conv(conv4_1,self.pool3,512,trainableFalse)self.conv4_2self.conv(conv4_2,self.conv4_1,512,trainableFalse)self.conv4_3self.conv(conv4_3,self.conv4_2,512,trainableFalse)self.pool4self.maxpool(pool4,self.conv4_3)#cov5self.conv5_1self.conv(conv5_1,self.pool4,512,trainableFals
10、e)self.conv5_2self.conv(conv5_2,self.conv5_1,512,trainableFalse)self.conv5_3self.conv(conv5_3,self.conv5_2,512,trainableFalse)self.pool5self.maxpool(pool5,self.conv5_3)#全连接层deffc_layers(self):self.fc6self.fc(fc1,self.pool5,4096,trainableFalse)self.fc7self.fc(fc2,self.fc6,4096,trainableFalse)#要进展微调tr
11、ainable为trueself.fc8self.fc(fc3,self.fc7,2,trainableTrue)#载入权重defload_weights(self,weight_file,sess):weightsnp.load(weight_file)keyssorted(weights.keys()fori,kinenumerate(keys):ifinotin30,31:sess.run(self.parametersi.assign(weightsk)print(weightsloading.)接下来是训练经过分成了数据集的读取预处理通过不同文件夹进展分类独热编码转换和参数设置及进展
12、迭代训练#VGG16importtensorflowastfimportnumpyasnpimportosfromtimeimporttimeimportutilsimportmodel#数据输入defget_file(file_dir):imagestemp#对不同文件夹分类forroot,sub_folders,filesinos.walk(file_dir):fornameinfiles:images.append(os.path.join(root,name)fornameinsub_folders:temp.append(os.path.join(root,name)labelsfo
13、rone_folderintemp:n_imglen(os.listdir(one_folder)letterone_folder.split(/)-1ifletterCat:labelsnp.append(labels,n_img*0)else:labelsnp.append(labels,n_img*1)tempnp.array(images,labels)temptemp.transpose()np.random.shuffle(temp)image_listlist(temp:,0)label_listlist(temp:,1)label_listint(float(i)foriinl
14、abel_listreturnimage_list,label_listdefget_batch(image_list,label_list,img_width,img_height,batch_size,capacity):imagetf.cast(image_list,tf.string)labeltf.cast(label_list,tf.int32)input_queuetf.train.slice_input_producer(image,label)labelinput_queue1image_contentstf.read_file(input_queue0)imagetf.im
15、age.decode_jpeg(image_contents,channels3)imagetf.image.resize_images(image,224,224,methodtf.image.ResizeMethod.NEAREST_NEIGHBOR)image_batch,label_batchtf.train.batch(image,label,batch_sizebatch_size,num_threads64,capacitycapacity)label_batchtf.reshape(label_batch,batch_size)returnimage_batch,label_b
16、atch#转换独热编码defonehot(labels):n_samplelen(labels)n_classmax(labels)1onehot_labelsnp.zeros(n_sample,n_class)onehot_labelsnp.arange(n_sample),labels1returnonehot_labelsbatch_size10capacity256#存储容量#VGG预训练是减掉的均值means123.68,116.779,103.939img_width224img_height224start_timetime()xs,ysget_file(D:/demo/tens
17、orflow_猫狗大战/tensorflow_猫狗大战/data/)image_batch,label_batchget_batch(xs,ys,img_width,img_height,batch_size,capacity)print(len(xs),len(ys)xtf.placeholder(tf.float32,None,224,224,3)ytf.placeholder(tf.int32,None,2)vggmodel.vgg16(x)fc8_finetuiningvgg.probsloss_functiontf.reduce_mean(tf.nn.softmax_cross_en
18、tropy_with_logits(logitsfc8_finetuining,labelsy)optimizertf.train.GradientDescentOptimizer(learning_rate0.001).minimize(loss_function)sesstf.Session()sess.run(tf.global_variables_initializer()vgg.load_weights(vgg16_weights.npz,sess)savertf.train.Saver()coordtf.train.Coordinator()threadstf.train.star
19、t_queue_runners(coordcoord,sesssess)epoch_start_timetime()foriinrange(100):print(开场训练.)images,labelssess.run(image_batch,label_batch)labelsonehot(labels)sess.run(optimizer,feed_dictx:images,y:labels)losssess.run(loss_function,feed_dictx:images,y:labels)print(loss:%f%loss)epoch_end_timetime()print(ti
20、me:,(epoch_end_time-epoch_start_time)epoch_start_timeepoch_end_timeif(i1)%200:saver.save(sess,os.path.join(./model/,epoch:06d.ckpt.format(i1)print(_epoch%dfinish%(i1)saver.save(sess,./model/)print(finish,alltime:,time()-start_time()coord.request_stop()coord.join(threads)最后是检测局部的函数。importtensorflowas
21、tfimportnumpyasnpfromimageioimportimreadimportmodelasmodelfromPILimportImageimportmatplotlib.pyplotasplttf.reset_default_graph()means123.68,116.779,103.939xtf.placeholder(tf.float32,None,224,224,3)withtf.Session()assess:vggmodel.vgg16(x)fc8_finetuiningvgg.probssavertf.train.Saver()print(restoring.)s
22、aver.restore(sess,./model/)filepath./1116.jpgimage_raw_datatf.gfile.FastGFile(filepath,rb).read()img_datatf.image.decode_jpeg(image_raw_data)plt.imshow(img_data.eval()imagetf.image.resize_images(img_data,224,224,methodtf.image.ResizeMethod.NEAREST_NEIGHBOR)imgimage.eval()print(img.shape)imagetf.expand_dims(image,0)print(image.shape)#forcinrange(3):#image:,:,c-meanscprobsess.run(fc8_finetuining,feed_dictx:img)max_indexnp.argmax(prob)ifmax_index0:plt.title(Theresult:cat%.6f%prob:,0)plt.show()else:plt.title(Theresult:dog%.6f%prob:,1)plt.show()全程手敲几近崩溃趁年度轻读点书不想做小工了算了不讲了我今晚还有两车砖要搬爆炒小肥牛