This codebase is a PyTorch implementation of various attention mechanisms, CNNs, Vision Transformers and MLP-Like models.
If it is helpful for your work, please⭐
git clone https://github.com/changzy00/pytorch-attention.git
cd pytorch-attention
-
- 1. Squeeze-and-Excitation Attention
- 2. Convolutional Block Attention Module
- 3. Bottleneck Attention Module
- 4. Double Attention
- 5. Style Attention
- 6. Global Context Attention
- 7. Selective Kernel Attention
- 8. Linear Context Attention
- 9. Gated Channel Attention
- 10. Efficient Channel Attention
- 11. Triplet Attention
- 12. Gaussian Context Attention
- 13. Coordinate Attention
- 14. SimAM
- 15. Dual Attention
-
Convolutional Neural Networks(CNNs)
- 1. NiN Model
- 2. ResNet Model
- 3. WideResNet Model
- 4. DenseNet Model
- 5. PyramidNet Model
- 6. MobileNetV1 Model
- 7. MobileNetV2 Model
- 8. MobileNetV3 Model
- 9. MnasNet Model
- 10. EfficientNetV1 Model
- 11. Res2Net Model
- 12. MobileNeXt Model
- 13. GhostNet Model
- 14. EfficientNetV2 Model
- 15. ConvNeXt Model
- 16. Unet Model
- 17. ESPNet Model
-
Squeeze-and-Excitation Networks (CVPR 2018) pdf
import torch
from attention_mechanisms.se_module import SELayer
x = torch.randn(2, 64, 32, 32)
attn = SELayer(64)
y = attn(x)
print(y.shape)
-
CBAM: convolutional block attention module (ECCV 2018) pdf
import torch
from attention_mechanisms.cbam import CBAM
x = torch.randn(2, 64, 32, 32)
attn = CBAM(64)
y = attn(x)
print(y.shape)
-
Bam: Bottleneck attention module(BMVC 2018) pdf
import torch
from attention_mechanisms.bam import BAM
x = torch.randn(2, 64, 32, 32)
attn = BAM(64)
y = attn(x)
print(y.shape)
-
A2-nets: Double attention networks (NeurIPS 2018) pdf
import torch
from attention_mechanisms.double_attention import DoubleAttention
x = torch.randn(2, 64, 32, 32)
attn = DoubleAttention(64, 32, 32)
y = attn(x)
print(y.shape)
-
Srm : A style-based recalibration module for convolutional neural networks (ICCV 2019) pdf
import torch
from attention_mechanisms.srm import SRM
x = torch.randn(2, 64, 32, 32)
attn = SRM(64)
y = attn(x)
print(y.shape)
-
Gcnet: Non-local networks meet squeeze-excitation networks and beyond (ICCVW 2019) pdf
import torch
from attention_mechanisms.gc_module import GCModule
x = torch.randn(2, 64, 32, 32)
attn = GCModule(64)
y = attn(x)
print(y.shape)
-
Selective Kernel Networks (CVPR 2019) pdf
import torch
from attention_mechanisms.sk_module import SKLayer
x = torch.randn(2, 64, 32, 32)
attn = SKLayer(64)
y = attn(x)
print(y.shape)
-
Linear Context Transform Block (AAAI 2020) pdf
import torch
from attention_mechanisms.lct import LCT
x = torch.randn(2, 64, 32, 32)
attn = LCT(64, groups=8)
y = attn(x)
print(y.shape)
-
Gated Channel Transformation for Visual Recognition (CVPR 2020) pdf
import torch
from attention_mechanisms.gate_channel_module import GCT
x = torch.randn(2, 64, 32, 32)
attn = GCT(64)
y = attn(x)
print(y.shape)
-
Ecanet: Efficient channel attention for deep convolutional neural networks (CVPR 2020) pdf
import torch
from attention_mechanisms.eca import ECALayer
x = torch.randn(2, 64, 32, 32)
attn = ECALayer(64)
y = attn(x)
print(y.shape)
-
Rotate to Attend: Convolutional Triplet Attention Module (WACV 2021) pdf
import torch
from attention_mechanisms.triplet_attention import TripletAttention
x = torch.randn(2, 64, 32, 32)
attn = TripletAttention(64)
y = attn(x)
print(y.shape)
-
Gaussian Context Transformer (CVPR 2021) pdf
import torch
from attention_mechanisms.gct import GCT
x = torch.randn(2, 64, 32, 32)
attn = GCT(64)
y = attn(x)
print(y.shape)
-
Coordinate Attention for Efficient Mobile Network Design (CVPR 2021) pdf
import torch
from attention_mechanisms.coordatten import CoordinateAttention
x = torch.randn(2, 64, 32, 32)
attn = CoordinateAttention(64, 64)
y = attn(x)
print(y.shape)
- SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks (ICML 2021) pdf
import torch
from attention_mechanisms.simam import simam_module
x = torch.randn(2, 64, 32, 32)
attn = simam_module(64)
y = attn(x)
print(y.shape)
-
Dual Attention Network for Scene Segmentatio (CVPR 2019) pdf
import torch
from attention_mechanisms.dual_attention import PAM, CAM
x = torch.randn(2, 64, 32, 32)
#attn = PAM(64)
attn = CAM()
y = attn(x)
print(y.shape
-
An image is worth 16x16 words: Transformers for image recognition at scale (ICLR 2021) pdf
import torch
from vision_transformers.ViT import VisionTransformer
x = torch.randn(2, 3, 224, 224)
model = VisionTransformer()
y = model(x)
print(y.shape) #[2, 1000]
-
XCiT: Cross-Covariance Image Transformer (NeurIPS 2021) pdf
import torch
from vision_transformers.xcit import xcit_nano_12_p16
x = torch.randn(2, 3, 224, 224)
model = xcit_nano_12_p16()
y = model(x)
print(y.shape)
-
Rethinking Spatial Dimensions of Vision Transformers (ICCV 2021) pdf
import torch
from vision_transformers.pit import pit_ti
x = torch.randn(2, 3, 224, 224)
model = pit_ti()
y = model(x)
print(y.shape)
-
CvT: Introducing Convolutions to Vision Transformers (ICCV 2021) pdf
import torch
from vision_transformers.cvt import cvt_13
x = torch.randn(2, 3, 224, 224)
model = cvt_13()
y = model(x)
print(y.shape)
-
Pyramid vision transformer: A versatile backbone for dense prediction without convolutions (ICCV 2021) pdf
import torch
from vision_transformers.pvt import pvt_t
x = torch.randn(2, 3, 224, 224)
model = pvt_t()
y = model(x)
print(y.shape)
-
CMT: Convolutional Neural Networks Meet Vision Transformers (CVPR 2022) pdf
import torch
from vision_transformers.cmt import cmt_ti
x = torch.randn(2, 3, 224, 224)
model = cmt_ti()
y = model(x)
print(y.shape)
-
MetaFormer is Actually What You Need for Vision (CVPR 2022) pdf
import torch
from vision_transformers.poolformer import poolformer_12
x = torch.randn(2, 3, 224, 224)
model = poolformer_12()
y = model(x)
print(y.shape)
-
KVT: k-NN Attention for Boosting Vision Transformers (ECCV 2022) pdf
import torch
from vision_transformers.kvt import KVT
x = torch.randn(2, 3, 224, 224)
model = KVT()
y = model(x)
print(y.shape)
-
MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer (ICLR 2022) pdf
import torch
from vision_transformers.mobilevit import mobilevit_s
x = torch.randn(2, 3, 224, 224)
model = mobilevit_s()
y = model(x)
print(y.shape)
-
Pyramid Pooling Transformer for Scene Understanding (TPAMI 2022) pdf
import torch
from vision_transformers.p2t import p2t_tiny
x = torch.randn(2, 3, 224, 224)
model = p2t_tiny()
y = model(x)
print(y.shape)
-
EfficientFormer: Vision Transformers at MobileNet Speed (NeurIPS 2022) pdf
import torch
from vision_transformers.efficientformer import efficientformer_l1
x = torch.randn(2, 3, 224, 224)
model = efficientformer_l1()
y = model(x)
print(y.shape)
-
When Shift Operation Meets Vision Transformer: An Extremely Simple Alternative to Attention Mechanism (AAAI 2022) pdf
import torch
from vision_transformers.shiftvit import shift_t
x = torch.randn(2, 3, 224, 224)
model = shift_t()
y = model(x)
print(y.shape)
-
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows (CVPR 2022) pdf
import torch
from vision_transformers.cswin import CSWin_64_12211_tiny_224
x = torch.randn(2, 3, 224, 224)
model = CSWin_64_12211_tiny_224()
y = model(x)
print(y.shape)
-
DilateFormer: Multi-Scale Dilated Transformer for Visual Recognition (TMM 2023) pdf
import torch
from vision_transformers.dilateformer import dilateformer_tiny
x = torch.randn(2, 3, 224, 224)
model = dilateformer_tiny()
y = model(x)
print(y.shape)
-
BViT: Broad Attention based Vision Transformer (TNNLS 2023) pdf
import torch
from vision_transformers.bvit import BViT_S
x = torch.randn(2, 3, 224, 224)
model = BViT_S()
y = model(x)
print(y.shape)
-
MOAT: Alternating Mobile Convolution and Attention Brings Strong Vision Models (ICLR 2023) pdf
import torch
from vision_transformers.moat import moat_0
x = torch.randn(2, 3, 224, 224)
model = moat_0()
y = model(x)
print(y.shape)
-
SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (NeurIPS 2021) pdf
import torch
from vision_transformers.moat import SegFormer
x = torch.randn(2, 3, 512, 512)
model = SegFormer(num_classes=50)
y = model(x)
print(y.shape)
-
Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers (CVPR 2021) pdf
import torch
from vision_transformers.setr import SETR
x = torch.randn(2, 3, 480, 480)
model = SETR(num_classes=50)
y = model(x)
print(y.shape)
-
Network In Network (ICLR 2014) pdf
import torch
from cnns.NiN import NiN
x = torch.randn(2, 3, 224, 224)
model = NiN()
y = model(x)
print(y.shape)
-
Deep Residual Learning for Image Recognition (CVPR 2016) pdf
import torch
from cnns.resnet import resnet18
x = torch.randn(2, 3, 224, 224)
model = resnet18()
y = model(x)
print(y.shape)
-
Wide Residual Networks (BMVC 2016) pdf
import torch
from cnns.wideresnet import wideresnet
x = torch.randn(2, 3, 224, 224)
model = wideresnet()
y = model(x)
print(y.shape)
-
Densely Connected Convolutional Networks (CVPR 2017) pdf
import torch
from cnns.densenet import densenet121
x = torch.randn(2, 3, 224, 224)
model = densenet121()
y = model(x)
print(y.shape)
-
Deep Pyramidal Residual Networks (CVPR 2017) pdf
import torch
from cnns.pyramidnet import pyramidnet18
x = torch.randn(2, 3, 224, 224)
model = densenet121()
y = model(x)
print(y.shape)
-
MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications (CVPR 2017) pdf
import torch
from cnns.mobilenetv1 import MobileNetv1
x = torch.randn(2, 3, 224, 224)
model = MobileNetv1()
y = model(x)
print(y.shape)
-
MobileNetV2: Inverted Residuals and Linear Bottlenecks (CVPR 2018) pdf
import torch
from cnns.mobilenetv2 import MobileNetv2
x = torch.randn(2, 3, 224, 224)
model = MobileNetv2()
y = model(x)
print(y.shape)
-
Searching for MobileNetV3 (ICCV 2019) pdf
import torch
from cnns.mobilenetv3 import mobilenetv3_small
x = torch.randn(2, 3, 224, 224)
model = mobilenetv3_small()
y = model(x)
print(y.shape)
-
MnasNet: Platform-Aware Neural Architecture Search for Mobile (CVPR 2019) pdf
import torch
from cnns.mnasnet import MnasNet
x = torch.randn(2, 3, 224, 224)
model = MnasNet()
y = model(x)
print(y.shape)
-
EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks (ICML 2019) pdf
import torch
from cnns.efficientnet import EfficientNet
x = torch.randn(2, 3, 224, 224)
model = EfficientNet()
y = model(x)
print(y.shape)
-
Res2Net: A New Multi-scale Backbone Architecture (TPAMI 2019) pdf
import torch
from cnns.res2net import res2net50
x = torch.randn(2, 3, 224, 224)
model = res2net50()
y = model(x)
print(y.shape)
-
Rethinking Bottleneck Structure for Efficient Mobile Network Design (ECCV 2020) pdf
import torch
from cnns.mobilenext import MobileNeXt
x = torch.randn(2, 3, 224, 224)
model = MobileNeXt()
y = model(x)
print(y.shape)
-
GhostNet: More Features from Cheap Operations (CVPR 2020) pdf
import torch
from cnns.ghostnet import ghostnet
x = torch.randn(2, 3, 224, 224)
model = ghostnet()
y = model(x)
print(y.shape)
-
EfficientNetV2: Smaller Models and Faster Trainin (ICML 2021) pdf
import torch
from cnns.efficientnet import EfficientNetV2
x = torch.randn(2, 3, 224, 224)
model = EfficientNetV2()
y = model(x)
print(y.shape)
-
A ConvNet for the 2020s (CVPR 2022) pdf
import torch
from cnns.convnext import convnext_18
x = torch.randn(2, 3, 224, 224)
model = convnext_18()
y = model(x)
print(y.shape)
-
U-Net: Convolutional Networks for Biomedical Image Segmentation (MICCAI 2015) pdf
import torch
from cnns.unet import Unet
x = torch.randn(2, 3, 512, 512)
model = Unet(10)
y = model(x)
print(y.shape)
-
ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation (ECCV 2018) pdf
import torch
from cnns.espnet import ESPNet
x = torch.randn(2, 3, 512, 512)
model = ESPNet(10)
y = model(x)
print(y.shape)
-
MLP-Mixer: An all-MLP Architecture for Vision (NeurIPS 2021) pdf
import torch
from mlps.mlp_mixer import MLP_Mixer
x = torch.randn(2, 3, 224, 224)
model = MLP_Mixer()
y = model(x)
print(y.shape)
-
Pay Attention to MLPs (NeurIPS 2021) pdf
import torch
from mlps.gmlp import gMLP
x = torch.randn(2, 3, 224, 224)
model = gMLP()
y = model(x)
print(y.shape)
-
Global Filter Networks for Image Classification (NeurIPS 2021) pdf
import torch
from mlps.gfnet import GFNet
x = torch.randn(2, 3, 224, 224)
model = GFNet()
y = model(x)
print(y.shape)
-
Sparse MLP for Image Recognition: Is Self-Attention Really Necessary? (AAAI 2022) pdf
import torch
from mlps.smlp import sMLPNet
x = torch.randn(2, 3, 224, 224)
model = sMLPNet()
y = model(x)
print(y.shape)
-
DynaMixer: A Vision MLP Architecture with Dynamic Mixing (ICML 2022) pdf
import torch
from mlps.dynamixer import DynaMixer
x = torch.randn(2, 3, 224, 224)
model = DynaMixer()
y = model(x)
print(y.shape)
-
Patches Are All You Need? (TMLR 2022) pdf
import torch
from mlps.convmixer import ConvMixer
x = torch.randn(2, 3, 224, 224)
model = ConvMixer(128, 6)
y = model(x)
print(y.shape)
-
Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition (TPAMI 2022) pdf
import torch
from mlps.vip import vip_s7
x = torch.randn(2, 3, 224, 224)
model = vip_s7()
y = model(x)
print(y.shape)
-
CycleMLP: A MLP-like Architecture for Dense Prediction (ICLR 2022) pdf
import torch
from mlps.cyclemlp import CycleMLP_B1
x = torch.randn(2, 3, 224, 224)
model = CycleMLP_B1()
y = model(x)
print(y.shape)
-
Sequencer: Deep LSTM for Image Classification (NeurIPS 2022) pdf
import torch
from mlps.sequencer import sequencer_s
x = torch.randn(2, 3, 224, 224)
model = sequencer_s()
y = model(x)
print(y.shape)
-
MobileViG: Graph-Based Sparse Attention for Mobile Vision Applications (CVPRW 2023) pdf
import torch
from mlps.mobilevig import mobilevig_s
x = torch.randn(2, 3, 224, 224)
model = mobilevig_s()
y = model(x)
print(y.shape)