-
Notifications
You must be signed in to change notification settings - Fork 332
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
347 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# -*- encoding:utf-8 -*- | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
import tensorflow as tf | ||
from easy_rec.python.utils.shape_utils import get_shape_list | ||
|
||
if tf.__version__ >= '2.0': | ||
tf = tf.compat.v1 | ||
|
||
|
||
def assign(input_tensor, position=None, value=None): | ||
input_tensor[tuple(position)] = value | ||
return input_tensor | ||
|
||
|
||
def item_mask(aug_data, length, weights, mask_param): | ||
length1 = tf.cast(length, dtype=tf.float32) | ||
num_mask = tf.cast(tf.math.floor(length1 * mask_param), dtype=tf.int32) | ||
seq = tf.range(length, dtype=tf.int32) | ||
mask_index = tf.random.shuffle(seq)[:num_mask] | ||
masked_item_seq = aug_data | ||
masked_item_seq = tf.py_func(assign, inp=[masked_item_seq, [mask_index], weights], Tout=masked_item_seq.dtype) | ||
return masked_item_seq, length | ||
|
||
|
||
def item_crop(aug_data, length, crop_param): | ||
length1 = tf.cast(length, dtype=tf.float32) | ||
max_length = tf.cast(get_shape_list(aug_data)[0], dtype=tf.int32) | ||
embedding_size = get_shape_list(aug_data)[1] | ||
|
||
num_left = tf.cast(tf.math.floor(length1 * crop_param), dtype=tf.int32) | ||
crop_begin = tf.random.uniform([1], minval=0, maxval=length - num_left, dtype=tf.int32)[0] | ||
cropped_item_seq = tf.zeros([get_shape_list(aug_data)[0], embedding_size]) | ||
cropped_item_seq = tf.where(crop_begin + num_left < max_length, | ||
tf.concat([aug_data[crop_begin:crop_begin + num_left], | ||
cropped_item_seq[:max_length - num_left]], axis=0), | ||
tf.concat([aug_data[crop_begin:], cropped_item_seq[:crop_begin]], axis=0)) | ||
return cropped_item_seq, num_left | ||
|
||
|
||
def item_reorder(aug_data, length, reorder_param): | ||
length1 = tf.cast(length, dtype=tf.float32) | ||
num_reorder = tf.cast(tf.math.floor(length1 * reorder_param), dtype=tf.int32) | ||
reorder_begin = tf.random.uniform([1], minval=0, maxval=length - num_reorder, dtype=tf.int32)[0] | ||
shuffle_index = tf.range(reorder_begin, reorder_begin + num_reorder) | ||
shuffle_index = tf.random.shuffle(shuffle_index) | ||
x = tf.range(get_shape_list(aug_data)[0]) | ||
left = tf.slice(x, [0], [reorder_begin]) | ||
right = tf.slice(x, [reorder_begin + num_reorder], [-1]) | ||
reordered_item_index = tf.concat([left, shuffle_index, right], axis=0) | ||
reordered_item_seq = tf.scatter_nd(tf.expand_dims(reordered_item_index, axis=1), | ||
aug_data, | ||
tf.shape(aug_data)) | ||
return reordered_item_seq, length | ||
|
||
|
||
def augment(x, cl_param, weights): | ||
seq, length = x | ||
flag = tf.range(3, dtype=tf.int32) | ||
flag1 = tf.random.shuffle(flag)[:1][0] | ||
aug_seq, aug_len = tf.cond(tf.equal(flag1, 0), | ||
lambda: item_crop(seq, length, cl_param.crop_param), | ||
lambda: tf.cond(tf.equal(flag1, 1), | ||
lambda: item_mask(seq, length, weights, cl_param.mask_param), | ||
lambda: item_reorder(seq, length, cl_param.reorder_param))) | ||
|
||
return [aug_seq, aug_len] | ||
|
||
|
||
def input_aug_data(original_data, seq_len, weights, cl_param): | ||
lengths = tf.cast(seq_len, dtype=tf.int32) | ||
aug_seq1, aug_len1 = tf.map_fn(lambda elems: augment(elems, cl_param, weights), elems=(original_data, lengths), | ||
dtype=[tf.float32, tf.int32]) | ||
aug_seq2, aug_len2 = tf.map_fn(lambda elems: augment(elems, cl_param, weights), elems=(original_data, lengths), | ||
dtype=[tf.float32, tf.int32]) | ||
aug_seq1 = tf.reshape(aug_seq1, tf.shape(original_data)) | ||
aug_seq2 = tf.reshape(aug_seq2, tf.shape(original_data)) | ||
return aug_seq1, aug_seq2, aug_len1, aug_len2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# -*- encoding: utf-8 -*- | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
import tensorflow as tf | ||
from tensorflow.python.keras.layers import Layer | ||
|
||
from easy_rec.python.layers import multihead_cross_attention | ||
from easy_rec.python.utils.activation import get_activation | ||
|
||
|
||
class BSTCTR(Layer): | ||
|
||
def __init__(self, params, name='bst_cl4ctr', l2_reg=None, **kwargs): | ||
super(BSTCTR, self).__init__(name=name, **kwargs) | ||
self.l2_reg = l2_reg | ||
self.config = params.get_pb_config() | ||
|
||
def encode(self, fea_input, max_position): | ||
fea_input = multihead_cross_attention.embedding_postprocessor( | ||
fea_input, | ||
position_embedding_name=self.name + '/position_embeddings', | ||
max_position_embeddings=max_position, | ||
reuse_position_embedding=tf.AUTO_REUSE) | ||
|
||
n = tf.count_nonzero(fea_input, axis=-1) | ||
seq_mask = tf.cast(n > 0, tf.int32) | ||
|
||
attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask( | ||
from_tensor=fea_input, to_mask=seq_mask) | ||
|
||
hidden_act = get_activation(self.config.hidden_act) | ||
attention_fea = multihead_cross_attention.transformer_encoder( | ||
fea_input, | ||
hidden_size=self.config.hidden_size, | ||
num_hidden_layers=self.config.num_hidden_layers, | ||
num_attention_heads=self.config.num_attention_heads, | ||
attention_mask=attention_mask, | ||
intermediate_size=self.config.intermediate_size, | ||
intermediate_act_fn=hidden_act, | ||
hidden_dropout_prob=self.config.hidden_dropout_prob, | ||
attention_probs_dropout_prob=self.config.attention_probs_dropout_prob, | ||
initializer_range=self.config.initializer_range, | ||
name=self.name + '/transformer', | ||
reuse=tf.AUTO_REUSE) | ||
# attention_fea shape: [batch_size, num_features, hidden_size] | ||
# out_fea shape: [batch_size * num_features, hidden_size] | ||
out_fea = tf.reshape(attention_fea, [-1, self.config.hidden_size]) | ||
print('bst_cl4ctr output shape:', out_fea.shape) | ||
return out_fea | ||
|
||
def call(self, inputs, training=None, **kwargs): | ||
# inputs: [batch_size, num_features, embed_size] | ||
if not training: | ||
self.config.hidden_dropout_prob = 0.0 | ||
self.config.attention_probs_dropout_prob = 0.0 | ||
max_position = self.config.max_position_embeddings | ||
|
||
return self.encode(inputs, max_position) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# -*- encoding: utf-8 -*- | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
import tensorflow as tf | ||
|
||
from easy_rec.python.layers import multihead_cross_attention | ||
from easy_rec.python.utils.activation import get_activation | ||
from easy_rec.python.utils.shape_utils import get_shape_list | ||
from easy_rec.python.loss.nce_loss import nce_loss | ||
from easy_rec.python.input.augment import input_aug_data | ||
|
||
|
||
# from tensorflow.python.keras.layers import Layer | ||
|
||
|
||
class BSTSEQ(object): | ||
|
||
def __init__(self, params, name='bstseq', l2_reg=None, **kwargs): | ||
# super(BSTSEQ, self).__init__(name=name, **kwargs) | ||
self.name = name | ||
self.l2_reg = l2_reg | ||
self.config = params.get_pb_config() | ||
|
||
def encode(self, seq_input, max_position): | ||
seq_fea = multihead_cross_attention.embedding_postprocessor( | ||
seq_input, | ||
position_embedding_name=self.name + '/position_embeddings', | ||
max_position_embeddings=max_position, | ||
reuse_position_embedding=tf.AUTO_REUSE) | ||
|
||
n = tf.count_nonzero(seq_fea, axis=-1) | ||
seq_mask = tf.cast(n > 0, tf.int32) | ||
|
||
attention_mask = multihead_cross_attention.create_attention_mask_from_input_mask( | ||
from_tensor=seq_fea, to_mask=seq_mask) | ||
|
||
hidden_act = get_activation(self.config.hidden_act) | ||
attention_fea = multihead_cross_attention.transformer_encoder( | ||
seq_fea, | ||
hidden_size=self.config.hidden_size, | ||
num_hidden_layers=self.config.num_hidden_layers, | ||
num_attention_heads=self.config.num_attention_heads, | ||
attention_mask=attention_mask, | ||
intermediate_size=self.config.intermediate_size, | ||
intermediate_act_fn=hidden_act, | ||
hidden_dropout_prob=self.config.hidden_dropout_prob, | ||
attention_probs_dropout_prob=self.config.attention_probs_dropout_prob, | ||
initializer_range=self.config.initializer_range, | ||
name=self.name + '/bstseq', | ||
reuse=tf.AUTO_REUSE) | ||
# attention_fea shape: [batch_size, seq_length, hidden_size] | ||
out_fea = attention_fea[:, 0, :] # target feature | ||
return out_fea | ||
|
||
def __call__(self, inputs, training=None, **kwargs): | ||
seq_features, _ = inputs | ||
assert len(seq_features) > 0, '[%s] sequence feature is empty' % self.name | ||
if not training: | ||
self.config.hidden_dropout_prob = 0.0 | ||
self.config.attention_probs_dropout_prob = 0.0 | ||
|
||
seq_embeds = [seq_fea for seq_fea, _ in seq_features] | ||
|
||
max_position = self.config.max_position_embeddings | ||
# max_seq_len: the max sequence length in current mini-batch, all sequences are padded to this length | ||
batch_size, max_seq_len, _ = get_shape_list(seq_features[0][0], 3) | ||
valid_len = tf.assert_less_equal( | ||
max_seq_len, | ||
max_position, | ||
message='sequence length is greater than `max_position_embeddings`:' + | ||
str(max_position) + ' in feature group:' + self.name) | ||
with tf.control_dependencies([valid_len]): | ||
# seq_input: [batch_size, seq_len, embed_size] | ||
seq_input = tf.concat(seq_embeds, axis=-1) | ||
|
||
seq_embed_size = seq_input.shape.as_list()[-1] | ||
if seq_embed_size != self.config.hidden_size: | ||
seq_input = tf.layers.dense( | ||
seq_input, | ||
self.config.hidden_size, | ||
activation=tf.nn.relu, | ||
kernel_regularizer=self.l2_reg) | ||
|
||
seq_len = seq_features[0][1] | ||
if self.config.need_contrastive_learning: | ||
with tf.variable_scope('cl_mask', reuse=tf.AUTO_REUSE): | ||
weights = tf.get_variable('mask_tensor', | ||
shape=[seq_input.shape.as_list()[-1]], | ||
trainable=True, | ||
initializer=tf.truncated_normal_initializer(stddev=0.02)) | ||
cl_loss = self.contrastive_loss(seq_input, seq_len, max_position, weights, self.config.cl_param) | ||
cl_loss *= self.config.contrastive_loss_weight | ||
assert 'loss_dict' in kwargs, "no `loss_dict` in kwargs of bst layer: %s" % self.name | ||
loss_dict = kwargs['loss_dict'] | ||
loss_dict['%s_contrastive_loss' % self.name] = cl_loss | ||
return self.encode(seq_input, max_position) | ||
|
||
def contrastive_loss(self, seq_input, seq_len, max_position, weights, cl_param): | ||
aug_seq1, aug_seq2, aug_len1, aug_len2 = input_aug_data(seq_input, seq_len, weights, cl_param) | ||
seq_output1 = self.encode(aug_seq1, max_position) | ||
seq_output2 = self.encode(aug_seq2, max_position) | ||
loss = nce_loss(seq_output1, seq_output2) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import tensorflow as tf | ||
from tensorflow.python.keras.layers import Layer | ||
from easy_rec.python.utils.shape_utils import get_shape_list | ||
|
||
|
||
def contrastive(fea_cl1, fea_cl2): | ||
distance_sq = tf.reduce_sum(tf.square(tf.subtract(fea_cl1, fea_cl2)), axis=1) | ||
loss = tf.reduce_mean(distance_sq) | ||
return loss | ||
|
||
|
||
def compute_uniformity_loss(fea_emb): | ||
frac = tf.matmul(fea_emb, tf.transpose(fea_emb, perm=[0, 2, 1])) | ||
norm = tf.norm(fea_emb, axis=2, keepdims=True) | ||
denom = tf.matmul(norm, tf.transpose(norm, perm=[0, 2, 1])) | ||
res = tf.div_no_nan(frac, denom) | ||
uniformity_loss = tf.reduce_mean(res) | ||
return uniformity_loss | ||
|
||
|
||
def compute_alignment_loss(fea_emb): | ||
batch_size = get_shape_list(fea_emb)[0] | ||
indices = tf.where(tf.ones([tf.reduce_sum(batch_size), tf.reduce_sum(batch_size)])) | ||
row = tf.gather(tf.reshape(indices[:, 0], [-1]), tf.where(indices[:, 0] < indices[:, 1])) | ||
col = tf.gather(tf.reshape(indices[:, 1], [-1]), tf.where(indices[:, 0] < indices[:, 1])) | ||
row = tf.squeeze(row) | ||
col = tf.squeeze(col) | ||
x_row = tf.gather(fea_emb, row) | ||
x_col = tf.gather(fea_emb, col) | ||
distance_sq = tf.reduce_sum(tf.square(tf.subtract(x_row, x_col)), axis=2) | ||
alignment_loss = tf.reduce_mean(distance_sq) | ||
return alignment_loss | ||
|
||
|
||
class LOSSCTR(Layer): | ||
def __init__(self, params, name='loss_ctr', **kwargs): | ||
super(LOSSCTR, self).__init__(name=name, **kwargs) | ||
self.cl_weight = params.get_or_default('cl_weight', 1) | ||
self.au_weight = params.get_or_default('au_weight', 0.01) | ||
|
||
def call(self, inputs, training=None, **kwargs): | ||
if training: | ||
# fea_cl1, fea_cl2, fea_emd = inputs | ||
fea_cl1, fea_cl2 = inputs | ||
# cl_align_loss = compute_alignment_loss(fea_emd) | ||
# cl_uniform_loss = compute_uniformity_loss(fea_emd) | ||
cl_loss = contrastive(fea_cl1, fea_cl2) | ||
# loss = cl_loss * self.cl_weight + (cl_align_loss + cl_uniform_loss) * self.au_weight | ||
loss_dict = kwargs['loss_dict'] | ||
loss_dict['%s_cl_loss' % self.name] = cl_loss * self.cl_weight | ||
# loss_dict['%s_align_loss' % self.name] = cl_align_loss * self.au_weight | ||
# loss_dict['%s_uniform_loss' % self.name] = cl_uniform_loss * self.au_weight | ||
return 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# -*- encoding:utf-8 -*- | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
|
||
import tensorflow as tf | ||
|
||
from easy_rec.python.utils.shape_utils import get_shape_list | ||
|
||
|
||
def mask_samples(batch_size): | ||
part = tf.ones((batch_size, batch_size), bool) | ||
diag_part = tf.linalg.diag_part(part) | ||
diag_part = tf.fill(tf.shape(diag_part), False) | ||
part = tf.linalg.set_diag(part, diag_part) | ||
part_half = tf.concat([part, part], axis=1) | ||
part_total = tf.concat([part_half, part_half], axis=0) | ||
return part_total | ||
|
||
|
||
def nce_loss(z_i, z_j, temp=1): | ||
batch_size = get_shape_list(z_i)[0] | ||
N = 2 * batch_size | ||
z = tf.concat((z_i, z_j), axis=0) | ||
sim = tf.matmul(z, tf.transpose(z)) / temp | ||
sim_i_j = tf.matrix_diag_part( | ||
tf.slice(sim, [batch_size, 0], [batch_size, batch_size])) | ||
sim_j_i = tf.matrix_diag_part( | ||
tf.slice(sim, [0, batch_size], [batch_size, batch_size])) | ||
positive_samples = tf.reshape(tf.concat((sim_i_j, sim_j_i), axis=0), (N, 1)) | ||
mask = mask_samples(batch_size) | ||
negative_samples = tf.reshape(tf.boolean_mask(sim, mask), (N, -1)) | ||
|
||
labels = tf.zeros(N, dtype=tf.int32) | ||
logits = tf.concat((positive_samples, negative_samples), axis=1) | ||
|
||
loss = tf.reduce_mean( | ||
tf.nn.sparse_softmax_cross_entropy_with_logits( | ||
labels=labels, logits=logits)) | ||
|
||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters