Skip to content

Commit

Permalink
add cl
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhjerry committed Sep 11, 2023
1 parent f6f1870 commit eea8a8b
Show file tree
Hide file tree
Showing 7 changed files with 347 additions and 0 deletions.
77 changes: 77 additions & 0 deletions easy_rec/python/input/augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf
from easy_rec.python.utils.shape_utils import get_shape_list

if tf.__version__ >= '2.0':
tf = tf.compat.v1


def assign(input_tensor, position=None, value=None):
input_tensor[tuple(position)] = value
return input_tensor


def item_mask(aug_data, length, weights, mask_param):
length1 = tf.cast(length, dtype=tf.float32)
num_mask = tf.cast(tf.math.floor(length1 * mask_param), dtype=tf.int32)
seq = tf.range(length, dtype=tf.int32)
mask_index = tf.random.shuffle(seq)[:num_mask]
masked_item_seq = aug_data
masked_item_seq = tf.py_func(assign, inp=[masked_item_seq, [mask_index], weights], Tout=masked_item_seq.dtype)
return masked_item_seq, length


def item_crop(aug_data, length, crop_param):
length1 = tf.cast(length, dtype=tf.float32)
max_length = tf.cast(get_shape_list(aug_data)[0], dtype=tf.int32)
embedding_size = get_shape_list(aug_data)[1]

num_left = tf.cast(tf.math.floor(length1 * crop_param), dtype=tf.int32)
crop_begin = tf.random.uniform([1], minval=0, maxval=length - num_left, dtype=tf.int32)[0]
cropped_item_seq = tf.zeros([get_shape_list(aug_data)[0], embedding_size])
cropped_item_seq = tf.where(crop_begin + num_left < max_length,
tf.concat([aug_data[crop_begin:crop_begin + num_left],
cropped_item_seq[:max_length - num_left]], axis=0),
tf.concat([aug_data[crop_begin:], cropped_item_seq[:crop_begin]], axis=0))
return cropped_item_seq, num_left


def item_reorder(aug_data, length, reorder_param):
length1 = tf.cast(length, dtype=tf.float32)
num_reorder = tf.cast(tf.math.floor(length1 * reorder_param), dtype=tf.int32)
reorder_begin = tf.random.uniform([1], minval=0, maxval=length - num_reorder, dtype=tf.int32)[0]
shuffle_index = tf.range(reorder_begin, reorder_begin + num_reorder)
shuffle_index = tf.random.shuffle(shuffle_index)
x = tf.range(get_shape_list(aug_data)[0])
left = tf.slice(x, [0], [reorder_begin])
right = tf.slice(x, [reorder_begin + num_reorder], [-1])
reordered_item_index = tf.concat([left, shuffle_index, right], axis=0)
reordered_item_seq = tf.scatter_nd(tf.expand_dims(reordered_item_index, axis=1),
aug_data,
tf.shape(aug_data))
return reordered_item_seq, length


def augment(x, cl_param, weights):
seq, length = x
flag = tf.range(3, dtype=tf.int32)
flag1 = tf.random.shuffle(flag)[:1][0]
aug_seq, aug_len = tf.cond(tf.equal(flag1, 0),
lambda: item_crop(seq, length, cl_param.crop_param),
lambda: tf.cond(tf.equal(flag1, 1),
lambda: item_mask(seq, length, weights, cl_param.mask_param),
lambda: item_reorder(seq, length, cl_param.reorder_param)))

return [aug_seq, aug_len]


def input_aug_data(original_data, seq_len, weights, cl_param):
lengths = tf.cast(seq_len, dtype=tf.int32)
aug_seq1, aug_len1 = tf.map_fn(lambda elems: augment(elems, cl_param, weights), elems=(original_data, lengths),
dtype=[tf.float32, tf.int32])
aug_seq2, aug_len2 = tf.map_fn(lambda elems: augment(elems, cl_param, weights), elems=(original_data, lengths),
dtype=[tf.float32, tf.int32])
aug_seq1 = tf.reshape(aug_seq1, tf.shape(original_data))
aug_seq2 = tf.reshape(aug_seq2, tf.shape(original_data))
return aug_seq1, aug_seq2, aug_len1, aug_len2
3 changes: 3 additions & 0 deletions easy_rec/python/layers/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@
from .multi_task import MMoE
from .numerical_embedding import AutoDisEmbedding
from .numerical_embedding import PeriodicEmbedding
from .bst_cl4ctr import BSTCTR
from .loss_for_cl4ctr import LOSSCTR
from .bst_forseq import BSTSEQ
57 changes: 57 additions & 0 deletions easy_rec/python/layers/keras/bst_cl4ctr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# -*- encoding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf
from tensorflow.python.keras.layers import Layer

