设为首页 加入收藏

TOP

Pytorch使用多GPU(一)
2017-09-30 13:01:39 】 浏览:3006
Tags:Pytorch 使用 GPU

在caffe中训练的时候如果使用多GPU则直接在运行程序的时候指定GPU的index即可,但是在Pytorch中则需要在声明模型之后,对声明的模型进行初始化,如:


cnn = DataParallel(AlexNet())


之后直接运行Pytorch之后则默认使用所有的GPU,为了说明上述初始化的作用,我用了一组畸变图像的数据集,写了一个Resent的模块,过了50个epoch,对比一下实验耗时的差别,代码如下:


# -*- coding: utf-8 -*-
# Implementation of https://arxiv.org/pdf/1512.03385.pdf/
# See section 4.2 for model architecture on CIFAR-10.
# Some part of the code was referenced below.
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py


import os
from PIL import Image
import time


import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.utils.data as data
from torch.nn import DataParallel



kwargs = {'num_workers': 1, 'pin_memory': True}
# def my dataloader, return the data and corresponding label



def default_loader(path):
    return Image.open(path).convert('RGB')



class myImageFloder(data.Dataset):  # Class inheritance
    def __init__(self, root, label, transform=None, target_transform=None, loader=default_loader):
        fh = open(label)
        c = 0
        imgs = []
        class_names = []
        for line in fh.readlines():
            if c == 0:
                class_names = [n.strip() for n in line.rstrip().split('    ')]
            else:
                cls = line.split()  # cls is a list
                fn = cls.pop(0)
                if os.path.isfile(os.path.join(root, fn)):
                    imgs.append((fn, tuple([float(v) for v in cls])))  # imgs is the list,and the content is the tuple
                    # we can use the append way to append the element for list
            c = c + 1
        self.root = root
        self.imgs = imgs
        self.classes = class_names
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader


    def __getitem__(self, index):
        fn, label = self.imgs[index]  # eventhough the imgs is just a list, it can return the elements of is
        # in a proper way
        img = self.loader(os.path.join(self.root, fn))
        if self.transform is not None:
            img = self.transform(img)
        return img, torch.Tensor(label)


    def __len__(self):
        return len(self.imgs)


    def getName(self):
        return self.classes


mytransform = transforms.Compose([transforms.ToTensor()])  # almost dont do any operation
train_data_root = "/home/ying/shiyongjie/rjp/generate_distortion_image_2016_03_15/0_Distorted_Image/Training"
test_data_root = "/home/ying/shiyongjie/rjp/generate_distortion_image_2

首页 上一页 1 2 3 4 5 6 7 下一页 尾页 1/15/15
】【打印繁体】【投稿】【收藏】 【推荐】【举报】【评论】 【关闭】 【返回顶部
上一篇Thrift框架快速入门 下一篇Java中的方法和方法重载

最新文章

热门文章

Hot 文章

Python

C 语言

C++基础

大数据基础

linux编程基础

C/C++面试题目