Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
  • Loading branch information
ir2718 committed Oct 26, 2024
1 parent e86eab0 commit b629fff
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 102 deletions.
30 changes: 21 additions & 9 deletions src/pytorch_metric_learning/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from PIL import Image
from torch.utils.data import Dataset
import os
from abc import ABC, abstractmethod

from PIL import Image
from torch.utils.data import Dataset


class BaseDataset(ABC, Dataset):

def __init__(self, root, split="train+test", transform=None, target_transform=None, download=False):
def __init__(
self,
root,
split="train+test",
transform=None,
target_transform=None,
download=False,
):
self.root = root

if download:
Expand All @@ -18,24 +27,27 @@ def __init__(self, root, split="train+test", transform=None, target_transform=No
# The given directory does not exist so the user should be aware of downloading it
# Otherwise proceed as usual
if not os.path.isdir(self.root):
raise ValueError("The given path does not exist. "
raise ValueError(
"The given path does not exist. "
"You should probably initialize the dataset with download=True."
)

self.transform = transform
self.target_transform = target_transform

if split not in self.get_available_splits():
raise ValueError(f"Supported splits are: {', '.join(self.get_available_splits())}")

raise ValueError(
f"Supported splits are: {', '.join(self.get_available_splits())}"
)

self.split = split

self.generate_split()

@abstractmethod
def generate_split():
raise NotImplementedError

@abstractmethod
def download_and_remove():
raise NotImplementedError
Expand All @@ -45,7 +57,7 @@ def get_available_splits(self):

def __len__(self):
return len(self.labels)

def __getitem__(self, idx):
img = Image.open(self.paths[idx])
label = self.labels[idx]
Expand All @@ -56,4 +68,4 @@ def __getitem__(self, idx):
if self.target_transform is not None:
label = self.target_transform(label)

return (img, label)
return (img, label)
26 changes: 17 additions & 9 deletions src/pytorch_metric_learning/datasets/cars196.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from ..datasets.base_dataset import BaseDataset
from ..utils.common_functions import _urlretrieve
import os
import zipfile

from ..datasets.base_dataset import BaseDataset
from ..utils.common_functions import _urlretrieve


class Cars196(BaseDataset):

DOWNLOAD_URL = "https://www.kaggle.com/api/v1/datasets/download/jutrera/stanford-car-dataset-by-classes-folder"
Expand All @@ -15,7 +17,7 @@ def generate_split(self):
classes = set(range(99, 197))
else:
classes = set(range(1, 197))

with open(os.path.join(self.root, "names.csv"), "r") as f:
names = [x.strip() for x in f.readlines()]

Expand All @@ -28,13 +30,12 @@ def generate_split(self):
paths = paths_train + paths_test
labels = labels_train + labels_test


self.paths, self.labels = [], []
for p, l in zip(paths, labels):
if l in classes:
self.paths.append(p)
self.labels.append(l)

def _load_csv(self, path, names, split):
all_paths, all_labels = [], []
with open(path, "r") as f:
Expand All @@ -44,16 +45,23 @@ def _load_csv(self, path, names, split):
curr_label = path_annos[-1]
all_paths.append(
os.path.join(
self.root, "car_data", "car_data", split, names[int(curr_label) - 1].replace("/","-"), curr_path
self.root,
"car_data",
"car_data",
split,
names[int(curr_label) - 1].replace("/", "-"),
curr_path,
)
)
all_labels.append(int(curr_label))
return all_paths, all_labels

def download_and_remove(self):
os.makedirs(self.root, exist_ok=True)
download_folder_path = os.path.join(self.root, Cars196.DOWNLOAD_URL.split('/')[-1])
download_folder_path = os.path.join(
self.root, Cars196.DOWNLOAD_URL.split("/")[-1]
)
_urlretrieve(url=Cars196.DOWNLOAD_URL, filename=download_folder_path)
with zipfile.ZipFile(download_folder_path, 'r') as zip_ref:
with zipfile.ZipFile(download_folder_path, "r") as zip_ref:
zip_ref.extractall(self.root)
os.remove(download_folder_path)
os.remove(download_folder_path)
18 changes: 11 additions & 7 deletions src/pytorch_metric_learning/datasets/cub.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from ..datasets.base_dataset import BaseDataset
from ..utils.common_functions import _urlretrieve
import os
import tarfile

from ..datasets.base_dataset import BaseDataset
from ..utils.common_functions import _urlretrieve


class CUB(BaseDataset):

DOWNLOAD_URL = "https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz"

def generate_split(self):
dir_name = CUB.DOWNLOAD_URL.split('/')[-1].replace(".tgz", "")
dir_name = CUB.DOWNLOAD_URL.split("/")[-1].replace(".tgz", "")

# Training split is first 100 classes, other 100 is test
if self.split == "train":
Expand All @@ -17,14 +19,14 @@ def generate_split(self):
classes = set(range(101, 201))
else:
classes = set(range(1, 201))

# Find ids which correspond to the classes in the split
self.paths, self.labels = [], []
with open(os.path.join(self.root, dir_name, "image_class_labels.txt")) as f1:
with open(os.path.join(self.root, dir_name, "images.txt")) as f2:
for l1, l2 in zip(f1, f2):
img_idx1, class_idx = list(map(int, l1.split()))

if class_idx not in classes:
continue

Expand All @@ -33,12 +35,14 @@ def generate_split(self):

# If the image ids correspond it's a match
if img_idx1 == img_idx2:
self.paths.append(os.path.join(self.root, dir_name, "images", img_path))
self.paths.append(
os.path.join(self.root, dir_name, "images", img_path)
)
self.labels.append(class_idx)

def download_and_remove(self):
os.makedirs(self.root, exist_ok=True)
download_folder_path = os.path.join(self.root, CUB.DOWNLOAD_URL.split('/')[-1])
download_folder_path = os.path.join(self.root, CUB.DOWNLOAD_URL.split("/")[-1])
_urlretrieve(url=CUB.DOWNLOAD_URL, filename=download_folder_path)
with tarfile.open(download_folder_path, "r:gz") as tar:
tar.extractall(self.root)
Expand Down
50 changes: 34 additions & 16 deletions src/pytorch_metric_learning/datasets/inaturalist2018.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from ..datasets.base_dataset import BaseDataset
from ..utils.common_functions import _urlretrieve
import json
import os
import tarfile
import zipfile
import json

from ..datasets.base_dataset import BaseDataset
from ..utils.common_functions import _urlretrieve


class INaturalist2018(BaseDataset):

IMG_DOWNLOAD_URL = "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz"
TRAIN_ANN_URL = "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train2018.json.tar.gz"
VAL_ANN_URL = "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/val2018.json.tar.gz"
VAL_ANN_URL = (
"https://ml-inat-competition-datasets.s3.amazonaws.com/2018/val2018.json.tar.gz"
)
SPLITS_URL = "https://drive.google.com/uc?id=1sXfkBTFDrRU3__-NUs1qBP3sf_0uMB98"

def generate_split(self):
Expand All @@ -24,8 +28,8 @@ def generate_split(self):

imgs, anns = val_imgs + train_imgs, val_anns + train_anns

path2id = {x["file_name"]:x["id"] for x in imgs}
id2label = {x["image_id"]:x["category_id"] for x in anns}
path2id = {x["file_name"]: x["id"] for x in imgs}
id2label = {x["image_id"]: x["category_id"] for x in anns}

if self.split in ["train", "test"]:
paths = self._load_split_txt(self.split)
Expand All @@ -36,7 +40,7 @@ def generate_split(self):
train_paths = self._load_split_txt("train")
train_ids = [path2id[p] for p in train_paths]
train_labels = [id2label[i] for i in train_ids]

test_paths = self._load_split_txt("test")
test_ids = [path2id[p] for p in test_paths]
test_labels = [id2label[i] for i in test_ids]
Expand All @@ -49,32 +53,46 @@ def generate_split(self):

def _load_split_txt(self, split):
paths = []
with open(os.path.join(self.root, "Inat_dataset_splits", f"Inaturalist_{split}_set1.txt")) as f:
with open(
os.path.join(
self.root, "Inat_dataset_splits", f"Inaturalist_{split}_set1.txt"
)
) as f:
for l in f:
paths.append(l.strip())
return paths

def download_and_remove(self):
download_folder_path = os.path.join(self.root, INaturalist2018.IMG_DOWNLOAD_URL.split('/')[-1])
_urlretrieve(url=INaturalist2018.IMG_DOWNLOAD_URL, filename=download_folder_path)
download_folder_path = os.path.join(
self.root, INaturalist2018.IMG_DOWNLOAD_URL.split("/")[-1]
)
_urlretrieve(
url=INaturalist2018.IMG_DOWNLOAD_URL, filename=download_folder_path
)
with tarfile.open(download_folder_path, "r:gz") as tar:
tar.extractall(self.root)
os.remove(download_folder_path)

download_folder_path = os.path.join(self.root, INaturalist2018.TRAIN_ANN_URL.split('/')[-1])

download_folder_path = os.path.join(
self.root, INaturalist2018.TRAIN_ANN_URL.split("/")[-1]
)
_urlretrieve(url=INaturalist2018.TRAIN_ANN_URL, filename=download_folder_path)
with tarfile.open(download_folder_path, "r:gz") as tar:
tar.extractall(self.root)
os.remove(download_folder_path)

download_folder_path = os.path.join(self.root, INaturalist2018.VAL_ANN_URL.split('/')[-1])
download_folder_path = os.path.join(
self.root, INaturalist2018.VAL_ANN_URL.split("/")[-1]
)
_urlretrieve(url=INaturalist2018.VAL_ANN_URL, filename=download_folder_path)
with tarfile.open(download_folder_path, "r:gz") as tar:
tar.extractall(self.root)
os.remove(download_folder_path)

download_folder_path = os.path.join(self.root, INaturalist2018.SPLITS_URL.split('/')[-1])
download_folder_path = os.path.join(
self.root, INaturalist2018.SPLITS_URL.split("/")[-1]
)
_urlretrieve(url=INaturalist2018.SPLITS_URL, filename=download_folder_path)
with zipfile.ZipFile(download_folder_path, "r") as zip_ref:
zip_ref.extractall(self.root)
os.remove(download_folder_path)
os.remove(download_folder_path)
26 changes: 17 additions & 9 deletions src/pytorch_metric_learning/datasets/sop.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from ..datasets.base_dataset import BaseDataset
from ..utils.common_functions import _urlretrieve
import os
import zipfile

from ..datasets.base_dataset import BaseDataset
from ..utils.common_functions import _urlretrieve


class StanfordOnlineProducts(BaseDataset):

DOWNLOAD_URL = "https://drive.usercontent.google.com/download?id=1TclrpQOF_ullUP99wk_gjGN8pKvtErG8&export=download&authuser=0&confirm=t"
Expand All @@ -21,26 +23,32 @@ def generate_split(self):

def _load_split_txt(self, split):
paths, labels = [], []
with open(os.path.join(self.root, "Stanford_Online_Products", f"Ebay_{split}.txt")) as f:
with open(
os.path.join(self.root, "Stanford_Online_Products", f"Ebay_{split}.txt")
) as f:
for i, l in enumerate(f):
if i == 0:
continue
l_split = l.strip().split()
label, path = int(l_split[1]), l_split[3]
label, path = int(l_split[1]), l_split[3]
paths.append(os.path.join(self.root, "Stanford_Online_Products", path))
labels.append(label)
return paths, labels

def download_and_remove(self):
os.makedirs(self.root, exist_ok=True)
download_folder_path = os.path.join(self.root, StanfordOnlineProducts.DOWNLOAD_URL.split('/')[-1])
_urlretrieve(url=StanfordOnlineProducts.DOWNLOAD_URL, filename=download_folder_path)
download_folder_path = os.path.join(
self.root, StanfordOnlineProducts.DOWNLOAD_URL.split("/")[-1]
)
_urlretrieve(
url=StanfordOnlineProducts.DOWNLOAD_URL, filename=download_folder_path
)
with zipfile.ZipFile(download_folder_path, "r") as zip_ref:
zip_ref.extractall(self.root)
os.remove(download_folder_path)



# if __name__ == "__main__":
# train_dataset = StanfordOnlineProducts(root="data_sop", split="train", download=True)
# train_dataset = StanfordOnlineProducts(root="data_sop", split="test", download=True)
# train_dataset = StanfordOnlineProducts(root="data_sop", split="train+test", download=True)

2 changes: 2 additions & 0 deletions src/pytorch_metric_learning/samplers/m_per_class_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ def __iter__(self):
curr_label_set = self.labels
else:
curr_label_set = self.labels[: self.batch_size // self.m_per_class]
print(curr_label_set)
for label in curr_label_set:
t = self.labels_to_indices[label]
idx_list[i : i + self.m_per_class] = c_f.safe_random_choice(
t, size=self.m_per_class
)
i += self.m_per_class

return iter(idx_list)

def calculate_num_iters(self):
Expand Down
6 changes: 5 additions & 1 deletion src/pytorch_metric_learning/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@
NUMPY_RANDOM = np.random
COLLECT_STATS = False


# taken from:
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/utils.py#L27
def _urlretrieve(url, filename, chunk_size=1024 * 32):
with urllib.request.urlopen(urllib.request.Request(url)) as response:
with open(filename, "wb") as fh, tqdm(total=response.length, unit="B", unit_scale=True) as pbar:
with open(filename, "wb") as fh, tqdm(
total=response.length, unit="B", unit_scale=True
) as pbar:
while chunk := response.read(chunk_size):
fh.write(chunk)
pbar.update(len(chunk))


def set_logger_name(name):
global LOGGER_NAME
global LOGGER
Expand Down
Loading

0 comments on commit b629fff

Please sign in to comment.