from easy_rec.python.layers import multihead_cross_attention
from easy_rec.python.utils.activation import get_activation


class BSTCTR(Layer):

def __init__(self, params, name='bst_cl4ctr', l2_reg=None, **kwargs):
super(BSTCTR, self).__init__(name=name, **kwargs)
self.l2_reg = l2_reg
self.config = params.get_pb_config()

def encode(self, fea_input, max_position):
fea_input = multihead_cross_attention.embedding_postprocessor(
fea_input,
position_embedding_name=self.name + '/position_embeddings',
max_position_embeddings=max_position,
reuse_position_embedding=tf.AUTO_REUSE)

n = tf.count_nonzero(fea_input, axis=-1)
seq_mask = tf.cast(n > 0, tf.int32)

attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask(
from_tensor=fea_input, to_mask=seq_mask)

hidden_act = get_activation(self.config.hidden_act)
attention_fea = multihead_cross_attention.transformer_encoder(
fea_input,
hidden_size=self.config.hidden_size,
num_hidden_layers=self.config.num_hidden_layers,
num_attention_heads=self.config.num_attention_heads,
attention_mask=attention_mask,
intermediate_size=self.config.intermediate_size,
intermediate_act_fn=hidden_act,
hidden_dropout_prob=self.config.hidden_dropout_prob,
attention_probs_dropout_prob=self.config.attention_probs_dropout_prob,
initializer_range=self.config.initializer_range,
name=self.name + '/transformer',
reuse=tf.AUTO_REUSE)
# attention_fea shape: [batch_size, num_features, hidden_size]
# out_fea shape: [batch_size * num_features, hidden_size]
out_fea = tf.reshape(attention_fea, [-1, self.config.hidden_size])
print('bst_cl4ctr output shape:', out_fea.shape)
return out_fea

def call(self, inputs, training=None, **kwargs):
# inputs: [batch_size, num_features, embed_size]
if not training:
self.config.hidden_dropout_prob = 0.0
self.config.attention_probs_dropout_prob = 0.0
max_position = self.config.max_position_embeddings

return self.encode(inputs, max_position)
102 changes: 102 additions & 0 deletions easy_rec/python/layers/keras/bst_forseq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# -*- encoding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf

from easy_rec.python.layers import multihead_cross_attention
from easy_rec.python.utils.activation import get_activation
from easy_rec.python.utils.shape_utils import get_shape_list
from easy_rec.python.loss.nce_loss import nce_loss
from easy_rec.python.input.augment import input_aug_data


# from tensorflow.python.keras.layers import Layer


class BSTSEQ(object):

def __init__(self, params, name='bstseq', l2_reg=None, **kwargs):
# super(BSTSEQ, self).__init__(name=name, **kwargs)
self.name = name
self.l2_reg = l2_reg
self.config = params.get_pb_config()

def encode(self, seq_input, max_position):
seq_fea = multihead_cross_attention.embedding_postprocessor(
seq_input,
position_embedding_name=self.name + '/position_embeddings',
max_position_embeddings=max_position,
reuse_position_embedding=tf.AUTO_REUSE)

n = tf.count_nonzero(seq_fea, axis=-1)
seq_mask = tf.cast(n > 0, tf.int32)

attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask(
from_tensor=seq_fea, to_mask=seq_mask)

