pytorch 数据工程(待更新)

目录

1、爬取图像

2、转换为png图像

3、划分训练集、验证集和测试集

4、数据增强


1、爬取图像

# -*- coding: utf-8 -*-

import requests
import time
import os
import sys
import importlib
import json
importlib.reload(sys)


def getManyPages(keyword,pages):
    params=[]
    for i in range(30,30*pages+30,30):
        params.append({
                      'tn': 'resultjson_com',
                      'ipn': 'rj',
                      'ct': 201326592,
                      'is': '',
                      'fp': 'result',
                      'queryWord': keyword,
                      'cl': 2,
                      'lm': -1,
                      'ie': 'utf-8',
                      'oe': 'utf-8',
                      'adpicid': '',
                      'st': -1,
                      'z': '',
                      'ic': 0,
                      'word': keyword,
                      's': '',
                      'se': '',
                      'tab': '',
                      'width': '',
                      'height': '',
                      'face': 0,
                      'istype': 2,
                      'qc': '',
                      'nc': 1,
                      'fr': '',
                      'pn': i,
                      'rn': 30,
                      'gsm': '1e',
                      '1488942260214': ''
                  })
    url = 'https://image.baidu.com/search/acjson'
    urls = []
    for i in params:
         try:
            urls.append(requests.get(url, params=i).json().get('data'))
         except json.decoder.JSONDecodeError:
            print("解析出错")
    return urls


def getImg(dataList, localPath):
    if not os.path.exists(localPath):
        os.mkdir(localPath)

    x = 1
    for list in dataList:
        for i in list:
            if i.get('thumbURL') != None:
                print('正在下载:%s' % i.get('thumbURL'))
                ir = requests.get(i.get('thumbURL'))
                string = localPath + '/' + '%04d.jpg' % x
                with open(string, 'wb') as f:
                    f.write(ir.content)
                x += 1
                time.sleep(0.5)
            else:
                print('图片链接不存在')

if __name__ == '__main__':
    keyword = input("输入关键词:")
    dataList = getManyPages(keyword, 100)
    getImg(dataList,'E:\\darknet-master\\Complied-darknet-master\\darknet-master\\data\\Mydata\\'+keyword)
        

参考资料:http://www.mamicode.com/info-detail-2185644.html

 

2、转换为png图像

把 cifar-10 的测试集转换成了 png 图片,充当实验的原始数据。

# -*- coding: utf-8 -*-
"""
Created on Thu Jun 27 11:09:32 2019

@author: xiaoxiaoke
"""
import cv2
import numpy as np
import os

savePath = '.\\data_batch1' 
srcPath='data_batch_1'

filepath = savePath.strip()    #去掉.\符号
isExists=os.path.exists(filepath)
if not isExists:
        os.makedirs(filepath) 
        print(filepath+'创建成功')
else:
        print(filepath+'目录已存在')
       
def unpickle(srcPath):
    import pickle
    with open(srcPath, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
dict1 = unpickle(srcPath)

for i in range(dict1[b'data'].shape[0]):
    img = dict1[b'data'][i]
    img = np.reshape(img, (3, 32,32))      #转为三维图片数组
    img = img.transpose((1,2,0))           #通道转换为CV2的要求形式,CV是BGR
    img_name = str(dict1[b'filenames'][i]) #图片的名字
    img_label = str(dict1[b'labels'][i])   #图片的标签
    cv2.imwrite(".\\cifarData\\"+img_label+"\\"+img_label+"_"+img_name[2:len(img_name)-1],img)#保存
    print(".\\cifarData\\"+img_label+"\\"+img_label+"_"+img_name[2:len(img_name)-1])
    

3、划分训练集、验证集和测试集

      将上述数据分为实验所需的训练集、验证集和测试集。

# -*- coding: utf-8 -*-
"""
Created on Thu Jul 25 14:41:01 2019

@author: xiaoxiaoke 
"""
import random
import numpy as np
import glob
import shutil #提供了多个针对文件或文件集合的高等级操作

datapath="F:\DeepLearning\pytorch\cifar10-master\cifar-10-batches-py\data_batch1\\"

imgs_list = glob.glob(datapath+'/*.png')
imageSum=len(imgs_list)
items = np.arange(imageSum)
numImage=np.random.shuffle(imgs_list)

rateTrain=int(0.7*imageSum)
rateVaild=int(0.85*imageSum)

#创建文件夹
trainPath=datapath+'trainPath'
testPath=datapath+'testPath'
vaildPath=datapath+'vaildPath'

if not os.path.exists(trainPath):
    os.makedirs(trainPath) 
if not os.path.exists(testPath):
    os.makedirs(testPath) 
if not os.path.exists(vaildPath):
    os.makedirs(vaildPath) 

for i in range(imageSum):
    if i<rateTrain:
        shutil.copy(imgs_list[i], trainPath)
    elif i<rateVaild:
        shutil.copy(imgs_list[i], testPath)
    else:
        shutil.copy(imgs_list[i], vaildPath)
        

4、数据增强

 常见的数据增强技术:

1、标准化(减去均均值,除以标准差,均值为0,标准差为1)

2、归一化(除以255,像素值归一化至[0 1])

3、裁剪(中心裁剪、随机裁剪、随机长宽比裁剪、上下左右中心裁剪、填充、resize)

4、旋转(随机旋转)

5、镜像翻转(依照概率P水平翻转、垂直翻转)

6、变换(修改亮度、饱和度和对比度、线性变换、转为灰度图、仿射变换)

上述操作的按照一定概率随机排列和关闭。

标准化和归一化的目标在于:去除不同维度上的量纲影响,有利于模型的快速收敛。 

计算某一个文件夹下的图像的均值:

# -*- coding: utf-8 -*-
"""
Created on Mon Jul  8 14:07:52 2019
计算图像的均值,应当计算图像在各类变换以后的均值
@author: xiaoxiaoke 
"""
import cv2 as cv
import numpy as np
import os
import sys
import matplotlib .pyplot as plt

def computeMeanDirImage(trainPath,width,height):
    matSum=np.zeros([width,height,3])
    cv.namedWindow("image",0)
    cv.resizeWindow("image",400,400)
    for root, dirs, files in os.walk(trainPath):
        print(root,dirs)
       
        matSum=np.zeros((width,height,3))
        for imageNum in range(0,len(os.listdir(root))):     #range(0,100): 
            print(files[imageNum])
            matImage=cv.imread(trainPath+files[imageNum])
            if matImage.shape[2]!=3:
                print("The image is gray!");
                return 0
            matImageNp = np.array(matImage) 
            matSum=matSum+matImageNp;
            cv.imshow("image",matImage)
            cv.waitKey(100)
            
    avergeMatFloat=matSum/len(os.listdir(root));
    avergeMatInt=avergeMatFloat.astype(np.int16)
    sc = plt.imshow(avergeMatInt)
    sc.set_cmap('hot')# 这里可以设置多种模式

    return avergeMatFloat

 

       

 

 

 

 


更多精彩内容