Skip to content

Commit

Permalink
sasrec-plus
Browse files Browse the repository at this point in the history
  • Loading branch information
Fotiligner committed Mar 10, 2024
1 parent 5f3cb96 commit 2c748e9
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 9 deletions.
Binary file added data.pkl
Binary file not shown.
Binary file added data_reasoning.pkl
Binary file not shown.
187 changes: 184 additions & 3 deletions recbole/model/sequential_recommender/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from recbole.model.layers import TransformerEncoder
from recbole.model.loss import BPRLoss

import pickle

init = nn.init.xavier_uniform_

class SASRec(SequentialRecommender):
r"""
Expand Down Expand Up @@ -50,7 +53,7 @@ def __init__(self, config, dataset):

self.initializer_range = config["initializer_range"]
self.loss_type = config["loss_type"]

# define layers and loss
self.item_embedding = nn.Embedding(
self.n_items, self.hidden_size, padding_idx=0
Expand All @@ -70,6 +73,24 @@ def __init__(self, config, dataset):
self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.dropout = nn.Dropout(self.hidden_dropout_prob)

self.kd_weight = 0.5

with open('./data.pkl', 'rb') as file:
self.origin_embeds = pickle.load(file)

self.itmprf_embeds = torch.tensor(self.origin_embeds).float().cuda()
self.mlp = nn.Sequential(
nn.Linear(self.itmprf_embeds.shape[1], (self.itmprf_embeds.shape[1] + self.hidden_size) // 2),
nn.LeakyReLU(),
nn.Linear((self.itmprf_embeds.shape[1] + self.hidden_size) // 2, self.hidden_size)
)

# self.mlp2 = nn.Sequential(
# nn.Linear(self.hidden_size * 2, (self.hidden_size * 2 + self.hidden_size) // 2),
# nn.LeakyReLU(),
# nn.Linear((self.hidden_size * 2 + self.hidden_size) // 2, self.hidden_size)
# )

if self.loss_type == "BPR":
self.loss_fct = BPRLoss()
elif self.loss_type == "CE":
Expand All @@ -92,6 +113,14 @@ def _init_weights(self, module):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()

for m in self.mlp:
if isinstance(m, nn.Linear):
init(m.weight)

# for m in self.mlp2:
# if isinstance(m, nn.Linear):
# init(m.weight)

def forward(self, item_seq, item_seq_len):
position_ids = torch.arange(
item_seq.size(1), dtype=torch.long, device=item_seq.device
Expand All @@ -116,7 +145,7 @@ def forward(self, item_seq, item_seq_len):
def calculate_loss(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
seq_output_id = self.forward(item_seq, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == "BPR":
neg_items = interaction[self.NEG_ITEM_ID]
Expand All @@ -128,8 +157,17 @@ def calculate_loss(self, interaction):
return loss
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
logits = torch.matmul(seq_output_id, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)

itmprf_emb = self.mlp(self.itmprf_embeds)
index_items = pos_items - 1
posprf_embeds = itmprf_emb[index_items]
kd_loss = self.cal_infonce_loss(embeds1=seq_output_id, embeds2=posprf_embeds, all_embeds2=itmprf_emb, temp=self.kd_weight)
kd_loss /= seq_output_id.shape[0]
kd_loss *= self.kd_weight
loss += kd_loss

return loss

def predict(self, interaction):
Expand All @@ -148,3 +186,146 @@ def full_sort_predict(self, interaction):
test_items_emb = self.item_embedding.weight
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B n_items]
return scores

def cal_infonce_loss(self, embeds1, embeds2, all_embeds2, temp):
normed_embeds1 = embeds1 / torch.sqrt(1e-8 + embeds1.square().sum(-1, keepdim=True))
normed_embeds2 = embeds2 / torch.sqrt(1e-8 + embeds2.square().sum(-1, keepdim=True))
normed_all_embeds2 = all_embeds2 / torch.sqrt(1e-8 + all_embeds2.square().sum(-1, keepdim=True))
nume_term = -(normed_embeds1 * normed_embeds2 / temp).sum(-1)
deno_term = torch.log(torch.sum(torch.exp(normed_embeds1 @ normed_all_embeds2.T / temp), dim=-1))
cl_loss = (nume_term + deno_term).sum()
return cl_loss









# class SASRec(SequentialRecommender):
# r"""
# SASRec is the first sequential recommender based on self-attentive mechanism.

# NOTE:
# In the author's implementation, the Point-Wise Feed-Forward Network (PFFN) is implemented
# by CNN with 1x1 kernel. In this implementation, we follows the original BERT implementation
# using Fully Connected Layer to implement the PFFN.
# """

# def __init__(self, config, dataset):
# super(SASRec, self).__init__(config, dataset)

# # load parameters info
# self.n_layers = config["n_layers"]
# self.n_heads = config["n_heads"]
# self.hidden_size = config["hidden_size"] # same as embedding_size
# self.inner_size = config[
# "inner_size"
# ] # the dimensionality in feed-forward layer
# self.hidden_dropout_prob = config["hidden_dropout_prob"]
# self.attn_dropout_prob = config["attn_dropout_prob"]
# self.hidden_act = config["hidden_act"]
# self.layer_norm_eps = config["layer_norm_eps"]

# self.initializer_range = config["initializer_range"]
# self.loss_type = config["loss_type"]

