-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CodeCamp2023-340] New Version of config Adapting MobileNet Algorithm (…
…#1774) * add new config adapting MobileNetV2,V3 * add base model config for mobile net v3, modified all training configs of mobile net v3 inherit from the base model config * removed directory _base_/models/mobilenet_v3
- Loading branch information
Showing
13 changed files
with
554 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# This is a BETA new format config file, and the usage may change recently. | ||
from mmengine.dataset import DefaultSampler | ||
|
||
from mmpretrain.datasets import CIFAR10, PackInputs, RandomCrop, RandomFlip | ||
from mmpretrain.evaluation import Accuracy | ||
|
||
# dataset settings | ||
dataset_type = CIFAR10 | ||
data_preprocessor = dict( | ||
num_classes=10, | ||
# RGB format normalization parameters | ||
mean=[125.307, 122.961, 113.8575], | ||
std=[51.5865, 50.847, 51.255], | ||
# loaded images are already RGB format | ||
to_rgb=False) | ||
|
||
train_pipeline = [ | ||
dict(type=RandomCrop, crop_size=32, padding=4), | ||
dict(type=RandomFlip, prob=0.5, direction='horizontal'), | ||
dict(type=PackInputs), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type=PackInputs), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=16, | ||
num_workers=2, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/cifar10', | ||
split='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=True), | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=16, | ||
num_workers=2, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/cifar10/', | ||
split='test', | ||
pipeline=test_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
) | ||
val_evaluator = dict(type=Accuracy, topk=(1, )) | ||
|
||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# This is a BETA new format config file, and the usage may change recently. | ||
from mmengine.dataset import DefaultSampler | ||
|
||
from mmpretrain.datasets import (AutoAugment, CenterCrop, ImageNet, | ||
LoadImageFromFile, PackInputs, RandomErasing, | ||
RandomFlip, RandomResizedCrop, ResizeEdge) | ||
from mmpretrain.evaluation import Accuracy | ||
|
||
# dataset settings | ||
dataset_type = ImageNet | ||
data_preprocessor = dict( | ||
num_classes=1000, | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
to_rgb=True, | ||
) | ||
|
||
bgr_mean = data_preprocessor['mean'][::-1] | ||
bgr_std = data_preprocessor['std'][::-1] | ||
|
||
train_pipeline = [ | ||
dict(type=LoadImageFromFile), | ||
dict(type=RandomResizedCrop, scale=224, backend='pillow'), | ||
dict(type=RandomFlip, prob=0.5, direction='horizontal'), | ||
dict( | ||
type=AutoAugment, | ||
policies='imagenet', | ||
hparams=dict(pad_val=[round(x) for x in bgr_mean])), | ||
dict( | ||
type=RandomErasing, | ||
erase_prob=0.2, | ||
mode='rand', | ||
min_area_ratio=0.02, | ||
max_area_ratio=1 / 3, | ||
fill_color=bgr_mean, | ||
fill_std=bgr_std), | ||
dict(type=PackInputs), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type=LoadImageFromFile), | ||
dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'), | ||
dict(type=CenterCrop, crop_size=224), | ||
dict(type=PackInputs), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=128, | ||
num_workers=5, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
split='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=True), | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=128, | ||
num_workers=5, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
split='val', | ||
pipeline=test_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
) | ||
val_evaluator = dict(type=Accuracy, topk=(1, 5)) | ||
|
||
# If you want standard test, please manually configure the test dataset | ||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator |
60 changes: 60 additions & 0 deletions
60
mmpretrain/configs/_base_/datasets/imagenet_bs32_pil_resize.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# This is a BETA new format config file, and the usage may change recently. | ||
from mmengine.dataset import DefaultSampler | ||
|
||
from mmpretrain.datasets import (CenterCrop, ImageNet, LoadImageFromFile, | ||
PackInputs, RandomFlip, RandomResizedCrop, | ||
ResizeEdge) | ||
from mmpretrain.evaluation import Accuracy | ||
|
||
# dataset settings | ||
dataset_type = ImageNet | ||
data_preprocessor = dict( | ||
num_classes=1000, | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
to_rgb=True, | ||
) | ||
|
||
train_pipeline = [ | ||
dict(type=LoadImageFromFile), | ||
dict(type=RandomResizedCrop, scale=224, backend='pillow'), | ||
dict(type=RandomFlip, prob=0.5, direction='horizontal'), | ||
dict(type=PackInputs), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type=LoadImageFromFile), | ||
dict(type=ResizeEdge, scale=256, edge='short', backend='pillow'), | ||
dict(type=CenterCrop, crop_size=224), | ||
dict(type=PackInputs), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=32, | ||
num_workers=5, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
split='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=True), | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=32, | ||
num_workers=5, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
split='val', | ||
pipeline=test_pipeline), | ||
sampler=dict(type=DefaultSampler, shuffle=False), | ||
) | ||
val_evaluator = dict(type=Accuracy, topk=(1, 5)) | ||
|
||
# If you want standard test, please manually configure the test dataset | ||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# This is a BETA new format config file, and the usage may change recently. | ||
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, | ||
ImageClassifier, LinearClsHead, MobileNetV2) | ||
|
||
# model settings | ||
model = dict( | ||
type=ImageClassifier, | ||
backbone=dict(type=MobileNetV2, widen_factor=1.0), | ||
neck=dict(type=GlobalAveragePooling), | ||
head=dict( | ||
type=LinearClsHead, | ||
num_classes=1000, | ||
in_channels=1280, | ||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0), | ||
topk=(1, 5), | ||
)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# This is a BETA new format config file, and the usage may change recently. | ||
from mmengine.model.weight_init import NormalInit | ||
from torch.nn.modules.activation import Hardswish | ||
|
||
from mmpretrain.models import (CrossEntropyLoss, GlobalAveragePooling, | ||
ImageClassifier, MobileNetV3, | ||
StackedLinearClsHead) | ||
|
||
# model settings | ||
model = dict( | ||
type=ImageClassifier, | ||
backbone=dict(type=MobileNetV3, arch='small'), | ||
neck=dict(type=GlobalAveragePooling), | ||
head=dict( | ||
type=StackedLinearClsHead, | ||
num_classes=1000, | ||
in_channels=576, | ||
mid_channels=[1024], | ||
dropout_rate=0.2, | ||
act_cfg=dict(type=Hardswish), | ||
loss=dict(type=CrossEntropyLoss, loss_weight=1.0), | ||
init_cfg=dict( | ||
type=NormalInit, layer='Linear', mean=0., std=0.01, bias=0.), | ||
topk=(1, 5))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# This is a BETA new format config file, and the usage may change recently. | ||
from mmengine.optim import MultiStepLR | ||
from torch.optim import SGD | ||
|
||
# optimizer | ||
optim_wrapper = dict( | ||
optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001)) | ||
# learning policy | ||
param_scheduler = dict( | ||
type=MultiStepLR, by_epoch=True, milestones=[100, 150], gamma=0.1) | ||
|
||
# train, val, test setting | ||
train_cfg = dict(by_epoch=True, max_epochs=200, val_interval=1) | ||
val_cfg = dict() | ||
test_cfg = dict() | ||
|
||
# NOTE: `auto_scale_lr` is for automatically scaling LR | ||
# based on the actual training batch size. | ||
auto_scale_lr = dict(base_batch_size=128) |
20 changes: 20 additions & 0 deletions
20
mmpretrain/configs/_base_/schedules/imagenet_bs256_epochstep.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# This is a BETA new format config file, and the usage may change recently. | ||
from mmengine.optim import StepLR | ||
from torch.optim import SGD | ||
|
||
# optimizer | ||
optim_wrapper = dict( | ||
optimizer=dict(type=SGD, lr=0.045, momentum=0.9, weight_decay=0.00004)) | ||
|
||
# learning policy | ||
param_scheduler = dict(type=StepLR, by_epoch=True, step_size=1, gamma=0.98) | ||
|
||
# train, val, test setting | ||
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1) | ||
val_cfg = dict() | ||
test_cfg = dict() | ||
|
||
# NOTE: `auto_scale_lr` is for automatically scaling LR, | ||
# based on the actual training batch size. | ||
auto_scale_lr = dict(base_batch_size=256) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# This is a BETA new format config file, and the usage may change recently. | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .._base_.datasets.imagenet_bs32_pil_resize import * | ||
from .._base_.default_runtime import * | ||
from .._base_.models.mobilenet_v2_1x import * | ||
from .._base_.schedules.imagenet_bs256_epochstep import * |
40 changes: 40 additions & 0 deletions
40
mmpretrain/configs/mobilenet_v3/mobilenet_v3_large_8xb128_in1k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
# This is a BETA new format config file, and the usage may change recently. | ||
|
||
# Refers to https://pytorch.org/blog/ml-models-torchvision-v0.9/#classification | ||
from mmengine.config import read_base | ||
|
||
with read_base(): | ||
from .._base_.models.mobilenet_v3_small import * | ||
from .._base_.datasets.imagenet_bs128_mbv3 import * | ||
from .._base_.default_runtime import * | ||
|
||
from mmengine.optim import StepLR | ||
from torch.optim import RMSprop | ||
|
||
# model settings | ||
model.merge( | ||
dict( | ||
backbone=dict(arch='large'), | ||
head=dict(in_channels=960, mid_channels=[1280]), | ||
)) | ||
# schedule settings | ||
optim_wrapper = dict( | ||
optimizer=dict( | ||
type=RMSprop, | ||
lr=0.064, | ||
alpha=0.9, | ||
momentum=0.9, | ||
eps=0.0316, | ||
weight_decay=1e-5)) | ||
|
||
param_scheduler = dict(type=StepLR, by_epoch=True, step_size=2, gamma=0.973) | ||
|
||
train_cfg = dict(by_epoch=True, max_epochs=600, val_interval=1) | ||
val_cfg = dict() | ||
test_cfg = dict() | ||
|
||
# NOTE: `auto_scale_lr` is for automatically scaling LR | ||
# based on the actual training batch size. | ||
# base_batch_size = (8 GPUs) x (128 samples per GPU) | ||
auto_scale_lr = dict(base_batch_size=1024) |
Oops, something went wrong.