2 Star 1 Fork 0

Zehebi / MAG-CA-Net

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
args.py 4.28 KB
一键复制 编辑 原始数据 按行查看 历史
Zehebi 提交于 2024-01-16 17:41 . initial
import argparse
import torch
from torchvision import transforms
import os
def get_args():
'''
:return: args for the traning log
'''
parser = argparse.ArgumentParser()
parser.add_argument("--py_name", default='train_relation_attention.py', type=str, help=" ")
parser.add_argument("--input_size", default=224, type=int, help=" ")
parser.add_argument("--data_root", default='/home/zhb/Desktop/experiment/EUS_bulk', type=str, help=" ")
parser.add_argument("--batch_size", default=16, type=int, help=" ")
parser.add_argument("--num_classes", default=3, type=int, help=" ")
parser.add_argument("--model_structure", default='RelationNet_s1', type=str, help="Model name")
parser.add_argument("--change_info", default='Add Relation attention module to Net', type=str, help="what has been changed change")
parser.add_argument("--dropout", default=0.0, type=float, help="Dropout rate.")
parser.add_argument("--log_path", default='./logs/relationNet', type=str, help="Path for saving logs")
parser.add_argument("--pth_path", default='../runs/relationNet', type=str, help="Path for saving pth")
parser.add_argument("--ssl_pth", default='../new_save/model1/G_error/netG_error.pth', type=str, help="Path for saving ssl pre-trained model")
parser.add_argument("--epochs", default=200, type=int, help="Total number of epochs.")
parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum.")
parser.add_argument("--threads", default=0, type=int, help="Number of CPU threads for dataloaders.")
parser.add_argument("--weight_decay", default=0.0005, type=float, help="L2 weight decay.")
parser.add_argument("--ssl_option", default=False, type=bool, help=" ")
parser.add_argument("--ngpu", default=0, type=int, help="the nth gpu used for trainig")
parser.add_argument("--resume_option", default=False, type=bool, help=" ")
parser.add_argument("--resume_pth", default=' ', type=str, help="Path for saving ssl pre-trained model")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16",
"ViT-L_32", "ViT-H_14"],
default="ViT-B_16",
help="Which variant to use.")
parser.add_argument('--split', type=str, default='non-overlap',
help="Split method")
parser.add_argument('--slide_step', type=int, default=12,
help="Slide step for overlap split")
parser.add_argument('-pretrained', default=False, type=bool, required=False,
help='If True, returns a model pre-trained on ImageNet')
parser.add_argument('-progress', default=True, type=bool, required=False,
help='If True, displays a progress bar of the download to stderr')
parser.add_argument('-load_weights', default=True, type=bool, required=False, help='')
parser.add_argument('-cam_path', default='./checkpoint/best_loss.pth', type=str, required=False,
help='CAMPath')
args = parser.parse_args()
return args
def test_kwargs(pth_name="pass.pth"):
'''
:return: **kwargs
'''
kwargs = {}
mean, std = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
args = get_args()
val_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(args.input_size),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
mean, std = [0.5], [0.5]
gray_transforms = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
kwargs['transforms'] = val_transform
kwargs['gray_transforms'] = gray_transforms
kwargs['test_model'] = pth_name
kwargs['num_class'] = args.num_classes
kwargs['batch_size'] = args.batch_size
kwargs['data_path'] = '/home/zhb/Desktop/experiment/EUS_bulk/test'
return kwargs
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/HangbinZheng/mag-ca-net.git
git@gitee.com:HangbinZheng/mag-ca-net.git
HangbinZheng
mag-ca-net
MAG-CA-Net
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891