# # define layers and loss
# self.item_embedding = nn.Embedding(
# self.n_items, self.hidden_size, padding_idx=0
# )
# self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size)
# self.trm_encoder = TransformerEncoder(
# n_layers=self.n_layers,
# n_heads=self.n_heads,
# hidden_size=self.hidden_size,
# inner_size=self.inner_size,
# hidden_dropout_prob=self.hidden_dropout_prob,
# attn_dropout_prob=self.attn_dropout_prob,
# hidden_act=self.hidden_act,
# layer_norm_eps=self.layer_norm_eps,
# )

# self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
# self.dropout = nn.Dropout(self.hidden_dropout_prob)

# if self.loss_type == "BPR":
# self.loss_fct = BPRLoss()
# elif self.loss_type == "CE":
# self.loss_fct = nn.CrossEntropyLoss()
# else:
# raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

# # parameters initialization
# self.apply(self._init_weights)

# def _init_weights(self, module):
# """Initialize the weights"""
# if isinstance(module, (nn.Linear, nn.Embedding)):
# # Slightly different from the TF version which uses truncated_normal for initialization
# # cf https://github.com/pytorch/pytorch/pull/5617
# module.weight.data.normal_(mean=0.0, std=self.initializer_range)
# elif isinstance(module, nn.LayerNorm):
# module.bias.data.zero_()
# module.weight.data.fill_(1.0)
# if isinstance(module, nn.Linear) and module.bias is not None:
# module.bias.data.zero_()

# def forward(self, item_seq, item_seq_len):
# position_ids = torch.arange(
# item_seq.size(1), dtype=torch.long, device=item_seq.device
# )
# position_ids = position_ids.unsqueeze(0).expand_as(item_seq)
# position_embedding = self.position_embedding(position_ids)

# item_emb = self.item_embedding(item_seq)
# input_emb = item_emb + position_embedding
# input_emb = self.LayerNorm(input_emb)
# input_emb = self.dropout(input_emb)

# extended_attention_mask = self.get_attention_mask(item_seq)

# trm_output = self.trm_encoder(
# input_emb, extended_attention_mask, output_all_encoded_layers=True
# )
# output = trm_output[-1]
# output = self.gather_indexes(output, item_seq_len - 1)
# return output # [B H]

# def calculate_loss(self, interaction):
# item_seq = interaction[self.ITEM_SEQ]
# item_seq_len = interaction[self.ITEM_SEQ_LEN]
# seq_output = self.forward(item_seq, item_seq_len)
# pos_items = interaction[self.POS_ITEM_ID]
# if self.loss_type == "BPR":
# neg_items = interaction[self.NEG_ITEM_ID]
# pos_items_emb = self.item_embedding(pos_items)
# neg_items_emb = self.item_embedding(neg_items)
# pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
# neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
# loss = self.loss_fct(pos_score, neg_score)
# return loss
# else: # self.loss_type = 'CE'
# test_item_emb = self.item_embedding.weight
# logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
# loss = self.loss_fct(logits, pos_items)
# return loss

# def predict(self, interaction):
# item_seq = interaction[self.ITEM_SEQ]
# item_seq_len = interaction[self.ITEM_SEQ_LEN]
# test_item = interaction[self.ITEM_ID]
# seq_output = self.forward(item_seq, item_seq_len)
# test_item_emb = self.item_embedding(test_item)
# scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
# return scores

# def full_sort_predict(self, interaction):
# item_seq = interaction[self.ITEM_SEQ]
# item_seq_len = interaction[self.ITEM_SEQ_LEN]
# seq_output = self.forward(item_seq, item_seq_len)
# test_items_emb = self.item_embedding.weight
# scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B n_items]
# return scores
12 changes: 6 additions & 6 deletions recbole/properties/overall.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ epochs: 300 # (int) The number of training epochs.
train_batch_size: 2048 # (int) The training batch size.
learner: adam # (str) The name of used optimizer.
learning_rate: 0.001 # (float) Learning rate.
train_neg_sample_args: # (dict) Negative sampling configuration for model training.
distribution: uniform # (str) The distribution of negative items.
sample_num: 1 # (int) The sampled num of negative items.
alpha: 1.0 # (float) The power of sampling probability for popularity distribution.
dynamic: False # (bool) Whether to use dynamic negative sampling.
candidate_num: 0 # (int) The number of candidate negative items when dynamic negative sampling.
# train_neg_sample_args: # (dict) Negative sampling configuration for model training.
# distribution: uniform # (str) The distribution of negative items.
# sample_num: 1 # (int) The sampled num of negative items.
# alpha: 1.0 # (float) The power of sampling probability for popularity distribution.
# dynamic: False # (bool) Whether to use dynamic negative sampling.
# candidate_num: 0 # (int) The number of candidate negative items when dynamic negative sampling.
eval_step: 1 # (int) The number of training epochs before an evaluation on the valid dataset.
stopping_step: 10 # (int) The threshold for validation-based early stopping.
clip_grad_norm: ~ # (dict) The args of clip_grad_norm_ which will clip gradient norm of model.
Expand Down
4 changes: 4 additions & 0 deletions run_sasrec.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
python run_recbole.py \
--model SASRec \
--train_neg_sample_args None \
--loss_type CE
5 changes: 5 additions & 0 deletions run_sasrec_plus.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
python run_recbole.py \
--model SASRec \
--train_neg_sample_args None \
--loss_type CE \
--kd_weight 0.5

0 comments on commit 2c748e9

Please sign in to comment.