Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predict fails on a model containing the Attention layer. #20429

Open
HGS-mbayer opened this issue Oct 30, 2024 · 5 comments
Open

Predict fails on a model containing the Attention layer. #20429

HGS-mbayer opened this issue Oct 30, 2024 · 5 comments
Assignees
Labels

Comments

@HGS-mbayer
Copy link

Running predict on a model containing an Attention layer causes a RuntimeError due to a dimension issue.

  • Keras 3.6.0 (issue occurs with 3.5.0 too)
  • Backend is Torch with GPU support (2.5.1+cu124)
  • Windows 11
  • Python 3.10.11

Example Code

Here is a dummy model to reproduce the issue.

import keras
import numpy as np

INPUT_SHAPE = (128, 128, 3)
NUM_CLASSES = 3


def create_model(dims: tuple[int, int, int], num_classes: int):
    width, height, bands = dims

    inputs = keras.layers.Input((width, height, bands))

    conv1 = keras.layers.Conv2D(8, (3, 3), padding='same')(inputs)
    bn = keras.layers.BatchNormalization()(conv1)
    act = keras.layers.Activation('relu')(bn)
    pool1 = keras.layers.MaxPooling2D((2, 3), strides=(2, 2))(act)
    attention = keras.layers.Attention(use_scale=False, score_mode='dot')(
        [pool1, pool1]
    )
    output = keras.layers.Conv2D(
        num_classes, (1, 1), padding='same', activation='softmax'
    )(attention)

    model = keras.models.Model(inputs=inputs, outputs=output)

    return model


model = create_model(INPUT_SHAPE, NUM_CLASSES)

data = np.random.rand(1, *INPUT_SHAPE)

output = model.predict(data)

Training also appears to fail:

model.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.MeanSquaredError())
model.fit(data, data)

Traceback

RuntimeError                              Traceback (most recent call last)
Cell In[19], line 33
     29 model = create_model(INPUT_SHAPE, NUM_CLASSES)
     31 data = np.random.rand(1, *INPUT_SHAPE)
---> 33 output = model.predict(data)

File ...\env\lib\site-packages\keras\src\utils\traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    119     filtered_tb = _process_traceback_frames(e.__traceback__)
    120     # To get the full stack trace, call:
    121     # `keras.config.disable_traceback_filtering()`
--> 122     raise e.with_traceback(filtered_tb) from None
    123 finally:
    124     del filtered_tb

File ...\env\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ...\env\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ...\env\lib\site-packages\torch\nn\modules\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ...\env\lib\site-packages\torch\nn\modules\module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

RuntimeError: Exception encountered when calling Attention.call().

permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 4 is not equal to len(dims) = 3

Arguments received by Attention.call():
  • inputs=['torch.Tensor(shape=torch.Size([1, 64, 63, 8]), dtype=float32)', 'torch.Tensor(shape=torch.Size([1, 64, 63, 8]), dtype=float32)']
  • mask=['None', 'None']
  • training=False
  • return_attention_scores=False
  • use_causal_mask=False
@mehtamansi29
Copy link
Collaborator

Hi @HGS-mbayer -

Here Using Attention layer with Conv2D , for model creation here is some steps you can consider first.

  • Image(2D/3D with channels) having height (h) and width(w) to conv2D layer and maxpool layer
  • Get image into single dimensional with (Channel x H.W) —> 1D- reshape the maxpool output
  • Then transpose that input because multi headed attention takes tensors
  • So after transpose with 1D data
  • These will put to multi headed attention block
  • Then after attention transpose again and reshape as the input and sum the output

Here is the model using Attention layer and Convolution network.

def create_model(dims: tuple[int, int, int], num_classes: int):
    width, height, bands = dims
    ip= (width, height, bands)
    inputs= keras.layers.Input(shape= ip)
    conv2d= keras.layers.Conv2D(8,(3,3), activation='relu')(inputs)
    maxpool= keras.layers.MaxPooling2D(pool_size=(2,2))(conv2d)
    bs, h, w, c = keras.ops.shape(maxpool)
    ip_attr = keras.layers.Reshape((h * w, c))(maxpool) # Batchsize x (H.W) x channels
    ip_attr= keras.layers.Permute((2,1))(ip_attr)
    ip_attr= keras.layers.Normalization()(ip_attr)
    # att_out,att_map= keras.layers.MultiHeadAttention(num_heads=4,key_dim=bands)(ip_attr,ip_attr,return_attention_scores=True)
    att_out= keras.layers.Attention()([ip_attr,ip_attr])
    att_out= keras.layers.Add()([ip_attr,att_out])
    att_out= keras.layers.LayerNormalization()(att_out)
    flatten= keras.layers.Flatten()(att_out)
    dense= keras.layers.Dense(64,activation='relu')(flatten)
    outputs= keras.layers.Dense(num_classes,activation='softmax')(dense)
    model= keras.Model(inputs=inputs,outputs=outputs)
    return model

