代码拉取完成,页面将自动刷新
import argparse
import torch
from torchvision import transforms
import os
def get_args():
'''
:return: args for the traning log
'''
parser = argparse.ArgumentParser()
parser.add_argument("--py_name", default='train_relation_attention.py', type=str, help=" ")
parser.add_argument("--input_size", default=224, type=int, help=" ")
parser.add_argument("--data_root", default='/home/zhb/Desktop/experiment/EUS_bulk', type=str, help=" ")
parser.add_argument("--batch_size", default=16, type=int, help=" ")
parser.add_argument("--num_classes", default=3, type=int, help=" ")
parser.add_argument("--model_structure", default='RelationNet_s1', type=str, help="Model name")
parser.add_argument("--change_info", default='Add Relation attention module to Net', type=str, help="what has been changed change")
parser.add_argument("--dropout", default=0.0, type=float, help="Dropout rate.")
parser.add_argument("--log_path", default='./logs/relationNet', type=str, help="Path for saving logs")
parser.add_argument("--pth_path", default='../runs/relationNet', type=str, help="Path for saving pth")
parser.add_argument("--ssl_pth", default='../new_save/model1/G_error/netG_error.pth', type=str, help="Path for saving ssl pre-trained model")
parser.add_argument("--epochs", default=200, type=int, help="Total number of epochs.")
parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum.")
parser.add_argument("--threads", default=0, type=int, help="Number of CPU threads for dataloaders.")
parser.add_argument("--weight_decay", default=0.0005, type=float, help="L2 weight decay.")
parser.add_argument("--ssl_option", default=False, type=bool, help=" ")
parser.add_argument("--ngpu", default=0, type=int, help="the nth gpu used for trainig")
parser.add_argument("--resume_option", default=False, type=bool, help=" ")
parser.add_argument("--resume_pth", default=' ', type=str, help="Path for saving ssl pre-trained model")
parser.add_argument('--smoothing_value', type=float, default=0.0, help="Label smoothing value\n")
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16",
"ViT-L_32", "ViT-H_14"],
default="ViT-B_16",
help="Which variant to use.")
parser.add_argument('--split', type=str, default='non-overlap',
help="Split method")
parser.add_argument('--slide_step', type=int, default=12,
help="Slide step for overlap split")
parser.add_argument('-pretrained', default=False, type=bool, required=False,
help='If True, returns a model pre-trained on ImageNet')
parser.add_argument('-progress', default=True, type=bool, required=False,
help='If True, displays a progress bar of the download to stderr')
parser.add_argument('-load_weights', default=True, type=bool, required=False, help='')
parser.add_argument('-cam_path', default='./checkpoint/best_loss.pth', type=str, required=False,
help='CAMPath')
args = parser.parse_args()
return args
def test_kwargs(pth_name="pass.pth"):
'''
:return: **kwargs
'''
kwargs = {}
mean, std = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
args = get_args()
val_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(args.input_size),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
mean, std = [0.5], [0.5]
gray_transforms = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
kwargs['transforms'] = val_transform
kwargs['gray_transforms'] = gray_transforms
kwargs['test_model'] = pth_name
kwargs['num_class'] = args.num_classes
kwargs['batch_size'] = args.batch_size
kwargs['data_path'] = '/home/zhb/Desktop/experiment/EUS_bulk/test'
return kwargs
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。