Skip to content

Commit

Permalink
update CL4SRec
Browse files Browse the repository at this point in the history
  • Loading branch information
Fotiligner committed Dec 8, 2023
1 parent 4121d5c commit 0aaaaaf
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 0 deletions.
84 changes: 84 additions & 0 deletions recbole/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def construct_transform(config):
"crop_itemseq": CropItemSequence,
"reorder_itemseq": ReorderItemSequence,
"user_defined": UserDefinedTransform,
"random_itemseq": RandomAugmentationSequence
}
if config["transform"] not in str2transform:
raise NotImplementedError(
Expand Down Expand Up @@ -221,6 +222,89 @@ def __call__(self, dataset, interaction):
interaction.update(Interaction(new_dict))
return interaction

class RandomAugmentationSequence:
def __init__(self, config):
self.ITEM_SEQ = config["ITEM_ID_FIELD"] + config["LIST_SUFFIX"]
self.RANDOM_ITEM_SEQ = "Random_" + self.ITEM_SEQ
self.ITEM_SEQ_LEN = config["ITEM_LIST_LENGTH_FIELD"]
self.ITEM_ID = config["ITEM_ID_FIELD"]
config["RANDOM_ITEM_SEQ"] = self.RANDOM_ITEM_SEQ


def __call__(self, dataset, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
device = item_seq.device
n_items = dataset.num(self.ITEM_ID)

aug_seq1 = []
aug_len1 = []
aug_seq2 = []
aug_len2 = []

for seq, length in zip(item_seq, item_seq_len):
if length > 1:
switch = random.sample(range(3), k=2)
else:
switch = [3, 3]
aug_seq = seq
aug_len = length
if switch[0] == 0:
aug_seq, aug_len = self.item_crop(seq, length)
elif switch[0] == 1:
aug_seq, aug_len = self.item_mask(seq, n_items, length)
elif switch[0] == 2:
aug_seq, aug_len = self.item_reorder(seq, length)

aug_seq1.append(aug_seq)
aug_len1.append(aug_len)

if switch[1] == 0:
aug_seq, aug_len = self.item_crop(seq, length)
elif switch[1] == 1:
aug_seq, aug_len = self.item_mask(seq, n_items, length)
elif switch[1] == 2:
aug_seq, aug_len = self.item_reorder(seq, length)

aug_seq2.append(aug_seq)
aug_len2.append(aug_len)

new_dict = {
"aug1" : torch.stack(aug_seq1),
"aug1_len" : torch.stack(aug_len1),
"aug2" : torch.stack(aug_seq2),
"aug2_len" : torch.stack(aug_len2)
}
interaction.update(Interaction(new_dict))
return interaction

def item_crop(self, item_seq, item_seq_len, eta=0.6):
num_left = math.floor(item_seq_len * eta)
crop_begin = random.randint(0, item_seq_len - num_left)
croped_item_seq = np.zeros(item_seq.shape[0])
if crop_begin + num_left < item_seq.shape[0]:
croped_item_seq[:num_left] = item_seq.cpu().detach().numpy()[crop_begin:crop_begin + num_left]
else:
croped_item_seq[:num_left] = item_seq.cpu().detach().numpy()[crop_begin:]
return torch.tensor(croped_item_seq, dtype=torch.long, device=item_seq.device),\
torch.tensor(num_left, dtype=torch.long, device=item_seq.device)

def item_mask(self, item_seq, n_items, item_seq_len, gamma=0.3):
num_mask = math.floor(item_seq_len * gamma)
mask_index = random.sample(range(item_seq_len), k=num_mask)
masked_item_seq = item_seq.cpu().detach().numpy().copy()
masked_item_seq[mask_index] = n_items - 1 # token 0 has been used for semantic masking
return torch.tensor(masked_item_seq, dtype=torch.long, device=item_seq.device), item_seq_len

def item_reorder(self, item_seq, item_seq_len, beta=0.6):
num_reorder = math.floor(item_seq_len * beta)
reorder_begin = random.randint(0, item_seq_len - num_reorder)
reordered_item_seq = item_seq.cpu().detach().numpy().copy()
shuffle_index = list(range(reorder_begin, reorder_begin + num_reorder))
random.shuffle(shuffle_index)
reordered_item_seq[reorder_begin:reorder_begin + num_reorder] = reordered_item_seq[shuffle_index]
return torch.tensor(reordered_item_seq, dtype=torch.long, device=item_seq.device), item_seq_len


class CropItemSequence:
"""
Expand Down
1 change: 1 addition & 0 deletions recbole/model/sequential_recommender/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from recbole.model.sequential_recommender.bert4rec import BERT4Rec
from recbole.model.sequential_recommender.caser import Caser
from recbole.model.sequential_recommender.core import CORE
from recbole.model.sequential_recommender.cl4srec import CL4SRec
from recbole.model.sequential_recommender.dien import DIEN
from recbole.model.sequential_recommender.din import DIN
from recbole.model.sequential_recommender.fdsa import FDSA
Expand Down
229 changes: 229 additions & 0 deletions recbole/model/sequential_recommender/cl4srec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# -*- coding: utf-8 -*-
# @Time : 2023/11/30
# @Author : Bingqian Li
# @Email : [email protected]

import math
import random
import numpy as np
import torch
from torch import nn
from recbole.model.abstract_recommender import SequentialRecommender
from recbole.model.layers import TransformerEncoder
from recbole.model.loss import BPRLoss


class CL4SRec(SequentialRecommender):
def __init__(self, config, dataset):
super(CL4SRec, self).__init__(config, dataset)

# load parameters info
self.n_layers = config['n_layers']
self.n_heads = config['n_heads']
self.hidden_size = config['hidden_size']
self.inner_size = config['inner_size']
self.hidden_dropout_prob = config['hidden_dropout_prob']
self.attn_dropout_prob = config['attn_dropout_prob']
self.hidden_act = config['hidden_act']
self.layer_norm_eps = config['layer_norm_eps']

self.batch_size = config['train_batch_size']
self.lmd = config['lmd']
self.tau = config['tau']
self.sim = config['sim']

self.initializer_range = config['initializer_range']
self.loss_type = config['loss_type']

# define layers and loss
self.item_embedding = nn.Embedding(self.n_items + 1, self.hidden_size, padding_idx=0)
self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size)
self.trm_encoder = TransformerEncoder(
n_layers=self.n_layers,
n_heads=self.n_heads,
hidden_size=self.hidden_size,
inner_size=self.inner_size,
hidden_dropout_prob=self.hidden_dropout_prob,
attn_dropout_prob=self.attn_dropout_prob,
hidden_act=self.hidden_act,
layer_norm_eps=self.layer_norm_eps
)

self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.dropout = nn.Dropout(self.hidden_dropout_prob)

if self.loss_type == 'BPR':
self.loss_fct = BPRLoss()
elif self.loss_type == 'CE':
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

self.mask_default = self.mask_correlated_samples(batch_size=self.batch_size)
self.nce_fct = nn.CrossEntropyLoss()

# parameters initialization
self.apply(self._init_weights)

def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()

def get_attention_mask(self, item_seq):
"""Generate left-to-right uni-directional attention mask for multi-head attention."""
attention_mask = (item_seq > 0).long()
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64
# mask for left-to-right unidirectional
max_len = attention_mask.size(-1)
attn_shape = (1, max_len, max_len)
subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8
subsequent_mask = (subsequent_mask == 0).unsqueeze(1)
subsequent_mask = subsequent_mask.long().to(item_seq.device)

extended_attention_mask = extended_attention_mask * subsequent_mask
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask

def forward(self, item_seq, item_seq_len):
position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device)
position_ids = position_ids.unsqueeze(0).expand_as(item_seq)
position_embedding = self.position_embedding(position_ids)

item_emb = self.item_embedding(item_seq)
input_emb = item_emb + position_embedding
input_emb = self.LayerNorm(input_emb)
input_emb = self.dropout(input_emb)

extended_attention_mask = self.get_attention_mask(item_seq)

trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True)
output = trm_output[-1]
output = self.gather_indexes(output, item_seq_len - 1)
return output # [B H]

def calculate_loss(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == 'BPR':
neg_items = interaction[self.NEG_ITEM_ID]
pos_items_emb = self.item_embedding(pos_items)
neg_items_emb = self.item_embedding(neg_items)
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
loss = self.loss_fct(pos_score, neg_score)
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight[:self.n_items] # unpad the augmentation mask
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)

# # NCE
# aug_item_seq1, aug_len1, aug_item_seq2, aug_len2 = self.augment(item_seq, item_seq_len)
# # aug_item_seq1, aug_len1, aug_item_seq2, aug_len2 = \
# # interaction['aug1'], interaction['aug_len1'], interaction['aug2'], interaction['aug_len2']
# seq_output1 = self.forward(aug_item_seq1, aug_len1)
# seq_output2 = self.forward(aug_item_seq2, aug_len2)

seq_output1 = self.forward(interaction["aug1"], interaction["aug1_len"])
seq_output2 = self.forward(interaction["aug2"], interaction["aug2_len"])

nce_logits, nce_labels = self.info_nce(seq_output1, seq_output2, temp=self.tau, batch_size=item_seq_len.shape[0], sim='dot')

nce_loss = self.nce_fct(nce_logits, nce_labels)

with torch.no_grad():
alignment, uniformity = self.decompose(seq_output1, seq_output2, seq_output,
batch_size=item_seq_len.shape[0])

return loss + self.lmd * nce_loss, alignment, uniformity

def decompose(self, z_i, z_j, origin_z, batch_size):
"""
We do not sample negative examples explicitly.
Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
"""
N = 2 * batch_size

z = torch.cat((z_i, z_j), dim=0)

# pairwise l2 distace
sim = torch.cdist(z, z, p=2)

sim_i_j = torch.diag(sim, batch_size)
sim_j_i = torch.diag(sim, -batch_size)

positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
alignment = positive_samples.mean()

# pairwise l2 distace
sim = torch.cdist(origin_z, origin_z, p=2)
mask = torch.ones((batch_size, batch_size), dtype=bool)
mask = mask.fill_diagonal_(0)
negative_samples = sim[mask].reshape(batch_size, -1)
uniformity = torch.log(torch.exp(-2 * negative_samples).mean())

return alignment, uniformity

def mask_correlated_samples(self, batch_size):
N = 2 * batch_size
mask = torch.ones((N, N), dtype=bool)
mask = mask.fill_diagonal_(0)
for i in range(batch_size):
mask[i, batch_size + i] = 0
mask[batch_size + i, i] = 0
return mask

def info_nce(self, z_i, z_j, temp, batch_size, sim='dot'):
"""
We do not sample negative examples explicitly.
Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
"""
N = 2 * batch_size

z = torch.cat((z_i, z_j), dim=0)

if sim == 'cos':
sim = nn.functional.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) / temp
elif sim == 'dot':
sim = torch.mm(z, z.T) / temp

sim_i_j = torch.diag(sim, batch_size)
sim_j_i = torch.diag(sim, -batch_size)

positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
if batch_size != self.batch_size:
mask = self.mask_correlated_samples(batch_size)
else:
mask = self.mask_default
negative_samples = sim[mask].reshape(N, -1)

labels = torch.zeros(N).to(positive_samples.device).long()
logits = torch.cat((positive_samples, negative_samples), dim=1)
return logits, labels

def predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
test_item = interaction[self.ITEM_ID]
seq_output = self.forward(item_seq, item_seq_len)
test_item_emb = self.item_embedding(test_item)
scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B]
return scores

def full_sort_predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
test_items_emb = self.item_embedding.weight[:self.n_items] # unpad the augmentation mask
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B n_items]
return scores
13 changes: 13 additions & 0 deletions recbole/properties/model/CL4SRec.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
n_layers: 2 # (int) The number of transformer layers in transformer encoder.
n_heads: 2 # (int) The number of attention heads for multi-head attention layer.
hidden_size: 64 # (int) The number of features in the hidden state.
inner_size: 256 # (int) The inner hidden size in feed-forward layer.
hidden_dropout_prob: 0.5 # (float) The probability of an element to be zeroed.
attn_dropout_prob: 0.5 # (float) The probability of an attention score to be zeroed.
hidden_act: 'gelu' # (str) The activation function in feed-forward layer.
layer_norm_eps: 1e-12 # (float) A value added to the denominator for numerical stability.
initializer_range: 0.02 # (float) The standard deviation for normal initialization.
loss_type: 'BPR' # (str) The type of loss function. Range in ['BPR', 'CE'].
transform: 'random_itemseq' # (str) The type of item trasformation.
lmd: 0.01 # (float) proportion of contrastive loss
tau: 5 # (float) hyper parameter of contrastive loss

0 comments on commit 0aaaaaf

Please sign in to comment.