在caffe中训练的时候如果使用多GPU则直接在运行程序的时候指定GPU的index即可,但是在Pytorch中则需要在声明模型之后,对声明的模型进行初始化,如:
cnn = DataParallel(AlexNet())
之后直接运行Pytorch之后则默认使用所有的GPU,为了说明上述初始化的作用,我用了一组畸变图像的数据集,写了一个Resent的模块,过了50个epoch,对比一下实验耗时的差别,代码如下:
# -*- coding: utf-8 -*-
# Implementation of https://arxiv.org/pdf/1512.03385.pdf/
# See section 4.2 for model architecture on CIFAR-10.
# Some part of the code was referenced below.
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
import os
from PIL import Image
import time
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.utils.data as data
from torch.nn import DataParallel
kwargs = {'num_workers': 1, 'pin_memory': True}
# def my dataloader, return the data and corresponding label
def default_loader(path):
return Image.open(path).convert('RGB')
class myImageFloder(data.Dataset): # Class inheritance
def __init__(self, root, label, transform=None, target_transform=None, loader=default_loader):
fh = open(label)
c = 0
imgs = []
class_names = []
for line in fh.readlines():
if c == 0:
class_names = [n.strip() for n in line.rstrip().split(' ')]
else:
cls = line.split() # cls is a list
fn = cls.pop(0)
if os.path.isfile(os.path.join(root, fn)):
imgs.append((fn, tuple([float(v) for v in cls]))) # imgs is the list,and the content is the tuple
# we can use the append way to append the element for list
c = c + 1
self.root = root
self.imgs = imgs
self.classes = class_names
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index] # eventhough the imgs is just a list, it can return the elements of is
# in a proper way
img = self.loader(os.path.join(self.root, fn))
if self.transform is not None:
img = self.transform(img)
return img, torch.Tensor(label)
def __len__(self):
return len(self.imgs)
def getName(self):
return self.classes
mytransform = transforms.Compose([transforms.ToTensor()]) # almost dont do any operation
train_data_root = "/home/ying/shiyongjie/rjp/generate_distortion_image_2016_03_15/0_Distorted_Image/Training"
test_data_root = "/home/ying/shiyongjie/rjp/generate_distortion_image_2