hidden_act = get_activation(self.config.hidden_act)
attention_fea = multihead_cross_attention.transformer_encoder(
seq_fea,
hidden_size=self.config.hidden_size,
num_hidden_layers=self.config.num_hidden_layers,
num_attention_heads=self.config.num_attention_heads,
attention_mask=attention_mask,
intermediate_size=self.config.intermediate_size,
intermediate_act_fn=hidden_act,
hidden_dropout_prob=self.config.hidden_dropout_prob,
attention_probs_dropout_prob=self.config.attention_probs_dropout_prob,
initializer_range=self.config.initializer_range,
name=self.name + '/bstseq',
reuse=tf.AUTO_REUSE)
# attention_fea shape: [batch_size, seq_length, hidden_size]
out_fea = attention_fea[:, 0, :] # target feature
return out_fea

def __call__(self, inputs, training=None, **kwargs):
seq_features, _ = inputs
assert len(seq_features) > 0, '[%s] sequence feature is empty' % self.name
if not training:
self.config.hidden_dropout_prob = 0.0
self.config.attention_probs_dropout_prob = 0.0

seq_embeds = [seq_fea for seq_fea, _ in seq_features]

max_position = self.config.max_position_embeddings
# max_seq_len: the max sequence length in current mini-batch, all sequences are padded to this length
batch_size, max_seq_len, _ = get_shape_list(seq_features[0][0], 3)
valid_len = tf.assert_less_equal(
max_seq_len,
max_position,
message='sequence length is greater than `max_position_embeddings`:' +
str(max_position) + ' in feature group:' + self.name)
with tf.control_dependencies([valid_len]):
# seq_input: [batch_size, seq_len, embed_size]
seq_input = tf.concat(seq_embeds, axis=-1)

seq_embed_size = seq_input.shape.as_list()[-1]
if seq_embed_size != self.config.hidden_size:
seq_input = tf.layers.dense(
seq_input,
self.config.hidden_size,
activation=tf.nn.relu,
kernel_regularizer=self.l2_reg)

seq_len = seq_features[0][1]
if self.config.need_contrastive_learning:
with tf.variable_scope('cl_mask', reuse=tf.AUTO_REUSE):
weights = tf.get_variable('mask_tensor',
shape=[seq_input.shape.as_list()[-1]],
trainable=True,
initializer=tf.truncated_normal_initializer(stddev=0.02))
cl_loss = self.contrastive_loss(seq_input, seq_len, max_position, weights, self.config.cl_param)
cl_loss *= self.config.contrastive_loss_weight
assert 'loss_dict' in kwargs, "no `loss_dict` in kwargs of bst layer: %s" % self.name
loss_dict = kwargs['loss_dict']
loss_dict['%s_contrastive_loss' % self.name] = cl_loss
return self.encode(seq_input, max_position)

def contrastive_loss(self, seq_input, seq_len, max_position, weights, cl_param):
aug_seq1, aug_seq2, aug_len1, aug_len2 = input_aug_data(seq_input, seq_len, weights, cl_param)
seq_output1 = self.encode(aug_seq1, max_position)
seq_output2 = self.encode(aug_seq2, max_position)
loss = nce_loss(seq_output1, seq_output2)
return loss
53 changes: 53 additions & 0 deletions easy_rec/python/layers/keras/loss_for_cl4ctr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import tensorflow as tf
from tensorflow.python.keras.layers import Layer
from easy_rec.python.utils.shape_utils import get_shape_list


def contrastive(fea_cl1, fea_cl2):
distance_sq = tf.reduce_sum(tf.square(tf.subtract(fea_cl1, fea_cl2)), axis=1)
loss = tf.reduce_mean(distance_sq)
return loss


def compute_uniformity_loss(fea_emb):
frac = tf.matmul(fea_emb, tf.transpose(fea_emb, perm=[0, 2, 1]))
norm = tf.norm(fea_emb, axis=2, keepdims=True)
denom = tf.matmul(norm, tf.transpose(norm, perm=[0, 2, 1]))
res = tf.div_no_nan(frac, denom)
uniformity_loss = tf.reduce_mean(res)
return uniformity_loss


