1 Star 0 Fork 1

Huterox / HuLook

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
preTrain.py 8.08 KB
一键复制 编辑 原始数据 按行查看 历史
Huterox 提交于 2022-08-01 15:51 . v1.1
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from Models.FeatureNet import YOLOFeature
from Utils import ModelUtils
from Config.ConfigPre import *
from Utils.DataSet.MyDataSet import MyDataSet
from Utils.DataSet.TransformAtions import TransFormAtions
import os
from Utils import SaveModel
from Utils import Log
from torch.utils.tensorboard import SummaryWriter
def train():
ModelUtils.set_seed()
# 初始化驱动
device = None
if (torch.cuda.is_available()):
if (not opt.device == 'cpu'):
div = "cuda:" + opt.device
# 这边后面还得做一个检测,看看有没有坑货,乱输入
device = torch.device(div)
print("\033[0;31;0m使用GPU训练中:{}\033[0m".format(torch.cuda.get_device_name()))
else:
device = torch.device("cpu")
print("\033[0;31;40m使用CPU训练\033[0m")
else:
device = torch.device("cpu")
print("\033[0;31;40m使用CPU训练\033[0m")
# 创建 runs exp 文件
EPX_Path = SaveModel.CreatRun(0,"pre")
# 日志相关的准备工作
wirter = None
openTensorboard = opt.tensorboardopen
path_board = None
if (openTensorboard):
path_board = EPX_Path + "\\logs"
wirter = SummaryWriter(path_board)
fo = Log.PrintLog(EPX_Path)
# 准备数据集
transformations = TransFormAtions()
train_data_dir = opt.train_dir
if (not train_data_dir):
train_data_dir = Data_Root + "\\" + Train
if (not os.path.exists(train_data_dir)):
raise Exception("训练集路径错误")
train_data = MyDataSet(data_dir=train_data_dir, transform=transformations.train_transform,ClassesName=ClassesName)
valid_data_dir = opt.valid_dir
if (not valid_data_dir):
valid_data_dir = Data_Root + "\\" + Valid
if (not os.path.exists(valid_data_dir)):
raise Exception("测试集路径错误")
valid_data = MyDataSet(data_dir=valid_data_dir, transform=transformations.valid_transform,ClassesName=ClassesName)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=opt.batch_size, num_workers=opt.works, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=opt.batch_size)
# 开始进入网络训练
# 1 开始初始化网络,设置参数啥的
# 1.1 初始化网络
net = YOLOFeature(Classes)
net.initialize_weights()
net = net.to(device)
# 1.2选择交叉熵损失函数,做分类问题一般是选择这个损失函数的
criterion = nn.CrossEntropyLoss()
# 1.3设置优化器
optimizer = optim.SGD(net.parameters(), lr=opt.lr, momentum=0.09) # 选择优化器
# 设置学习率下降策略,默认的也可以,那就不设置嘛,主要是不断去自动调整学习的那个速度
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.01)
# 2 开始进入训练步骤
# 2.1 进入网络训练
Best_weight = None
Best_Acc = 0.0
val_time = 0
for epoch in range(opt.epochs):
loss_mean = 0.0
correct = 0.0
total = 0.0
current_Acc_ecpho = 0.0
bacth_index = 0.
net.train()
print("正在进行第{}轮训练".format(epoch + 1))
for i, data in enumerate(train_loader):
bacth_index+=1
# forward
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# print(inputs.shape,labels.shape)
outputs = net(inputs)
# print(outputs.shape, labels.shape)
# backward
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
# update weights
optimizer.step()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().sum()
# 打印训练信息,进入对比
loss_mean += loss.item()
current_Acc = correct / total
current_Acc_ecpho+=current_Acc
if (i + 1) % opt.log_interval == 0:
loss_mean = loss_mean / opt.log_interval
info = "训练:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}" \
.format \
(
epoch, opt.epochs, i + 1, len(train_loader), loss_mean, current_Acc
)
print(info, file=fo)
if (opt.show_log_console):
info_print = "\033[0;33;0m" + info + "\033[0m"
print(info_print)
loss_mean = 0.0
# tensorboard 绘图
if (wirter):
wirter.add_scalar("训练准确率", current_Acc_ecpho, (epoch))
wirter.add_scalar("训练损失均值", loss_mean, (epoch))
current_Acc_ecpho/=bacth_index
# 保存效果最好的玩意
if (current_Acc_ecpho > Best_Acc):
Best_weight = net.state_dict()
Best_Acc = current_Acc_ecpho
scheduler.step() # 更新学习率
# 2.2 进入训练对比阶段
if (epoch + 1) % opt.val_interval == 0:
correct_val = 0.0
total_val = 0.0
loss_val = 0.0
current_Acc_val = 0.0
current_Acc_ecpho_val = 0.
batch_index_val = 0.0
net.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
batch_index_val+=1
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = criterion(outputs, labels)
loss_val += loss.item()
_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().sum()
current_Acc_val = correct_val / total_val
current_Acc_ecpho_val+=current_Acc_val
info_val = "测试:\tEpoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format \
(
epoch, opt.epochs, j + 1, len(valid_loader), loss_val, current_Acc_val
)
print(info_val, file=fo)
if (opt.show_log_console):
info_print_val = "\033[0;31;0m" + info_val + "\033[0m"
print(info_print_val)
current_Acc_ecpho_val/=batch_index_val
if (wirter):
wirter.add_scalar("测试准确率", current_Acc_ecpho_val, (val_time))
wirter.add_scalar("测试损失总值", loss_val, (val_time))
val_time+=1
# 最后一次的权重
Last_weight = net.state_dict()
# 保存模型
SaveModel.Save_Model(EPX_Path, Best_weight, Last_weight)
fo.close()
if (wirter):
print("tensorboard dir is:", path_board)
wirter.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--log_interval', type=int, default=10)
# 训练几轮测试一次
parser.add_argument('--val_interval', type=int, default=1)
parser.add_argument('--train_dir', type=str, default='')
parser.add_argument('--valid_dir', type=str, default='')
# 如果是Mac系注意这个参数可能需要设置为1,本地训练,不推荐MAC
parser.add_argument('--works', type=int, default=2)
parser.add_argument('--show_log_console', type=bool, default=True)
parser.add_argument('--device', type=str, default="0", help="默认使用显卡加速训练参数选择:0,1,2...or cpu")
parser.add_argument('--tensorboardopen', type=bool, default=True)
opt = parser.parse_args()
train()
# tensorboard --logdir=runs/trainpre/epx0/logs
Python
1
https://gitee.com/Huterox/hu-look.git
git@gitee.com:Huterox/hu-look.git
Huterox
hu-look
HuLook
master

搜索帮助