2 Star 1 Fork 0

Zehebi / MAG-CA-Net

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
predict.py 2.06 KB
一键复制 编辑 原始数据 按行查看 历史
Zehebi 提交于 2024-01-16 17:41 . initial
import os
os.getcwd()
from model.MAG_CA_Net import MAG_CA_Net
import os
import torch
from PIL import Image
from torchvision import transforms
import argparse
def test(img_path, model, device):
mean, std = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
_transforms = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
img = _transforms(img)
img = torch.unsqueeze(img, dim=0).to(device)
model.eval()
# predict class
outputs = model(img, PATH=img_path)
predict = torch.softmax(outputs, dim=0)
predict_cla = torch.argmax(predict).cpu().numpy()
print('predicted result: ', predict_cla)
def load_model(model, state_dict, device):
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-pretrained', default=False,type=bool, required=False, help='If True, returns a model pre-trained on ImageNet')
parser.add_argument('-progress', default=True,type=bool, required=False, help='If True, displays a progress bar of the download to stderr')
parser.add_argument('-load_weights', default=False, type=bool, required=False, help='')
parser.add_argument('-path', default='./checkpoint/best_acc_160.pth', type=str, required=False, help='预训练模型路径')
parser.add_argument('-cam_path', default='./checkpoint/best_loss.pth', type=str, required=False,
help='CAMPath')
parser.add_argument('-device', default='cuda:0', type=str, required=False,help='device')
args = parser.parse_args()
MAG_CA_Net = MAG_CA_Net(args)
params = torch.load(args.path)
model = load_model(MAG_CA_Net, params, args.device)
jpg_path = "./dataset/test.jpg"
test(jpg_path, model, args.device)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/HangbinZheng/mag-ca-net.git
git@gitee.com:HangbinZheng/mag-ca-net.git
HangbinZheng
mag-ca-net
MAG-CA-Net
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891