Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a DVS eye tracking dataset from the paper 3ET #276

Merged
merged 9 commits into from
Dec 22, 2023
2 changes: 2 additions & 0 deletions tonic/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .nmnist import NMNIST
from .pokerdvs import POKERDVS
from .s_mnist import SMNIST
from .threeET_eyetracking import ThreeET_Eyetracking
from .tum_vie import TUMVIE
from .visual_place_recognition import VPR

Expand All @@ -28,6 +29,7 @@
"SHD",
"SMNIST",
"SSC",
"ThreeET_Eyetracking",
"TUMVIE",
"VPR",
"DVSLip",
Expand Down
113 changes: 113 additions & 0 deletions tonic/datasets/threeET_eyetracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import os
from typing import Any, Callable, Optional, Tuple

import h5py
import numpy as np

from tonic.dataset import Dataset
from tonic.io import make_structured_array


class ThreeET_Eyetracking(Dataset):
"""3ET DVS eye tracking `3ET <https://github.com/qinche106/cb-convlstm-eyetracking>`_
::

@article{chen20233et,
title={3ET: Efficient Event-based Eye Tracking using a Change-Based ConvLSTM Network},
author={Chen, Qinyu and Wang, Zuowen and Liu, Shih-Chii and Gao, Chang},
journal={arXiv preprint arXiv:2308.11771},
year={2023}
}

Parameters:
save_to (string): Location to save files to on disk.
transform (callable, optional): A callable of transforms to apply to the data.
split (string, optional): The dataset split to use, ``train`` or ``val``.
target_transform (callable, optional): A callable of transforms to apply to the targets/labels.
transforms (callable, optional): A callable of transforms that is applied to both data and
labels at the same time.

Returns:
A dataset object that can be indexed or iterated over.
One sample returns a tuple of (events, targets).
"""

url = "https://dl.dropboxusercontent.com/s/1hyer8egd8843t9/ThreeET_Eyetracking.zip?dl=0"
filename = "ThreeET_Eyetracking.zip"
file_md5 = "b6c652b06fdfd85721f39e2dbe12f4e8"

sensor_size = (240, 180, 2)
dtype = np.dtype([("t", int), ("x", int), ("y", int), ("p", int)])
ordering = dtype.names

def __init__(
self,
save_to: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super().__init__(
save_to,
transform=transform,
target_transform=target_transform,
transforms=transforms,
)

# if not exist, download from url
if not self._check_exists():
self.download()

data_dir = os.path.join(save_to, "ThreeET_Eyetracking")
# Load filenames from the provided lists
if split == "train":
filenames = self.load_filenames(os.path.join(data_dir, "train_files.txt"))
elif split == "val":
filenames = self.load_filenames(os.path.join(data_dir, "val_files.txt"))
else:
raise ValueError("Invalid split name")

# Get the data file paths and target file paths
self.data = [os.path.join(data_dir, "data", f + ".h5") for f in filenames]
self.targets = [os.path.join(data_dir, "labels", f + ".txt") for f in filenames]

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Returns:
(events, target) where target is index of the target class.
"""
# get events from .h5 file
with h5py.File(self.data[index], "r") as f:
events = f["events"][:]
# load the sparse labels
with open(self.targets[index], "r") as f:
target = np.array(
[line.strip().split() for line in f.readlines()], np.float64
)

events = make_structured_array(
events[:, 0], # time in us
events[:, 1], # x
events[:, 2], # y
events[:, 3], # polarity in 1 or 0
dtype=self.dtype,
)

if self.transform is not None:
events = self.transform(events)
if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
events, target = self.transforms(events, target)
return events, target

def __len__(self):
return len(self.data)

def _check_exists(self):
return self._is_file_present()

def load_filenames(self, path):
with open(path, "r") as f:
return [line.strip() for line in f.readlines()]