Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

moe&distribute #3

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified __pycache__/args.cpython-311.pyc
Binary file not shown.
Binary file modified __pycache__/my_dataset.cpython-311.pyc
Binary file not shown.
17 changes: 17 additions & 0 deletions aaa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from models.vit_moe import vit_base_patch16_224_in21k_multi_long_tail_MoE
import torch
model = vit_base_patch16_224_in21k_multi_long_tail_MoE(
embed_dim=512,
depth=4,
num_heads=2,
num_classes=2,
multi_tasks=6,long_tail=[True,False]*3,alpha=0.6,multi_gate=True).cuda()
a=torch.FloatTensor(4,256,512).cuda()
# model.eval()
out=model(a,task_id=None)
print(len(out),type(out),out.keys())
print(out[list(out.keys())[0]].size())
# print((out[0]==out[1]).sum()/out[1].numel())

# print(model.heads)
# print(model.heads)
42 changes: 22 additions & 20 deletions args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@ def set_type(x,target_type):
parser.add_argument('--lr-head', type=lambda x: set_type(x,float), default=0.00005)
parser.add_argument('--loss-weights',type=lambda x: set_type(x,float), default=1)

# 数据集所在根目录
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--model-name', default='', help='create model name')

# 预训练权重路径,如果不想载入就设置为空字符
parser.add_argument('--weights', type=str, default='',
help='initial weights path')
parser.add_argument('--freeze-layers', type=bool, default=False)
Expand All @@ -28,36 +23,41 @@ def set_type(x,target_type):
parser.add_argument('--base-path', default='/public_bme/data/jianght/datas/Pathology/class2')
parser.add_argument('--train-csv', default='train.csv')
parser.add_argument('--valid-csv', default='test.csv')
parser.add_argument('--positive-csv', default='test.csv')
parser.add_argument('--negative-csv', default='test.csv')
parser.add_argument('--head-idx', default=None, type=int)
parser.add_argument('--img-batch', default=100, type=int)

parser.add_argument('--img-batch', default=100, type=int,help=' image numbers of a sample')

parser.add_argument('--arch', default='resnet34')
parser.add_argument('--arch', default='resnet34',help='name of resnet,such as resnet34 or resnet50')
parser.add_argument('--res-weights', default='')
parser.add_argument('--res-savedir', default='./resnet')


parser.add_argument('--multi-tasks',type=int,default=2,help='num of tasks')
parser.add_argument('--num-classes', type=lambda x: set_type(x,int), default=2 ,help='an integer or a list of integers')
parser.add_argument('--loss-fns',type= lambda x: set_type(x,str), default='CELoss' ,help='a str or a list of strs')
parser.add_argument('--tasks',type= lambda x: set_type(x,str), default='fungus,label' ,help='a str or a list of tasks')
parser.add_argument('--long-tails',type= lambda x: set_type(x,str), default='False,False' ,help='whether use long_tails or not')
parser.add_argument('--alpha',type=lambda x: set_type(x,float),default=1,help='the value of alpha if long-tail')
parser.add_argument('--positive-csv', default='test.csv')
parser.add_argument('--negative-csv', default='test.csv')
parser.add_argument('--multi-tasks',type=int,default=2,help='num of multi-tasks')

parser.add_argument('--cont', action='store_true',help='need to count high positive or not')
parser.add_argument('--show-tasks',type=lambda x: set_type(x,int),default=None,help='index of tasks to show the resluts')
parser.add_argument('--needpatch', action='store_true')
parser.add_argument('--backbone',default='vit',choices=['vit','TransMIL','vit_res'])
parser.add_argument('--backbone',default='vit',choices=['vit','TransMIL','vit_res','vit_moe'])
parser.add_argument('--reduction', default='mean',choices=['mean','sum','none'])
parser.add_argument('--logdir',required=True)
parser.add_argument('--cont', action='store_true',help='need to count gaoyangxing or false')
parser.add_argument('--cont-task',type=int,default=1)
parser.add_argument('--show-tasks',type=lambda x: set_type(x,int),default=None,help='index of tasks to show the resluts')
parser.add_argument('--logdir',required=True,help='dir to save error log')




parser.add_argument('--depth',type=int,default=12)
parser.add_argument('--gate-dim',type=int,default=None)
parser.add_argument('--moe_experts',type=int,default=16)
parser.add_argument("--local-rank","--local_rank", help="local device id on current node",type=int,default=None)
parser.add_argument('--use_weight', action='store_true',help='need to balance the loss or not')
parser.add_argument('--moe_top_k',type=lambda x: set_type(x,int),default=4,help='top k of per task')


def init_args(args):
check_attrs = ['num_classes','loss_fns','tasks','long_tails','alpha','loss_weights','lr_head']
check_attrs = ['num_classes','loss_fns','tasks','long_tails','alpha','loss_weights','lr_head','moe_top_k']
for attr in check_attrs:
val = getattr(args,attr)
assert isinstance(val, (int, float, str,list)) , f'expect type of {attr} in [int,float,str] ,but get {val} {type(val)}'
Expand All @@ -79,6 +79,9 @@ def init_args(args):
args.negative_csv = os.path.join(args.base_path,args.negative_csv)

args.long_tails = [i == 'True' for i in args.long_tails] if isinstance(args.long_tails,list) else args.long_tails == 'True'
args.task_id = [i for i in range(args.multi_tasks) if args.lr_head[i]>1e-8]
print(args.task_id,'wdada')
# parser.add_argument('--task-id',type= lambda x: set_type(x,int),default=None)

if args.show_tasks is None:
args.show_tasks = list(range(args.multi_tasks))
Expand All @@ -88,7 +91,6 @@ def init_args(args):

def get_args():
args = parser.parse_args()

init_args(args)

return args
Expand Down
Binary file not shown.
Binary file not shown.
Loading