Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

k-fold CV #1574

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
![RecBole Logo](asset/logo.png)

--------------------------------------------------------------------------------
# K-fold CV
This branch contains an implementation of K-fold CV in RecBole. The implementation relies on
RecBole's Dataset class and its methods `split_by_ratio` and `build` (`recbole/data/dataset/dataset.py`), on the function `data_preparation`
(`recbole/data/utils.py`), and on the general configuration file `overall.yaml` (`recbole/properties`).

Running RecBole calls the `quick_start.py` (`recbole/quick_start`) script, which splits the data with `data_preparation(config, dataset)`.
K-fold CV happens in this function. First, calling `dataset.build()` with the parameters stored in the
`config` dictionary, splits the dataset in folds of sizes defined in the `overall.yaml`. The percentage of interactions to be assigned to
each fold (e.g., [0.2, 0.2, 0.2, 0.2, 0.2] for 5 folds) is also stored in the variable `folds` as a list.

The variable k stores which permutation of the folds is to be used in the current experiment, e.g., for a 5-fold CV and $k=0$,
the first three splits will be assigned to the train set, the fourth to the validation, and the last one to the test set.

The function then initializes `train_dataset` as an empty Dataset(config) instance, and fills it with the interactions of the
corresponding folds. The remaining two folds are assigned to `valid_dataset` and `test_dataset`.

For running an experiment on the first fold of a 5-fold CV, the `eval_args` in the `overall.yaml` file need to be set to:
```python
# Evaluation Settings
eval_args: # (dict) 4 keys: group_by, order, split, and mode
split: {'KF': [0.2, 0.2, 0.2, 0.2, 0.2]}
fold: 0
group_by: none # (str) The grouping strategy ranging in ['user', 'none'].
order: RO # (str) The ordering strategy ranging in ['RO', 'TO'].
mode: full # (str) The evaluation mode ranging in ['full','unixxx','popxxx','labeled'].

```



# RecBole (伯乐)

Expand Down
18 changes: 18 additions & 0 deletions recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,6 +1796,24 @@ def build(self):
raise NotImplementedError(
f"The grouping method [{group_by}] has not been implemented."
)

elif split_mode == "KF":
"""
Will return n_folds datasets
"""
if not isinstance(split_args["KF"], list):
raise ValueError(f'The value of "KF" [{split_args}] should be a list.')
if group_by is None or group_by.lower() == "none":
datasets = self.split_by_ratio(split_args["KF"], group_by=None)
elif group_by == "user":
datasets = self.split_by_ratio(
split_args["KF"], group_by=self.uid_field
)
else:
raise NotImplementedError(
f"The grouping method [{group_by}] has not been implemented."
)

