SASRecF predict的计算方式 #1261
-
大佬们好,想请教一下为什么在SASRecF的
只用了 |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
@hlyu9 您好,可能您对 RecBole 的训练框架有一点误解, RecBole/recbole/trainer/trainer.py Lines 167 to 168 in 896155c 模型框架里的 SASRecF 结合 无论是训练过程还是测试过程,和 RecBole/recbole/model/sequential_recommender/sasrecf.py Lines 137 to 138 in 896155c |
Beta Was this translation helpful? Give feedback.
@hlyu9 您好,可能您对 RecBole 的训练框架有一点误解,
forward
并不是一般PyTorch
的前向传播函数,也不只被用于训练过程中。在 RecBole 的Trainer
中,训练时计算损失的函数实际上是calculate_loss
:RecBole/recbole/trainer/trainer.py
Lines 167 to 168 in 896155c
模型框架里的
forward
你可以把它理解为对item_seq
进行编码的函数,通过seq_output = self.forward(item_seq, item_seq_len)
就得到了融合特征的序列输出。SASRecF 结合
self.feature_embed_layer
的方式是通过forward
加入feature embedding
,从而对用户的历史交互序列,也就是item_seq
结合特征进行编码,编码之后用得到的seq_output
与item_emedding
的内积作为物品的得分。无论是训练过程还是测试过程,和
seq_output
相乘的都是self.item_embedding
的权重。也就是说,predict
和full_sort_predict
中的test_items_emb = self.item_embedding.weight
对应的是训练过程中计…