def compute_alignment_loss(fea_emb):
batch_size = get_shape_list(fea_emb)[0]
indices = tf.where(tf.ones([tf.reduce_sum(batch_size), tf.reduce_sum(batch_size)]))
row = tf.gather(tf.reshape(indices[:, 0], [-1]), tf.where(indices[:, 0] < indices[:, 1]))
col = tf.gather(tf.reshape(indices[:, 1], [-1]), tf.where(indices[:, 0] < indices[:, 1]))
row = tf.squeeze(row)
col = tf.squeeze(col)
x_row = tf.gather(fea_emb, row)
x_col = tf.gather(fea_emb, col)
distance_sq = tf.reduce_sum(tf.square(tf.subtract(x_row, x_col)), axis=2)
alignment_loss = tf.reduce_mean(distance_sq)
return alignment_loss


class LOSSCTR(Layer):
def __init__(self, params, name='loss_ctr', **kwargs):
super(LOSSCTR, self).__init__(name=name, **kwargs)
self.cl_weight = params.get_or_default('cl_weight', 1)
self.au_weight = params.get_or_default('au_weight', 0.01)

def call(self, inputs, training=None, **kwargs):
if training:
# fea_cl1, fea_cl2, fea_emd = inputs
fea_cl1, fea_cl2 = inputs
# cl_align_loss = compute_alignment_loss(fea_emd)
# cl_uniform_loss = compute_uniformity_loss(fea_emd)
cl_loss = contrastive(fea_cl1, fea_cl2)
# loss = cl_loss * self.cl_weight + (cl_align_loss + cl_uniform_loss) * self.au_weight
loss_dict = kwargs['loss_dict']
loss_dict['%s_cl_loss' % self.name] = cl_loss * self.cl_weight
# loss_dict['%s_align_loss' % self.name] = cl_align_loss * self.au_weight
# loss_dict['%s_uniform_loss' % self.name] = cl_uniform_loss * self.au_weight
return 0
39 changes: 39 additions & 0 deletions easy_rec/python/loss/nce_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

import tensorflow as tf

from easy_rec.python.utils.shape_utils import get_shape_list


def mask_samples(batch_size):
part = tf.ones((batch_size, batch_size), bool)
diag_part = tf.linalg.diag_part(part)
diag_part = tf.fill(tf.shape(diag_part), False)
part = tf.linalg.set_diag(part, diag_part)
part_half = tf.concat([part, part], axis=1)
part_total = tf.concat([part_half, part_half], axis=0)
return part_total


def nce_loss(z_i, z_j, temp=1):
batch_size = get_shape_list(z_i)[0]
N = 2 * batch_size
z = tf.concat((z_i, z_j), axis=0)
sim = tf.matmul(z, tf.transpose(z)) / temp
sim_i_j = tf.matrix_diag_part(
tf.slice(sim, [batch_size, 0], [batch_size, batch_size]))
sim_j_i = tf.matrix_diag_part(
tf.slice(sim, [0, batch_size], [batch_size, batch_size]))
positive_samples = tf.reshape(tf.concat((sim_i_j, sim_j_i), axis=0), (N, 1))
mask = mask_samples(batch_size)
negative_samples = tf.reshape(tf.boolean_mask(sim, mask), (N, -1))

labels = tf.zeros(N, dtype=tf.int32)
logits = tf.concat((positive_samples, negative_samples), axis=1)

loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits))

return loss
16 changes: 16 additions & 0 deletions easy_rec/python/protos/seq_encoder.proto
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ message BSTEncoder {
required bool use_position_embeddings = 9 [default = true];
// The stddev of the truncated_normal_initializer for initializing all weight matrices
required float initializer_range = 10 [default = 0.02];
// need contrastive learning
required bool need_contrastive_learning = 11 [default = false];
// the weight of contrastive learning loss
optional float contrastive_loss_weight = 12 [default = 1.0];
// seq_fea contrastive learning params
optional ContrastLearning cl_param= 13 ;
}

message DINEncoder {
Expand All @@ -35,3 +41,13 @@ message DINEncoder {
// option: softmax, sigmoid
required string attention_normalizer = 3 [default = 'softmax'];
}


message ContrastLearning {
//Percentage length of mask original sequence
required float mask_param = 1 [default = 0.6];
//Percentage left of crop original sequence
required float crop_param = 2 [default = 0.2];
//Percentage length of reorder original sequence
required float reorder_param = 3 [default = 0.6];
}

0 comments on commit eea8a8b

Please sign in to comment.