Skip to content

Commit

Permalink
Merge pull request #1946 from jackmedda/dropout_bug
Browse files Browse the repository at this point in the history
Fixed dropout instantiation in NGCF and GRU4RecKG forward. Moved dropout_prob in config for SimpleX.
  • Loading branch information
zhengbw0324 authored Aug 28, 2024
2 parents 2b6e209 + f2c2052 commit 14430b3
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 6 deletions.
3 changes: 2 additions & 1 deletion recbole/model/general_recommender/ngcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, config, dataset):
self.sparse_dropout = SparseDropout(self.node_dropout)
self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size)
self.emb_dropout = nn.Dropout(self.message_dropout)
self.GNNlayers = torch.nn.ModuleList()
for idx, (input_size, output_size) in enumerate(
zip(self.hidden_size_list[:-1], self.hidden_size_list[1:])
Expand Down Expand Up @@ -157,7 +158,7 @@ def forward(self):
for gnn in self.GNNlayers:
all_embeddings = gnn(A_hat, self.eye_matrix, all_embeddings)
all_embeddings = nn.LeakyReLU(negative_slope=0.2)(all_embeddings)
all_embeddings = nn.Dropout(self.message_dropout)(all_embeddings)
all_embeddings = self.emb_dropout(all_embeddings)
all_embeddings = F.normalize(all_embeddings, p=2, dim=1)
embeddings_list += [
all_embeddings
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, config, dataset):
if self.aggregator == "self_attention":
self.W_q = nn.Linear(self.embedding_size, 1, bias=False)
# dropout
self.dropout = nn.Dropout(0.1)
self.dropout_prob = nn.Dropout(config["dropout_prob"])
self.require_pow = config["require_pow"]
# l2 regularization loss
self.reg_loss = EmbLoss()
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/knowledge_aware_recommender/mkr.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, config, dataset):
self.kge_pred_mlp = MLPLayers(
[self.embedding_size * 2, self.embedding_size], self.dropout_prob, "sigmoid"
)
if self.use_inner_product == False:
if not self.use_inner_product:
self.rs_pred_mlp = MLPLayers(
[self.embedding_size * 2, 1], self.dropout_prob, "sigmoid"
)
Expand Down
6 changes: 4 additions & 2 deletions recbole/model/sequential_recommender/gru4reckg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(self, config, dataset):
self.entity_embedding = nn.Embedding(
self.n_items, self.embedding_size, padding_idx=0
)
self.item_emb_dropout = nn.Dropout(self.dropout)
self.entity_emb_dropout = nn.Dropout(self.dropout)
self.entity_embedding.weight.requires_grad = not self.freeze_kg
self.item_gru_layers = nn.GRU(
input_size=self.embedding_size,
Expand Down Expand Up @@ -79,8 +81,8 @@ def __init__(self, config, dataset):
def forward(self, item_seq, item_seq_len):
item_emb = self.item_embedding(item_seq)
entity_emb = self.entity_embedding(item_seq)
item_emb = nn.Dropout(self.dropout)(item_emb)
entity_emb = nn.Dropout(self.dropout)(entity_emb)
item_emb = self.item_emb_dropout(item_emb)
entity_emb = self.entity_emb_dropout(entity_emb)

item_gru_output, _ = self.item_gru_layers(item_emb) # [B Len H]
entity_gru_output, _ = self.entity_gru_layers(entity_emb)
Expand Down
3 changes: 2 additions & 1 deletion recbole/properties/model/SimpleX.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ negative_weight: 10 # (int) Weight to balance between positive-sampl
gamma: 0.5 # (float) Weight for fusion of user' and interacted items' representations.
aggregator: 'mean' # (str) The item aggregator ranging in ['mean', 'user_attention', 'self_attention'].
history_len: 50 # (int) The length of the user's historical interaction items.
reg_weight: 1e-05 # (float) The L2 regularization weights.
reg_weight: 1e-05 # (float) The L2 regularization weights.
dropout_prob: 0.1 # (float) Dropout probability for fusion of user' and interacted items' representations.

0 comments on commit 14430b3

Please sign in to comment.