You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to finetune for the Adyghe language. After adding the necessary dataset and code changes (based on Recipes/finetuning_example_multilingual.py) I used the following command line:
The run crashes while trying to finetune the aligner model when it tries to do back-propagation apparently with an error indicating that some tensors were detached from the computation graph.
I saw your comment in the Modules/Aligner/autoaligner_train_loop.py: "extremely unfortunate that we have to do this over here..."
where you refer to the second line below it where you move the mel spec to the cpu explicitly. I don't know if this is related.
Here is the error:
...
Loaded an Aligner dataset with 1891 datapoints from Corpora/adyge_fiftylangmale_small.
/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:30: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
0%| | 0/59 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/home/haroon/git_repos/IMS-Toucan_August_2024/run_training_pipeline.py", line 110, in
pipeline_dict[args.pipeline](gpu_id=args.gpu_id,
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Recipes/finetuning_example_multilingual_adyge.py", line 81, in run
adyge_datasets.append(prepare_tts_corpus(transcript_dict=build_path_to_transcript_dict_adyge_fiftylangmale_small(),
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Utility/corpus_preparation.py", line 52, in prepare_tts_corpus
train_aligner(train_dataset=aligner_datapoints,
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/Aligner/autoaligner_train_loop.py", line 194, in train_loop
loss.backward()
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/autograd/init.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
After this error I rerun the same command without deleting any cached files - I assume that the original aligner model will be used.
This time it starts training the main model but it seems to crash after several epochs when it tries to run inference for evaluation.
This is the error from the second time:
...
loading model Models/adyge_Fiftylangmale_small/checkpoint_936.pt
averaging...
saving model...
...done!
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 156/156 [01:16<00:00, 2.04it/s]
EPOCH COMPLETE
0%| | 0/156 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/haroon/git_repos/IMS-Toucan_August_2024/run_training_pipeline.py", line 110, in
pipeline_dict[args.pipeline](gpu_id=args.gpu_id,
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Recipes/finetuning_example_multilingual_adyge.py", line 118, in run
train_loop(net=model,
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/toucantts_train_loop_arbiter.py", line 55, in train_loop
mono_language_loop(net=net,
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/toucantts_train_loop.py", line 133, in train_loop
regression_loss, stochastic_loss, duration_loss, pitch_loss, energy_loss = net(
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/ToucanTTS.py", line 299, in forward
energy_loss = self._forward(text_tensors=text_tensors,
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/ToucanTTS.py", line 425, in _forward
stochastic_loss, _ = self.flow_matching_decoder.compute_loss(x1=gold_speech.transpose(1, 2),
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/flow_matching.py", line 120, in compute_loss
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), c),
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/dit_wrapper.py", line 173, in forward
x = block(x, c, t, mask)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/dit_wrapper.py", line 33, in forward
x = self.block(x, c, x_mask)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/dit.py", line 128, in forward
x = x + self.attn(self.norm1(x.transpose(1, 2)).transpose(1, 2), attn_mask)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/dit.py", line 73, in forward
x = self.attention(q, k, v, mask=attn_mask)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/dit.py", line 84, in attention
query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head]
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/dit.py", line 210, in forward
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
By the way, I used the same dataset successfully for finetuning before your big changes in the past 3-4 months, i.e. in the previous major version.
I also tried to run on the cpu by not providing --gpu_id but it turns out that this won't work since there are several places in the code where gpu is used explicitly, and I did get severeal errors related to that.
One last couple of comments:
Minor issue: in Recipes/finetuning_example_multilingual.py "import torch" was missing although it is being used
torchvision is missing in the requirements.txt file and I had to add this:
torchvision~=0.16.2
Best regards!
The text was updated successfully, but these errors were encountered:
I looked through the Aligner and tried to reproduce the problem, or at least figure out what could possibly cause it, but I didn't find anything. I don't think it's related to the CPU hack, there must have been something in that tensor that has no grad function, but I don't understand how it could have gotten there.
But at least the regular aligner worked and you didn't need the finetuned one necessarily, although it would probably improve quality by a lot. The second error is also a very weird one. When a tensor is created while torch.inference_mode is active, it will not have the hooks for the autograd engine and cannot be used for training. So it's important to make sure that inference mode is really only ever active when you're doing inference. It seems that the cache tensor for the attention cache in the diffusion transformer was created during inference mode. I don't see any place in the code where this could have happened.
I'm not sure how to find this problem, since I cannot reproduce it.
Hi,
I'm trying to finetune for the Adyghe language. After adding the necessary dataset and code changes (based on Recipes/finetuning_example_multilingual.py) I used the following command line:
python3 run_training_pipeline.py --gpu_id 0 --resume_checkpoint Models/ToucanTTS_Meta/best.pt --resume --finetune --model_save_dir Models/adyge_Fiftylangmale_small/ fine_tuning_adyge
The run crashes while trying to finetune the aligner model when it tries to do back-propagation apparently with an error indicating that some tensors were detached from the computation graph.
I saw your comment in the Modules/Aligner/autoaligner_train_loop.py: "extremely unfortunate that we have to do this over here..."
where you refer to the second line below it where you move the mel spec to the cpu explicitly. I don't know if this is related.
Here is the error:
...
Loaded an Aligner dataset with 1891 datapoints from Corpora/adyge_fiftylangmale_small.
/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:30: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
0%| | 0/59 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/home/haroon/git_repos/IMS-Toucan_August_2024/run_training_pipeline.py", line 110, in
pipeline_dict[args.pipeline](gpu_id=args.gpu_id,
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Recipes/finetuning_example_multilingual_adyge.py", line 81, in run
adyge_datasets.append(prepare_tts_corpus(transcript_dict=build_path_to_transcript_dict_adyge_fiftylangmale_small(),
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Utility/corpus_preparation.py", line 52, in prepare_tts_corpus
train_aligner(train_dataset=aligner_datapoints,
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/Aligner/autoaligner_train_loop.py", line 194, in train_loop
loss.backward()
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/autograd/init.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
After this error I rerun the same command without deleting any cached files - I assume that the original aligner model will be used.
This time it starts training the main model but it seems to crash after several epochs when it tries to run inference for evaluation.
This is the error from the second time:
...
loading model Models/adyge_Fiftylangmale_small/checkpoint_936.pt
averaging...
saving model...
...done!
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 156/156 [01:16<00:00, 2.04it/s]
EPOCH COMPLETE
0%| | 0/156 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/haroon/git_repos/IMS-Toucan_August_2024/run_training_pipeline.py", line 110, in
pipeline_dict[args.pipeline](gpu_id=args.gpu_id,
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Recipes/finetuning_example_multilingual_adyge.py", line 118, in run
train_loop(net=model,
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/toucantts_train_loop_arbiter.py", line 55, in train_loop
mono_language_loop(net=net,
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/toucantts_train_loop.py", line 133, in train_loop
regression_loss, stochastic_loss, duration_loss, pitch_loss, energy_loss = net(
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/ToucanTTS.py", line 299, in forward
energy_loss = self._forward(text_tensors=text_tensors,
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/ToucanTTS.py", line 425, in _forward
stochastic_loss, _ = self.flow_matching_decoder.compute_loss(x1=gold_speech.transpose(1, 2),
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/flow_matching.py", line 120, in compute_loss
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), c),
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/dit_wrapper.py", line 173, in forward
x = block(x, c, t, mask)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/dit_wrapper.py", line 33, in forward
x = self.block(x, c, x_mask)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/dit.py", line 128, in forward
x = x + self.attn(self.norm1(x.transpose(1, 2)).transpose(1, 2), attn_mask)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/dit.py", line 73, in forward
x = self.attention(q, k, v, mask=attn_mask)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/dit.py", line 84, in attention
query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head]
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/haroon/python_virtual_envs/IMS-Toucan_August_2024/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/haroon/git_repos/IMS-Toucan_August_2024/Modules/ToucanTTS/dit.py", line 210, in forward
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
By the way, I used the same dataset successfully for finetuning before your big changes in the past 3-4 months, i.e. in the previous major version.
I also tried to run on the cpu by not providing --gpu_id but it turns out that this won't work since there are several places in the code where gpu is used explicitly, and I did get severeal errors related to that.
One last couple of comments:
Minor issue: in Recipes/finetuning_example_multilingual.py "import torch" was missing although it is being used
torchvision is missing in the requirements.txt file and I had to add this:
torchvision~=0.16.2
Best regards!
The text was updated successfully, but these errors were encountered: