Skip to content

Commit

Permalink
Merge pull request #276 from ZuowenWang0000/develop
Browse files Browse the repository at this point in the history
Add a DVS eye tracking dataset from the paper 3ET
  • Loading branch information
biphasic authored Dec 22, 2023
2 parents 52de696 + 83e82ae commit debdd2e
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 1 deletion.
10 changes: 9 additions & 1 deletion docs/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ Star tracking

EBSSA

Eye tracking
-------------------
.. autosummary::
:toctree: generated/
:template: class_dataset.rst

ThreeET_Eyetracking

.. currentmodule:: tonic.prototype.datasets

Prototype iterable datasets
Expand All @@ -65,4 +73,4 @@ Prototype iterable datasets
NCARS
STMNIST
Gen1AutomotiveDetection
Gen4AutomotiveDetectionMini
Gen4AutomotiveDetectionMini
31 changes: 31 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,37 @@ def inject_fake_data(self, tmpdir):
return {"n_samples": 1}


class ThreeET_EyetrackingTestCase(dataset_utils.DatasetTestCase):
DATASET_CLASS = datasets.ThreeET_Eyetracking
FEATURE_TYPES = (datasets.ThreeET_Eyetracking.dtype,)
TARGET_TYPES = (np.ndarray,)
KWARGS = {"split": "train"}

def inject_fake_data(self, tmpdir):
testfolder = os.path.join(tmpdir, "ThreeET_Eyetracking")
os.makedirs(testfolder, exist_ok=True)
os.makedirs(os.path.join(testfolder, "data"), exist_ok=True)
os.makedirs(os.path.join(testfolder, "labels"), exist_ok=True)
# write one line of file name into train_files.txt under testfolder
os.system("echo testcase > " + os.path.join(testfolder, "train_files.txt"))
filename = "testcase"

# download test h5 file
download_url(
url=base_url + "4aiA4BAqz5km4Gc/download/" + filename + ".h5",
root=os.path.join(testfolder, "data"),
filename=filename + ".h5",
)
# # download test labels
download_url(
url=base_url + "G6ejNmXNnB2sKyc/download/" + filename + ".txt",
root=os.path.join(testfolder, "labels"),
filename=filename + ".txt",
)

return {"n_samples": 1}


class NCaltech101TestCase(dataset_utils.DatasetTestCase):
DATASET_CLASS = datasets.NCALTECH101
FEATURE_TYPES = (datasets.NCALTECH101.dtype,)
Expand Down
2 changes: 2 additions & 0 deletions tonic/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .nmnist import NMNIST
from .pokerdvs import POKERDVS
from .s_mnist import SMNIST
from .threeET_eyetracking import ThreeET_Eyetracking
from .tum_vie import TUMVIE
from .visual_place_recognition import VPR

Expand All @@ -28,6 +29,7 @@
"SHD",
"SMNIST",
"SSC",
"ThreeET_Eyetracking",
"TUMVIE",
"VPR",
"DVSLip",
Expand Down
113 changes: 113 additions & 0 deletions tonic/datasets/threeET_eyetracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import os
from typing import Any, Callable, Optional, Tuple

import h5py
import numpy as np

from tonic.dataset import Dataset
from tonic.io import make_structured_array


class ThreeET_Eyetracking(Dataset):
"""3ET DVS eye tracking `3ET <https://github.com/qinche106/cb-convlstm-eyetracking>`_
::
@article{chen20233et,
title={3ET: Efficient Event-based Eye Tracking using a Change-Based ConvLSTM Network},
author={Chen, Qinyu and Wang, Zuowen and Liu, Shih-Chii and Gao, Chang},
journal={arXiv preprint arXiv:2308.11771},
year={2023}
}
Parameters:
save_to (string): Location to save files to on disk.
transform (callable, optional): A callable of transforms to apply to the data.
split (string, optional): The dataset split to use, ``train`` or ``val``.
target_transform (callable, optional): A callable of transforms to apply to the targets/labels.
transforms (callable, optional): A callable of transforms that is applied to both data and
labels at the same time.
Returns:
A dataset object that can be indexed or iterated over.
One sample returns a tuple of (events, targets).
"""

url = "https://dl.dropboxusercontent.com/s/1hyer8egd8843t9/ThreeET_Eyetracking.zip?dl=0"
filename = "ThreeET_Eyetracking.zip"
file_md5 = "b6c652b06fdfd85721f39e2dbe12f4e8"

sensor_size = (240, 180, 2)
dtype = np.dtype([("t", int), ("x", int), ("y", int), ("p", int)])
ordering = dtype.names

def __init__(
self,
save_to: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
):
super().__init__(
save_to,
transform=transform,
target_transform=target_transform,
transforms=transforms,
)

# if not exist, download from url
if not self._check_exists():
self.download()

data_dir = os.path.join(save_to, "ThreeET_Eyetracking")
# Load filenames from the provided lists
if split == "train":
filenames = self.load_filenames(os.path.join(data_dir, "train_files.txt"))
elif split == "val":
filenames = self.load_filenames(os.path.join(data_dir, "val_files.txt"))
else:
raise ValueError("Invalid split name")

# Get the data file paths and target file paths
self.data = [os.path.join(data_dir, "data", f + ".h5") for f in filenames]
self.targets = [os.path.join(data_dir, "labels", f + ".txt") for f in filenames]

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Returns:
(events, target) where target is index of the target class.
"""
# get events from .h5 file
with h5py.File(self.data[index], "r") as f:
events = f["events"][:]
# load the sparse labels
with open(self.targets[index], "r") as f:
target = np.array(
[line.strip().split() for line in f.readlines()], np.float64
)

events = make_structured_array(
events[:, 0], # time in us
events[:, 1], # x
events[:, 2], # y
events[:, 3], # polarity in 1 or 0
dtype=self.dtype,
)

if self.transform is not None:
events = self.transform(events)
if self.target_transform is not None:
target = self.target_transform(target)
if self.transforms is not None:
events, target = self.transforms(events, target)
return events, target

def __len__(self):
return len(self.data)

def _check_exists(self):
return self._is_file_present()

def load_filenames(self, path):
with open(path, "r") as f:
return [line.strip() for line in f.readlines()]

0 comments on commit debdd2e

Please sign in to comment.