-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
86 lines (59 loc) · 2.18 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import numpy as np
import torch
import torch.nn as nn
## 데이터 로더를 구현하기
class Dataset(torch.utils.data.Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
lst_data = os.listdir(self.data_dir)
lst_label = [f for f in lst_data if f.startswith('label')]
lst_input = [f for f in lst_data if f.startswith('input')]
lst_label.sort()
lst_input.sort()
self.lst_label = lst_label
self.lst_input = lst_input
def __len__(self):
return len(self.lst_label)
def __getitem__(self, index):
label = np.load(os.path.join(self.data_dir, self.lst_label[index]))
input = np.load(os.path.join(self.data_dir, self.lst_input[index]))
input = input / 255
# label = label / 100
if label.ndim == 2:
label = label[:, :, np.newaxis]
if input.ndim == 2:
input = input[:, :, np.newaxis]
data = {'input': input, 'label': label}
if self.transform:
data = self.transform(data)
return data
## 트랜스폼 구현하기
class ToTensor(object):
def __call__(self, data):
label, input = data['label'], data['input']
label = label.transpose((2, 0, 1)).astype(np.float32)
input = input.transpose((2, 0, 1)).astype(np.float32)
data = {'label': torch.from_numpy(label), 'input': torch.from_numpy(input)}
return data
class Normalization(object):
def __init__(self, mean=0.5, std=0.5):
self.mean = mean
self.std = std
def __call__(self, data):
label, input = data['label'], data['input']
input = (input - self.mean)/self.std
data = {'label': label, 'input': input}
return data
class RandomFlip(object):
def __call__(self, data):
label, input = data['label'], data['input']
if np.random.rand() > 0.5:
label = np.fliplr(label)
input = np.fliplr(input)
if np.random.rand() > 0.5:
label = np.flipud(label)
input = np.flipud(input)
data = {'label': label, 'input': input}
return data