1 Star 0 Fork 1

Huterox / HuLook

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
classfiy.py 2.07 KB
一键复制 编辑 原始数据 按行查看 历史
Huterox 提交于 2022-08-01 15:28 . v1.0
import argparse
from PIL import Image
from Utils.DataSet.MyDataSet import MyDataSet
from Utils.DataSet.TransformAtions import TransFormAtions
import argparse
import torch
from torch.utils.data import DataLoader
from Models.FeatureNet import YOLOFeature
from Config.ConfigPre import *
import outProcessClassfiy
def detect():
ways = opt.valid_imgs
transformations = TransFormAtions()
net = YOLOFeature(Classes)
state_dict_load = torch.load(opt.path_state_dict)
net.load_state_dict(state_dict_load)
if(ways):
test_data = MyDataSet(data_dir=opt.valid_dir, transform=transformations.valid_transform,ClassesName=ClassesName)
valid_loader = DataLoader(dataset=test_data, batch_size=1)
net.eval()
with torch.no_grad():
for i, data in enumerate(valid_loader):
# forward
inputs, labels = data
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
# 输出处理器
outProcessClassfiy.Function(predicted.numpy()[0])
else:
#指定的是单张图片,少给我来奇奇怪怪的输入,这个版本容错很差滴!!!
path_img = opt.valid_dir
if(".jpg" not in path_img):
raise Exception("小爷打不开这图片")
image = Image.open(path_img)
image = transformations.valid_transform(image)
image = torch.reshape(image, (1, 3, 32, 32))
net.eval()
with torch.no_grad():
out = net(image)
outProcessClassfiy.Function(out.argmax(1).item())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# False表示识别单张图片,True表示多张图片,此时指定路径即可。
parser.add_argument('--valid_imgs',type=bool,default=False)
parser.add_argument('--valid_dir', type=str, default=r'F:\projects\PythonProject\HuLook\Data\PreData\train\猫羽雫\1.jpg')
parser.add_argument('--path_state_dict', type=str, default=r'runs\trainpre\epx0\weights\best.pth')
opt = parser.parse_args()
detect()
Python
1
https://gitee.com/Huterox/hu-look.git
git@gitee.com:Huterox/hu-look.git
Huterox
hu-look
HuLook
master

搜索帮助