Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] (DataLoader) sanity check fails due to Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) #20456

Open
MathiasBaumgartinger opened this issue Nov 27, 2024 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x

Comments

@MathiasBaumgartinger
Copy link

MathiasBaumgartinger commented Nov 27, 2024

Bug description

Hi there! I have previously created my first LightningDataModule. More specifically, a NonGeoDataModule which inherits from there (see torchgeo-fork. Interestingly, when I try to run this module I get RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor. Even more intersting is the fact, that if I override the transfer_batch_to_device like:

def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
        batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
        print("----------------------------------------")
        for k in batch.keys(): print(k, batch[k][0].get_device())
        print("----------------------------------------")
        
        return batch

I get the output

image 0
mask 0

It happens during the validation step (lightning/pytorch/strategies/strategy.py", line 411).

What version are you seeing the problem on?

v2.4

How to reproduce the bug

def train(
    config: dict, 
    data_dir: str=default_data_dir, 
    root_dir: str=default_root_dir,
    min_epochs: int=1,
    max_epochs: int=25) -> None:
    
    tune_metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
    
    module = FL(
        num_workers=config["num_workers"], 
        batch_size=config["batch_size"], 
        patch_size=config["patch_size"],
        val_split_pct=0.25,
        use_toy=True,
        #augs=transforms,
        root=data_dir, 
    )
    task = SemanticSegmentationTask(
        model="unet",
        backbone="resnet50",
        ignore_index=255,
        in_channels=5,#(5+3), #appended indices
        num_classes=13,
        lr=config["lr"],
        patience=config["lr_patience"]
    )

    # Callbacks
    checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min")
    lr_monitor = LearningRateMonitor(logging_interval="step")
    tune_callback = TuneReportCheckpointCallback(
        {"loss": "val_loss", "accuracy": "val_accuracy"}, on="validation_end"
    )
    logger = TensorBoardLogger(save_dir=root_dir, name="FLAIR2logs")

    trainer = Trainer(
        accelerator=accelerator,
        num_nodes=1,
        callbacks=[checkpoint_callback, lr_monitor, tune_callback],
        log_every_n_steps=1,
        logger=logger,
        min_epochs=1,
        max_epochs=25,
        precision=32,
    )

    trainer.fit(model=task, datamodule=module)

Error messages and logs

Traceback (most recent call last):
  File "//Dev/forks/torchgeo/train_simple.py", line 158, in <module>
    main()
  File "//Dev/forks/torchgeo/train_simple.py", line 154, in main
    train(config)
  File "//Dev/forks/torchgeo/train_simple.py", line 151, in train
    trainer.fit(model=task, datamodule=module)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1023, in _run_stage
    self._run_sanity_check()
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1052, in _run_sanity_check
    val_loop.run()
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 178, in _decorator
    return loop_run(self, *args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 411, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "//Dev/forks/torchgeo/torchgeo/trainers/segmentation.py", line 251, in validation_step
    y_hat = self(x)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//Dev/forks/torchgeo/torchgeo/trainers/base.py", line 81, in forward
    return self.model(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/segmentation_models_pytorch/base/model.py", line 38, in forward
    features = self.encoder(x)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/segmentation_models_pytorch/encoders/resnet.py", line 63, in forward
    x = stages[i](x)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/container.py", line 219, in forward
    input = module(input)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 458, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "//miniforge3/envs/torchgeo/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

Environment

Current environment
-----------------------------------------------------------
Python Version: 3.10.4
PyTorch Version: 2.4.1
Cuda is  available version: 12.4
Torch built with CUDA: True
cuDNN Version: 90100
cuDNN Enabled: True
cuDNN available: True
Device: cuda
Accelerator: gpu

lightning                 2.4.0             
lightning-utilities       0.11.9             
pytorch-lightning         2.4.0 

## conda env
name: torchgeo
channels:
  - pytorch
  - nvidia
  - conda-forge
  - defaults
dependencies:
  - python=3.10
  - pytorch-cuda=12.4
  - pytorch=2.4
  - torchgeo=0.6.0
  - tensorboard=2.17
-----------------------------------------------------------

More info

No response

@MathiasBaumgartinger MathiasBaumgartinger added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Nov 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

1 participant