1 Star 0 Fork 1

Huterox / HuLook

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
detect.py 5.28 KB
一键复制 编辑 原始数据 按行查看 历史
Huterox 提交于 2022-08-01 15:28 . v1.0
import cv2
import torchvision.transforms as transforms
from Models.Yolo import YOLO
import argparse
import torch
from Config.ConfigTrain import *
import numpy as np
from PIL import Image,ImageDraw,ImageFont
def iou(box_one, box_two):
LX = max(box_one[0], box_two[0])
LY = max(box_one[1], box_two[1])
RX = min(box_one[2], box_two[2])
RY = min(box_one[3], box_two[3])
if LX >= RX or LY >= RY:
return 0
return (RX - LX) * (RY - LY) / ((box_one[2]-box_one[0]) * (box_one[3] - box_one[1]) + (box_two[2]-box_two[0]) * (box_two[3] - box_two[1]))
def NMS(bounding_boxes,S=7,B=2,img_size=448,confidence_threshold=0.5,iou_threshold=0.0,possible_pred=0.4):
bounding_boxes = bounding_boxes.cpu().detach().numpy().tolist()
predict_boxes = []
nms_boxes = []
grid_size = img_size / S
for batch in range(len(bounding_boxes)):
for i in range(S):
for j in range(S):
gridX = grid_size * j
gridY = grid_size * i
if bounding_boxes[batch][i][j][4] < bounding_boxes[batch][i][j][9]:
bounding_box = bounding_boxes[batch][i][j][5:10]
else:
bounding_box = bounding_boxes[batch][i][j][0:5]
class_possible = (bounding_boxes[batch][i][j][10:])
bounding_box.extend(class_possible)
possible = max(class_possible)
if (bounding_box[4] < confidence_threshold
):
continue
if(bounding_box[4]*possible < possible_pred):
continue
# print(bounding_box[4]*possible)
centerX = (int)(gridX + bounding_box[0] * grid_size)
centerY = (int)(gridY + bounding_box[1] * grid_size)
width = (int)(bounding_box[2] * img_size)
height = (int)(bounding_box[3] * img_size)
bounding_box[0] = max(0, (int)(centerX - width / 2))
bounding_box[1] = max(0, (int)(centerY - height / 2))
bounding_box[2] = min(img_size - 1, (int)(centerX + width / 2))
bounding_box[3] = min(img_size - 1, (int)(centerY + height / 2))
predict_boxes.append(bounding_box)
while len(predict_boxes) != 0:
predict_boxes.sort(key=lambda box:box[4])
assured_box = predict_boxes[0]
temp = []
classIndex = np.argmax(assured_box[5:])
#print("类别:{}".format(ClassesName[classIndex))
assured_box[4] = assured_box[4] * assured_box[5 + classIndex]
#修正置信度为 物体分类准确度 × 含有物体的置信度
assured_box[5] = classIndex
nms_boxes.append(assured_box)
i = 1
while i < len(predict_boxes):
if iou(assured_box,predict_boxes[i]) <= iou_threshold:
temp.append(predict_boxes[i])
i = i + 1
predict_boxes = temp
return nms_boxes
def detect():
transform = transforms.Compose([
transforms.ToTensor(), # height * width * channel -> channel * height * width
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
image_dir = opt.valid_dir
img_data = cv2.imread(image_dir)
img_data = cv2.resize(img_data, (448, 448), interpolation=cv2.INTER_AREA)
train_data = transform(img_data)
train_data = train_data.unsqueeze(0)
net = YOLO(B=2,classes_num=Classes)
state_dict_load = torch.load(opt.path_state_dict)
net.load_state_dict(state_dict_load)
net.eval()
with torch.no_grad():
bounding_boxes = net(train_data)
NMS_boxes = NMS(bounding_boxes,confidence_threshold=opt.confidence,iou_threshold=opt.iou,possible_pred=opt.possible_pre)
"""
如果要拆的话把这个封装好就完了
"""
font = ImageFont.truetype(r'font/simsun.ttc', 20, encoding='utf-8')
for box in NMS_boxes:
img_data = cv2.rectangle(img_data, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 1)
"""
处理中文
"""
pil_img = Image.fromarray(cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_img)
draw.text((box[0], box[1]),"{}:{}".format(ClassesName[box[5]], round(box[4], 2)),(148,175,100),font)
print("class_name:{} confidence:{}".format(ClassesName[int(box[5])],round(box[4],2)))
img_data = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
if(opt.show_img):
cv2.imshow("img_detection", img_data)
cv2.waitKey()
cv2.destroyAllWindows()
if(opt.save_dir):
cv2.imwrite(opt.save_dir, img_data)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--valid_dir', type=str, default=r'F:\projects\PythonProject\HuLook\Data\DetData\train\images\002.jpg')
parser.add_argument('--path_state_dict', type=str, default=r'F:\projects\PythonProject\HuLook\runs\traindetect\epx0\weights\best.pth')
parser.add_argument("--iou",type=float,default=0.2)
parser.add_argument("--confidence",type=float,default=0.5)
parser.add_argument("--possible_pre",type=float,default=0.35)
parser.add_argument("--show_img",type=bool,default=True)
parser.add_argument("--save_dir",type=str,default="")
opt = parser.parse_args()
detect()
Python
1
https://gitee.com/Huterox/hu-look.git
git@gitee.com:Huterox/hu-look.git
Huterox
hu-look
HuLook
master

搜索帮助