From eea8a8bc80dff61ee77ed124ead20d87856486ec Mon Sep 17 00:00:00 2001 From: yzhjerry <978409300@qq.com> Date: Mon, 11 Sep 2023 15:15:27 +0800 Subject: [PATCH] add cl --- easy_rec/python/input/augment.py | 77 +++++++++++++ easy_rec/python/layers/keras/__init__.py | 3 + easy_rec/python/layers/keras/bst_cl4ctr.py | 57 ++++++++++ easy_rec/python/layers/keras/bst_forseq.py | 102 ++++++++++++++++++ .../python/layers/keras/loss_for_cl4ctr.py | 53 +++++++++ easy_rec/python/loss/nce_loss.py | 39 +++++++ easy_rec/python/protos/seq_encoder.proto | 16 +++ 7 files changed, 347 insertions(+) create mode 100644 easy_rec/python/input/augment.py create mode 100644 easy_rec/python/layers/keras/bst_cl4ctr.py create mode 100644 easy_rec/python/layers/keras/bst_forseq.py create mode 100644 easy_rec/python/layers/keras/loss_for_cl4ctr.py create mode 100644 easy_rec/python/loss/nce_loss.py diff --git a/easy_rec/python/input/augment.py b/easy_rec/python/input/augment.py new file mode 100644 index 000000000..c59f60330 --- /dev/null +++ b/easy_rec/python/input/augment.py @@ -0,0 +1,77 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import tensorflow as tf +from easy_rec.python.utils.shape_utils import get_shape_list + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +def assign(input_tensor, position=None, value=None): + input_tensor[tuple(position)] = value + return input_tensor + + +def item_mask(aug_data, length, weights, mask_param): + length1 = tf.cast(length, dtype=tf.float32) + num_mask = tf.cast(tf.math.floor(length1 * mask_param), dtype=tf.int32) + seq = tf.range(length, dtype=tf.int32) + mask_index = tf.random.shuffle(seq)[:num_mask] + masked_item_seq = aug_data + masked_item_seq = tf.py_func(assign, inp=[masked_item_seq, [mask_index], weights], Tout=masked_item_seq.dtype) + return masked_item_seq, length + + +def item_crop(aug_data, length, crop_param): + length1 = tf.cast(length, dtype=tf.float32) + max_length = tf.cast(get_shape_list(aug_data)[0], dtype=tf.int32) + embedding_size = get_shape_list(aug_data)[1] + + num_left = tf.cast(tf.math.floor(length1 * crop_param), dtype=tf.int32) + crop_begin = tf.random.uniform([1], minval=0, maxval=length - num_left, dtype=tf.int32)[0] + cropped_item_seq = tf.zeros([get_shape_list(aug_data)[0], embedding_size]) + cropped_item_seq = tf.where(crop_begin + num_left < max_length, + tf.concat([aug_data[crop_begin:crop_begin + num_left], + cropped_item_seq[:max_length - num_left]], axis=0), + tf.concat([aug_data[crop_begin:], cropped_item_seq[:crop_begin]], axis=0)) + return cropped_item_seq, num_left + + +def item_reorder(aug_data, length, reorder_param): + length1 = tf.cast(length, dtype=tf.float32) + num_reorder = tf.cast(tf.math.floor(length1 * reorder_param), dtype=tf.int32) + reorder_begin = tf.random.uniform([1], minval=0, maxval=length - num_reorder, dtype=tf.int32)[0] + shuffle_index = tf.range(reorder_begin, reorder_begin + num_reorder) + shuffle_index = tf.random.shuffle(shuffle_index) + x = tf.range(get_shape_list(aug_data)[0]) + left = tf.slice(x, [0], [reorder_begin]) + right = tf.slice(x, [reorder_begin + num_reorder], [-1]) + reordered_item_index = tf.concat([left, shuffle_index, right], axis=0) + reordered_item_seq = tf.scatter_nd(tf.expand_dims(reordered_item_index, axis=1), + aug_data, + tf.shape(aug_data)) + return reordered_item_seq, length + + +def augment(x, cl_param, weights): + seq, length = x + flag = tf.range(3, dtype=tf.int32) + flag1 = tf.random.shuffle(flag)[:1][0] + aug_seq, aug_len = tf.cond(tf.equal(flag1, 0), + lambda: item_crop(seq, length, cl_param.crop_param), + lambda: tf.cond(tf.equal(flag1, 1), + lambda: item_mask(seq, length, weights, cl_param.mask_param), + lambda: item_reorder(seq, length, cl_param.reorder_param))) + + return [aug_seq, aug_len] + + +def input_aug_data(original_data, seq_len, weights, cl_param): + lengths = tf.cast(seq_len, dtype=tf.int32) + aug_seq1, aug_len1 = tf.map_fn(lambda elems: augment(elems, cl_param, weights), elems=(original_data, lengths), + dtype=[tf.float32, tf.int32]) + aug_seq2, aug_len2 = tf.map_fn(lambda elems: augment(elems, cl_param, weights), elems=(original_data, lengths), + dtype=[tf.float32, tf.int32]) + aug_seq1 = tf.reshape(aug_seq1, tf.shape(original_data)) + aug_seq2 = tf.reshape(aug_seq2, tf.shape(original_data)) + return aug_seq1, aug_seq2, aug_len1, aug_len2 diff --git a/easy_rec/python/layers/keras/__init__.py b/easy_rec/python/layers/keras/__init__.py index cd1c5bff3..3c5b5ddbb 100644 --- a/easy_rec/python/layers/keras/__init__.py +++ b/easy_rec/python/layers/keras/__init__.py @@ -14,3 +14,6 @@ from .multi_task import MMoE from .numerical_embedding import AutoDisEmbedding from .numerical_embedding import PeriodicEmbedding +from .bst_cl4ctr import BSTCTR +from .loss_for_cl4ctr import LOSSCTR +from .bst_forseq import BSTSEQ \ No newline at end of file diff --git a/easy_rec/python/layers/keras/bst_cl4ctr.py b/easy_rec/python/layers/keras/bst_cl4ctr.py new file mode 100644 index 000000000..2949a48f0 --- /dev/null +++ b/easy_rec/python/layers/keras/bst_cl4ctr.py @@ -0,0 +1,57 @@ +# -*- encoding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import tensorflow as tf +from tensorflow.python.keras.layers import Layer + +from easy_rec.python.layers import multihead_cross_attention +from easy_rec.python.utils.activation import get_activation + + +class BSTCTR(Layer): + + def __init__(self, params, name='bst_cl4ctr', l2_reg=None, **kwargs): + super(BSTCTR, self).__init__(name=name, **kwargs) + self.l2_reg = l2_reg + self.config = params.get_pb_config() + + def encode(self, fea_input, max_position): + fea_input = multihead_cross_attention.embedding_postprocessor( + fea_input, + position_embedding_name=self.name + '/position_embeddings', + max_position_embeddings=max_position, + reuse_position_embedding=tf.AUTO_REUSE) + + n = tf.count_nonzero(fea_input, axis=-1) + seq_mask = tf.cast(n > 0, tf.int32) + + attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask( + from_tensor=fea_input, to_mask=seq_mask) + + hidden_act = get_activation(self.config.hidden_act) + attention_fea = multihead_cross_attention.transformer_encoder( + fea_input, + hidden_size=self.config.hidden_size, + num_hidden_layers=self.config.num_hidden_layers, + num_attention_heads=self.config.num_attention_heads, + attention_mask=attention_mask, + intermediate_size=self.config.intermediate_size, + intermediate_act_fn=hidden_act, + hidden_dropout_prob=self.config.hidden_dropout_prob, + attention_probs_dropout_prob=self.config.attention_probs_dropout_prob, + initializer_range=self.config.initializer_range, + name=self.name + '/transformer', + reuse=tf.AUTO_REUSE) + # attention_fea shape: [batch_size, num_features, hidden_size] + # out_fea shape: [batch_size * num_features, hidden_size] + out_fea = tf.reshape(attention_fea, [-1, self.config.hidden_size]) + print('bst_cl4ctr output shape:', out_fea.shape) + return out_fea + + def call(self, inputs, training=None, **kwargs): + # inputs: [batch_size, num_features, embed_size] + if not training: + self.config.hidden_dropout_prob = 0.0 + self.config.attention_probs_dropout_prob = 0.0 + max_position = self.config.max_position_embeddings + + return self.encode(inputs, max_position) diff --git a/easy_rec/python/layers/keras/bst_forseq.py b/easy_rec/python/layers/keras/bst_forseq.py new file mode 100644 index 000000000..53d56b09e --- /dev/null +++ b/easy_rec/python/layers/keras/bst_forseq.py @@ -0,0 +1,102 @@ +# -*- encoding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import tensorflow as tf + +from easy_rec.python.layers import multihead_cross_attention +from easy_rec.python.utils.activation import get_activation +from easy_rec.python.utils.shape_utils import get_shape_list +from easy_rec.python.loss.nce_loss import nce_loss +from easy_rec.python.input.augment import input_aug_data + + +# from tensorflow.python.keras.layers import Layer + + +class BSTSEQ(object): + + def __init__(self, params, name='bstseq', l2_reg=None, **kwargs): + # super(BSTSEQ, self).__init__(name=name, **kwargs) + self.name = name + self.l2_reg = l2_reg + self.config = params.get_pb_config() + + def encode(self, seq_input, max_position): + seq_fea = multihead_cross_attention.embedding_postprocessor( + seq_input, + position_embedding_name=self.name + '/position_embeddings', + max_position_embeddings=max_position, + reuse_position_embedding=tf.AUTO_REUSE) + + n = tf.count_nonzero(seq_fea, axis=-1) + seq_mask = tf.cast(n > 0, tf.int32) + + attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask( + from_tensor=seq_fea, to_mask=seq_mask) + + hidden_act = get_activation(self.config.hidden_act) + attention_fea = multihead_cross_attention.transformer_encoder( + seq_fea, + hidden_size=self.config.hidden_size, + num_hidden_layers=self.config.num_hidden_layers, + num_attention_heads=self.config.num_attention_heads, + attention_mask=attention_mask, + intermediate_size=self.config.intermediate_size, + intermediate_act_fn=hidden_act, + hidden_dropout_prob=self.config.hidden_dropout_prob, + attention_probs_dropout_prob=self.config.attention_probs_dropout_prob, + initializer_range=self.config.initializer_range, + name=self.name + '/bstseq', + reuse=tf.AUTO_REUSE) + # attention_fea shape: [batch_size, seq_length, hidden_size] + out_fea = attention_fea[:, 0, :] # target feature + return out_fea + + def __call__(self, inputs, training=None, **kwargs): + seq_features, _ = inputs + assert len(seq_features) > 0, '[%s] sequence feature is empty' % self.name + if not training: + self.config.hidden_dropout_prob = 0.0 + self.config.attention_probs_dropout_prob = 0.0 + + seq_embeds = [seq_fea for seq_fea, _ in seq_features] + + max_position = self.config.max_position_embeddings + # max_seq_len: the max sequence length in current mini-batch, all sequences are padded to this length + batch_size, max_seq_len, _ = get_shape_list(seq_features[0][0], 3) + valid_len = tf.assert_less_equal( + max_seq_len, + max_position, + message='sequence length is greater than `max_position_embeddings`:' + + str(max_position) + ' in feature group:' + self.name) + with tf.control_dependencies([valid_len]): + # seq_input: [batch_size, seq_len, embed_size] + seq_input = tf.concat(seq_embeds, axis=-1) + + seq_embed_size = seq_input.shape.as_list()[-1] + if seq_embed_size != self.config.hidden_size: + seq_input = tf.layers.dense( + seq_input, + self.config.hidden_size, + activation=tf.nn.relu, + kernel_regularizer=self.l2_reg) + + seq_len = seq_features[0][1] + if self.config.need_contrastive_learning: + with tf.variable_scope('cl_mask', reuse=tf.AUTO_REUSE): + weights = tf.get_variable('mask_tensor', + shape=[seq_input.shape.as_list()[-1]], + trainable=True, + initializer=tf.truncated_normal_initializer(stddev=0.02)) + cl_loss = self.contrastive_loss(seq_input, seq_len, max_position, weights, self.config.cl_param) + cl_loss *= self.config.contrastive_loss_weight + assert 'loss_dict' in kwargs, "no `loss_dict` in kwargs of bst layer: %s" % self.name + loss_dict = kwargs['loss_dict'] + loss_dict['%s_contrastive_loss' % self.name] = cl_loss + return self.encode(seq_input, max_position) + + def contrastive_loss(self, seq_input, seq_len, max_position, weights, cl_param): + aug_seq1, aug_seq2, aug_len1, aug_len2 = input_aug_data(seq_input, seq_len, weights, cl_param) + seq_output1 = self.encode(aug_seq1, max_position) + seq_output2 = self.encode(aug_seq2, max_position) + loss = nce_loss(seq_output1, seq_output2) + return loss diff --git a/easy_rec/python/layers/keras/loss_for_cl4ctr.py b/easy_rec/python/layers/keras/loss_for_cl4ctr.py new file mode 100644 index 000000000..ef0e747a9 --- /dev/null +++ b/easy_rec/python/layers/keras/loss_for_cl4ctr.py @@ -0,0 +1,53 @@ +import tensorflow as tf +from tensorflow.python.keras.layers import Layer +from easy_rec.python.utils.shape_utils import get_shape_list + + +def contrastive(fea_cl1, fea_cl2): + distance_sq = tf.reduce_sum(tf.square(tf.subtract(fea_cl1, fea_cl2)), axis=1) + loss = tf.reduce_mean(distance_sq) + return loss + + +def compute_uniformity_loss(fea_emb): + frac = tf.matmul(fea_emb, tf.transpose(fea_emb, perm=[0, 2, 1])) + norm = tf.norm(fea_emb, axis=2, keepdims=True) + denom = tf.matmul(norm, tf.transpose(norm, perm=[0, 2, 1])) + res = tf.div_no_nan(frac, denom) + uniformity_loss = tf.reduce_mean(res) + return uniformity_loss + + +def compute_alignment_loss(fea_emb): + batch_size = get_shape_list(fea_emb)[0] + indices = tf.where(tf.ones([tf.reduce_sum(batch_size), tf.reduce_sum(batch_size)])) + row = tf.gather(tf.reshape(indices[:, 0], [-1]), tf.where(indices[:, 0] < indices[:, 1])) + col = tf.gather(tf.reshape(indices[:, 1], [-1]), tf.where(indices[:, 0] < indices[:, 1])) + row = tf.squeeze(row) + col = tf.squeeze(col) + x_row = tf.gather(fea_emb, row) + x_col = tf.gather(fea_emb, col) + distance_sq = tf.reduce_sum(tf.square(tf.subtract(x_row, x_col)), axis=2) + alignment_loss = tf.reduce_mean(distance_sq) + return alignment_loss + + +class LOSSCTR(Layer): + def __init__(self, params, name='loss_ctr', **kwargs): + super(LOSSCTR, self).__init__(name=name, **kwargs) + self.cl_weight = params.get_or_default('cl_weight', 1) + self.au_weight = params.get_or_default('au_weight', 0.01) + + def call(self, inputs, training=None, **kwargs): + if training: + # fea_cl1, fea_cl2, fea_emd = inputs + fea_cl1, fea_cl2 = inputs + # cl_align_loss = compute_alignment_loss(fea_emd) + # cl_uniform_loss = compute_uniformity_loss(fea_emd) + cl_loss = contrastive(fea_cl1, fea_cl2) + # loss = cl_loss * self.cl_weight + (cl_align_loss + cl_uniform_loss) * self.au_weight + loss_dict = kwargs['loss_dict'] + loss_dict['%s_cl_loss' % self.name] = cl_loss * self.cl_weight + # loss_dict['%s_align_loss' % self.name] = cl_align_loss * self.au_weight + # loss_dict['%s_uniform_loss' % self.name] = cl_uniform_loss * self.au_weight + return 0 diff --git a/easy_rec/python/loss/nce_loss.py b/easy_rec/python/loss/nce_loss.py new file mode 100644 index 000000000..3be312835 --- /dev/null +++ b/easy_rec/python/loss/nce_loss.py @@ -0,0 +1,39 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +import tensorflow as tf + +from easy_rec.python.utils.shape_utils import get_shape_list + + +def mask_samples(batch_size): + part = tf.ones((batch_size, batch_size), bool) + diag_part = tf.linalg.diag_part(part) + diag_part = tf.fill(tf.shape(diag_part), False) + part = tf.linalg.set_diag(part, diag_part) + part_half = tf.concat([part, part], axis=1) + part_total = tf.concat([part_half, part_half], axis=0) + return part_total + + +def nce_loss(z_i, z_j, temp=1): + batch_size = get_shape_list(z_i)[0] + N = 2 * batch_size + z = tf.concat((z_i, z_j), axis=0) + sim = tf.matmul(z, tf.transpose(z)) / temp + sim_i_j = tf.matrix_diag_part( + tf.slice(sim, [batch_size, 0], [batch_size, batch_size])) + sim_j_i = tf.matrix_diag_part( + tf.slice(sim, [0, batch_size], [batch_size, batch_size])) + positive_samples = tf.reshape(tf.concat((sim_i_j, sim_j_i), axis=0), (N, 1)) + mask = mask_samples(batch_size) + negative_samples = tf.reshape(tf.boolean_mask(sim, mask), (N, -1)) + + labels = tf.zeros(N, dtype=tf.int32) + logits = tf.concat((positive_samples, negative_samples), axis=1) + + loss = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits)) + + return loss diff --git a/easy_rec/python/protos/seq_encoder.proto b/easy_rec/python/protos/seq_encoder.proto index 2b845a429..a302437f7 100644 --- a/easy_rec/python/protos/seq_encoder.proto +++ b/easy_rec/python/protos/seq_encoder.proto @@ -25,6 +25,12 @@ message BSTEncoder { required bool use_position_embeddings = 9 [default = true]; // The stddev of the truncated_normal_initializer for initializing all weight matrices required float initializer_range = 10 [default = 0.02]; + // need contrastive learning + required bool need_contrastive_learning = 11 [default = false]; + // the weight of contrastive learning loss + optional float contrastive_loss_weight = 12 [default = 1.0]; + // seq_fea contrastive learning params + optional ContrastLearning cl_param= 13 ; } message DINEncoder { @@ -35,3 +41,13 @@ message DINEncoder { // option: softmax, sigmoid required string attention_normalizer = 3 [default = 'softmax']; } + + +message ContrastLearning { + //Percentage length of mask original sequence + required float mask_param = 1 [default = 0.6]; + //Percentage left of crop original sequence + required float crop_param = 2 [default = 0.2]; + //Percentage length of reorder original sequence + required float reorder_param = 3 [default = 0.6]; +} \ No newline at end of file