How to save train, test data and model preditctions? #762
Answered
by
chenyushuo
johnny12150
asked this question in
Q&A
-
Is there a quick way to save the train, test data after splitting them? |
Beta Was this translation helpful? Give feedback.
Answered by
chenyushuo
Mar 9, 2021
Replies: 1 comment 6 replies
-
For the first question, we suggest that you can use import pickle
with open('split_data.pth', 'wb') as f:
pickle.dump((train_data, test_data), f) And we added save funtion for split data in #760 recently. You can use For the second question, you can see these code (mainly based on #506): import numpy as np
import torch
from recbole.data.dataloader.general_dataloader import GeneralFullDataLoader
from recbole.data.dataloader.sequential_dataloader import SequentialFullDataLoader
uid_series = np.array([1, 2])
# We assume you have load test_data and model
uid_field = test_data.dataset.uid_field
dataset = test_data.dataset
model.eval()
if isinstance(test_data, GeneralFullDataLoader):
index = np.isin(test_data.user_df[uid_field].numpy(), uid_series)
input_interaction = test_data.user_df[index]
elif isinstance(test_data, SequentialFullDataLoader):
index = np.isin(test_data.uid_list, uid_series)
input_interaction = test_data.augmentation(
test_data.item_list_index[index], test_data.target_index[index], test_data.item_list_length[index]
)
else:
raise NotImplementedError
# Get scores of all items
try:
scores = model.full_sort_predict(input_interaction)
except NotImplementedError:
input_interaction = input_interaction.repeat(dataset.item_num)
input_interaction.update(test_data.get_item_feature().repeat(len(uid_series)))
scores = model.predict(input_interaction)
scores = scores.view(-1, dataset.item_num) # scores of all item, shape: len(uid_series) *
# Get Groud truth interaction
index = np.isin(test_data.dataset.inter_feat[uid_field].numpy(), uid_series)
real_inter = test_data.dataset.inter_feat[index] # the ground truth interaction |
Beta Was this translation helpful? Give feedback.
6 replies
Answer selected by
johnny12150
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
For the first question, we suggest that you can use
pickle
to dump these split data, just like this:And we added save funtion for split data in #760 recently. You can use
save_split_dataloaders
andload_split_dataloaders
to save and load split data in the later version.For the second question, you can see these code (mainly based on #506):