elif split_mode == "LS":
datasets = self.leave_one_out(
group_by=self.uid_field, leave_one_mode=split_args["LS"]
Expand Down
23 changes: 21 additions & 2 deletions recbole/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from recbole.utils import ModelType, ensure_dir, get_local_time, set_color
from recbole.utils.argument_list import dataset_arguments

from recbole.data.interaction import cat_interactions
from recbole.data.dataset.dataset import Dataset


def create_dataset(config):
"""Create dataset according to :attr:`config['model']` and :attr:`config['MODEL_TYPE']`.
Expand Down Expand Up @@ -160,9 +163,25 @@ def data_preparation(config, dataset):
train_data, valid_data, test_data = dataloaders
else:
model_type = config["MODEL_TYPE"]
built_datasets = dataset.build()
if list(config["eval_args"]["split"].keys())[0] == 'KF':
print("==KF")
folds = dataset.build()# data_preparation(config, dataset)
n_folds = len(folds)
print(n_folds)
k = config["eval_args"]["fold"]
folds = folds[k:] + folds[:k]

train_dataset = Dataset(config)
train_dataset.inter_feat = cat_interactions([fold.inter_feat for fold in folds[: n_folds - 2]])
valid_dataset = folds[n_folds - 2]
test_dataset = folds[n_folds - 1]

built_datasets = [train_dataset, valid_dataset, test_dataset]

else:
built_datasets = dataset.build()
train_dataset, valid_dataset, test_dataset = built_datasets

train_dataset, valid_dataset, test_dataset = built_datasets
train_sampler, valid_sampler, test_sampler = create_samplers(
config, dataset, built_datasets
)
Expand Down
27 changes: 15 additions & 12 deletions recbole/properties/overall.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ use_gpu: True # (bool) Whether or not to use GPU.
seed: 2020 # (int) Random seed.
state: INFO # (str) Logging level.
reproducibility: True # (bool) Whether or not to make results reproducible.
data_path: 'dataset/' # (str) The path of input dataset.
data_path: /home/marta/jku/fairinterplay/dataset # (str) The path of input dataset.
checkpoint_dir: 'saved' # (str) The path to save checkpoint file.
show_progress: True # (bool) Whether or not to show the progress bar of every epoch.
save_dataset: False # (bool) Whether or not to save filtered dataset.
show_progress: True # (bool) Whether or not to show the progress bar of every epoch.
save_dataset: True # (bool) Whether or not to save filtered dataset.
dataset_save_path: ~ # (str) The path of saved dataset.
save_dataloaders: False # (bool) Whether or not save split dataloaders.
save_dataloaders: True # (bool) Whether or not save split dataloaders.
dataloaders_save_path: ~ # (str) The path of saved dataloaders.
log_wandb: False # (bool) Whether or not to use Weights & Biases(W&B).
log_wandb: True # (bool) Whether or not to use Weights & Biases(W&B).
wandb_project: 'recbole' # (str) The project to conduct experiments in W&B.
shuffle: True # (bool) Whether or not to shuffle the training data before each epoch.

# Training Settings
epochs: 300 # (int) The number of training epochs.
train_batch_size: 2048 # (int) The training batch size.
train_batch_size: 1024 # (int) The training batch size.
learner: adam # (str) The name of used optimizer.
learning_rate: 0.001 # (float) Learning rate.
train_neg_sample_args: # (dict) Negative sampling configuration for model training.
Expand All @@ -29,7 +29,7 @@ train_neg_sample_args: # (dict) Negative sampling configuration for mod
candidate_num: 0 # (int) The number of candidate negative items when dynamic negative sampling.
eval_step: 1 # (int) The number of training epochs before an evaluation on the valid dataset.
stopping_step: 10 # (int) The threshold for validation-based early stopping.
clip_grad_norm: ~ # (dict) The args of clip_grad_norm_ which will clip gradient norm of model.
clip_grad_norm: ~ # (dict) The args of clip_grad_norm_ which will clip gradient norm of model.
weight_decay: 0.0 # (float) The weight decay value (L2 penalty) for optimizers.
loss_decimal_place: 4 # (int) The decimal place of training loss.
require_pow: False # (bool) Whether or not to perform power operation in EmbLoss.
Expand All @@ -39,14 +39,17 @@ transform: ~ # (str) The transform operation for batch data p

# Evaluation Settings
eval_args: # (dict) 4 keys: group_by, order, split, and mode
split: {'RS':[0.8,0.1,0.1]} # (dict) The splitting strategy ranging in ['RS','LS'].
group_by: user # (str) The grouping strategy ranging in ['user', 'none'].
# split: {'LS': 'valid_and_test'} # (dict) The splitting strategy ranging in ['RS','LS'].
# split: { 'RS': [0.6, 0.2, 0.2] } # (dict) The splitting strategy ranging in ['RS','LS'].
split: {'KF': [0.2, 0.2, 0.2, 0.2, 0.2]}
fold: 2
group_by: none # (str) The grouping strategy ranging in ['user', 'none'].
order: RO # (str) The ordering strategy ranging in ['RO', 'TO'].
mode: full # (str) The evaluation mode ranging in ['full','unixxx','popxxx','labeled'].
repeatable: False # (bool) Whether to evaluate results with a repeatable recommendation scene.
repeatable: False # (bool) Whether to evaluate results with a repeatable recommendation scene.
metrics: ["Recall","MRR","NDCG","Hit","Precision"] # (list or str) Evaluation metrics.
topk: [10] # (list or int or None) The value of k for topk evaluation metrics.
valid_metric: MRR@10 # (str) The evaluation metric for early stopping.
valid_metric: 'NDCG@10'
valid_metric_bigger: True # (bool) Whether to take a bigger valid metric value as a better result.
eval_batch_size: 4096 # (int) The evaluation batch size.
eval_batch_size: 1024 # (int) The evaluation batch size.
metric_decimal_place: 4 # (int) The decimal place of metric scores.