代码拉取完成,页面将自动刷新
import os
os.getcwd()
from model.MAG_CA_Net import MAG_CA_Net
import os
import torch
from PIL import Image
from torchvision import transforms
import argparse
def test(img_path, model, device):
mean, std = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
_transforms = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
img = _transforms(img)
img = torch.unsqueeze(img, dim=0).to(device)
model.eval()
# predict class
outputs = model(img, PATH=img_path)
predict = torch.softmax(outputs, dim=0)
predict_cla = torch.argmax(predict).cpu().numpy()
print('predicted result: ', predict_cla)
def load_model(model, state_dict, device):
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-pretrained', default=False,type=bool, required=False, help='If True, returns a model pre-trained on ImageNet')
parser.add_argument('-progress', default=True,type=bool, required=False, help='If True, displays a progress bar of the download to stderr')
parser.add_argument('-load_weights', default=False, type=bool, required=False, help='')
parser.add_argument('-path', default='./checkpoint/best_acc_160.pth', type=str, required=False, help='预训练模型路径')
parser.add_argument('-cam_path', default='./checkpoint/best_loss.pth', type=str, required=False,
help='CAMPath')
parser.add_argument('-device', default='cuda:0', type=str, required=False,help='device')
args = parser.parse_args()
MAG_CA_Net = MAG_CA_Net(args)
params = torch.load(args.path)
model = load_model(MAG_CA_Net, params, args.device)
jpg_path = "./dataset/test.jpg"
test(jpg_path, model, args.device)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。