-
Notifications
You must be signed in to change notification settings - Fork 100
/
Loader.py
65 lines (57 loc) · 2.41 KB
/
Loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.utils.data as data
from os import listdir
from os.path import join
import numpy as np
import torch
import os
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def default_loader(path):
return Image.open(path).convert('RGB')
class Dataset(data.Dataset):
def __init__(self,contentPath,stylePath,fineSize):
super(Dataset,self).__init__()
self.contentPath = contentPath
self.image_list = [x for x in listdir(contentPath) if is_image_file(x)]
self.stylePath = stylePath
self.fineSize = fineSize
#self.normalize = transforms.Normalize(mean=[103.939,116.779,123.68],std=[1, 1, 1])
#normalize = transforms.Normalize(mean=[123.68,103.939,116.779],std=[1, 1, 1])
self.prep = transforms.Compose([
transforms.Scale(fineSize),
transforms.ToTensor(),
#transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to BGR
])
def __getitem__(self,index):
contentImgPath = os.path.join(self.contentPath,self.image_list[index])
styleImgPath = os.path.join(self.stylePath,self.image_list[index])
contentImg = default_loader(contentImgPath)
styleImg = default_loader(styleImgPath)
# resize
if(self.fineSize != 0):
w,h = contentImg.size
if(w > h):
if(w != self.fineSize):
neww = self.fineSize
newh = int(h*neww/w)
contentImg = contentImg.resize((neww,newh))
styleImg = styleImg.resize((neww,newh))
else:
if(h != self.fineSize):
newh = self.fineSize
neww = int(w*newh/h)
contentImg = contentImg.resize((neww,newh))
styleImg = styleImg.resize((neww,newh))
# Preprocess Images
contentImg = transforms.ToTensor()(contentImg)
styleImg = transforms.ToTensor()(styleImg)
return contentImg.squeeze(0),styleImg.squeeze(0),self.image_list[index]
def __len__(self):
# You should change 0 to the total size of your dataset.
return len(self.image_list)