Skip to content

SASRecF predict的计算方式 #1261

Answered by Sherry-XLL
hlyu9 asked this question in Q&A
Discussion options

You must be logged in to vote

@hlyu9 您好,可能您对 RecBole 的训练框架有一点误解,forward 并不是一般 PyTorch 的前向传播函数,也不只被用于训练过程中。在 RecBole 的 Trainer 中,训练时计算损失的函数实际上是 calculate_loss

self.model.train()
loss_func = loss_func or self.model.calculate_loss

模型框架里的 forward 你可以把它理解为对 item_seq 进行编码的函数,通过 seq_output = self.forward(item_seq, item_seq_len) 就得到了融合特征的序列输出。

SASRecF 结合 self.feature_embed_layer 的方式是通过 forward 加入 feature embedding,从而对用户的历史交互序列,也就是 item_seq 结合特征进行编码,编码之后用得到的 seq_outputitem_emedding 的内积作为物品的得分。

无论是训练过程还是测试过程,和 seq_output 相乘的都是 self.item_embedding 的权重。也就是说,predictfull_sort_predict 中的 test_items_emb = self.item_embedding.weight 对应的是训练过程中计…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@hlyu9
Comment options

Answer selected by hlyu9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
question Further information is requested
2 participants