Attached gist for the reference here.

@HGS-mbayer
Copy link
Author

Your example works fine for me, however, I have models that were trained using Tensorflow 2.9.1 (tf-keras 2.6.0 for that version of Tensorflow if I'm not mistaken). They worked fine until upgrading to Keras 3.6.0 with the torch backend. And I went back and tried (I think) every version of Keras 3.0+ and they all fail with the following code snippet.

Here is another example that is more closely related to my architecture:

import keras
import numpy as np

def create_model(dims: tuple[int, int, int], num_classes: int):
    width, height, bands = dims
    ip = (width, height, bands)
    inputs = keras.layers.Input(shape=ip)
    conv2d_1 = keras.layers.Conv2D(8, (3, 3), activation='relu')(inputs)
    maxpool = keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2d_1)
    att_out = keras.layers.Attention()([maxpool, maxpool])
    conv2d_2 = keras.layers.Conv2D(8, (3, 3), activation='relu')(att_out)
    output2 = keras.layers.Conv2D(num_classes, (1, 1), activation='softmax')(conv2d_2)
    model = keras.Model(inputs=inputs, outputs=output2)
    return model

model = create_model2((464, 464, 4), 2)
data = np.random.randn(1, *model.input_shape[1:])
test = model.predict(data)

Using Tensorflow 2.15.0 (final version of Tensorflow to use Keras2) the same code snippet succeeds. If I upgrade to the latest Tensorflow, which also brings in Keras3, then it will fail as before.

import numpy as np
import tensorflow as tf


def create_model(dims: tuple[int, int, int], num_classes: int):
    width, height, bands = dims
    ip = (width, height, bands)
    inputs = tf.keras.layers.Input(shape=ip)
    conv2d_1 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu')(inputs)
    maxpool = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2d_1)
    att_out = tf.keras.layers.Attention()([maxpool, maxpool])
    conv2d_2 = tf.keras.layers.Conv2D(8, (3, 3), activation='relu')(att_out)
    output2 = tf.keras.layers.Conv2D(num_classes, (1, 1), activation='softmax')(
        conv2d_2
    )
    model = tf.keras.Model(inputs=inputs, outputs=output2)
    return model


model = create_model((464, 464, 4), 2)
data = np.random.randn(1, *model.input_shape[1:])
test = model.predict(data)

I guess what I'm trying to get to the bottom of is this:

  1. Is there actually a bug preventing these old models from working?
  2. Is the old model architecture wrong in some way and should have never worked?

Thanks again for your assistance!

@mehtamansi29 mehtamansi29 added the keras-team-review-pending Pending review by a Keras team member. label Nov 12, 2024
@divyashreepathihalli
Copy link
Collaborator

The root of the error is that the shape of the tensor produced by the convolutional and pooling layers does not align with what the Attention layer expects. This issue arises when reshaping is not performed to convert the convolutional output into a shape compatible with the Attention layer, leading to dimension mismatch errors.

Reshape the MaxPooling Output: The Reshape layer transforms the 4D output of the pooling layer to a 3D format with shape (batch_size, sequence_length, feature_dim).
Apply Attention: The Attention layer now works on the reshaped data.
Reshape Back: The output from the Attention layer is reshaped back to 4D so that it can be fed into subsequent convolutional layers.
The architecture may have had dimensional mismatch and it may have been handled differently in previous version and with Keras 3 the error is more apparent.

@mehtamansi29
Copy link
Collaborator

Reshape the MaxPooling Output: The Reshape layer transforms the 4D output of the pooling layer to a 3D format with shape (batch_size, sequence_length, feature_dim).
Apply Attention: The Attention layer now works on the reshaped data.
Reshape Back: The output from the Attention layer is reshaped back to 4D so that it can be fed into subsequent convolutional layers.

Thanks @divyashreepathihalli For the explanation.

Here is the model using Attention layer and Convolution network.

This is the convolution with Attention layer model architecture create with same as you mentioned. I'll be create tutorial example with convolution network with Attention layer.

@HGS-mbayer
Copy link
Author

Thanks for taking the time on this issue. The more I think about this the more I can't help but believe a regression has occurred.

  1. This used to work and therefore it should still work with no behavior deprecation. In other words, you could "just add it" before and should "just add it" now too.
  2. The trained models from before work great, which makes me believe that the layer was working properly before.
    • There must have been a mechanism in place to deal with the incoming shape or I would have expected this to fail before. It seems a little "against the grain" and not obvious to have to worry about shape manipulation in order to use this layer as in your example:

      bs, h, w, c = keras.ops.shape(maxpool)
      ip_attr = keras.layers.Reshape((h * w, c))(maxpool) # Batchsize x (H.W) x channels
      ip_attr= keras.layers.Permute((2,1))(ip_attr)
      ip_attr= keras.layers.Normalization()(ip_attr)
      att_out= keras.layers.Attention()([ip_attr,ip_attr])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants