Skip to content

Commit

Permalink
restore num_workers in predict
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez committed Nov 8, 2024
1 parent 9e86685 commit 0ab3839
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions nbs/common.base_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@
" datamodule = TimeSeriesDataModule(\n",
" dataset=dataset,\n",
" valid_batch_size=self.valid_batch_size,\n",
" num_workers=self.num_workers_loader,\n",
" **data_module_kwargs\n",
" )\n",
" fcsts = trainer.predict(self, datamodule=datamodule)\n",
Expand Down
1 change: 1 addition & 0 deletions neuralforecast/common/_base_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ def predict(self, dataset, step_size=1, random_seed=None, **data_module_kwargs):
datamodule = TimeSeriesDataModule(
dataset=dataset,
valid_batch_size=self.valid_batch_size,
num_workers=self.num_workers_loader,
**data_module_kwargs,
)
fcsts = trainer.predict(self, datamodule=datamodule)
Expand Down

0 comments on commit 0ab3839

Please sign in to comment.