Skip to content

Commit

Permalink
tutorial_multiworker: copy tweaks, rm UserWarnings
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-williams committed Oct 15, 2024
1 parent 54aea2e commit c3eb457
Showing 1 changed file with 8 additions and 47 deletions.
55 changes: 8 additions & 47 deletions notebooks/tutorial_multiworker.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,25 @@
"# Multi-process training\n",
"\n",
"Multi-process usage of `tiledbsoma_ml.ExperimentAxisQueryIterDataset` includes both:\n",
"* using the `torch.utils.data.DataLoader` with 1 or more workers (i.e., with an argument of `n_workers=1` or greater)\n",
"* using a multi-process training configuration, such as `DistributedDataParallel`\n",
"* using the [`torch.utils.data.DataLoader`] with 1 or more workers (i.e., with an argument of `n_workers=1` or greater)\n",
"* using a multi-process training configuration, such as [`DistributedDataParallel`]\n",
"\n",
"In these configurations, `ExperimentAxisQueryIterDataset` will automatically partition data across workers. However, when using `shuffle=True`, there are several things to keep in mind:\n",
"\n",
"1. All worker processes must share the same random number generator `seed`, ensuring that all workers shuffle and partition the data in the same way.\n",
"2. To ensure that each epoch returns a _different_ shuffle, the caller must set the epoch, using the `set_epoch` API. This is identical to the behavior of `torch.utils.data.distributed.DistributedSampler`.\n",
"2. To ensure that each epoch returns a _different_ shuffle, the caller must set the epoch, using the `set_epoch` API. This is identical to the behavior of [`torch.utils.data.distributed.DistributedSampler`].\n",
"\n",
"\n"
"[DataLoader]: https://pytorch.org/docs/stable/data.html\n",
"[`torch.utils.data.DataLoader`]: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader\n",
"[`torch.utils.data.distributed.DistributedSampler`]: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler\n",
"[`DistributedDataParallel`]: https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
"################################################################################\n",
"WARNING!\n",
"The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
"future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
"to learn more and leave feedback.\n",
"################################################################################\n",
"\n",
" deprecation_warning()\n"
]
}
],
"outputs": [],
"source": [
"import tiledbsoma as soma\n",
"import torch\n",
Expand Down Expand Up @@ -76,7 +63,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"class LogisticRegression(torch.nn.Module):\n",
" def __init__(self, input_dim, output_dim):\n",
" super(LogisticRegression, self).__init__() # noqa: UP008\n",
Expand Down Expand Up @@ -138,31 +124,6 @@
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"switching torch multiprocessing start method from \"fork\" to \"spawn\"\n",
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
"################################################################################\n",
"WARNING!\n",
"The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
"future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
"to learn more and leave feedback.\n",
"################################################################################\n",
"\n",
" deprecation_warning()\n",
"/home/bruce/miniforge3/envs/toymodel/lib/python3.11/site-packages/torchdata/datapipes/__init__.py:18: UserWarning: \n",
"################################################################################\n",
"WARNING!\n",
"The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n",
"future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n",
"to learn more and leave feedback.\n",
"################################################################################\n",
"\n",
" deprecation_warning()\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand Down

0 comments on commit c3eb457

Please sign in to comment.