代码拉取完成,页面将自动刷新
#!/usr/bin/env python3
# coding: utf-8
import os.path as osp
from pathlib import Path
import numpy as np
import argparse
import time
import logging
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import mobilenet_v1
import torch.backends.cudnn as cudnn
from utils.ddfa import DDFADataset, ToTensorGjz, NormalizeGjz
from utils.ddfa import str2bool, AverageMeter
from utils.io import mkdir
from vdc_loss import VDCLoss
from wpdc_loss import WPDCLoss
# global args (configuration)
args = None
lr = None
arch_choices = ['mobilenet_2', 'mobilenet_1', 'mobilenet_075', 'mobilenet_05', 'mobilenet_025']
def parse_args():
parser = argparse.ArgumentParser(description='3DMM Fitting')
parser.add_argument('-j', '--workers', default=6, type=int)
parser.add_argument('--epochs', default=40, type=int)
parser.add_argument('--start-epoch', default=1, type=int)
parser.add_argument('-b', '--batch-size', default=128, type=int)
parser.add_argument('-vb', '--val-batch-size', default=32, type=int)
parser.add_argument('--base-lr', '--learning-rate', default=0.001, type=float)
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float)
parser.add_argument('--print-freq', '-p', default=20, type=int)
parser.add_argument('--resume', default='', type=str, metavar='PATH')
parser.add_argument('--devices-id', default='0,1', type=str)
parser.add_argument('--filelists-train',
default='', type=str)
parser.add_argument('--filelists-val',
default='', type=str)
parser.add_argument('--root', default='')
parser.add_argument('--snapshot', default='', type=str)
parser.add_argument('--log-file', default='output.log', type=str)
parser.add_argument('--log-mode', default='w', type=str)
parser.add_argument('--size-average', default='true', type=str2bool)
parser.add_argument('--num-classes', default=62, type=int)
parser.add_argument('--arch', default='mobilenet_1', type=str,
choices=arch_choices)
parser.add_argument('--frozen', default='false', type=str2bool)
parser.add_argument('--milestones', default='15,25,30', type=str)
parser.add_argument('--task', default='all', type=str)
parser.add_argument('--test_initial', default='false', type=str2bool)
parser.add_argument('--warmup', default=-1, type=int)
parser.add_argument('--param-fp-train',
default='',
type=str)
parser.add_argument('--param-fp-val',
default='')
parser.add_argument('--opt-style', default='resample', type=str) # resample
parser.add_argument('--resample-num', default=132, type=int)
parser.add_argument('--loss', default='vdc', type=str)
global args
args = parser.parse_args()
# some other operations
args.devices_id = [int(d) for d in args.devices_id.split(',')]
args.milestones = [int(m) for m in args.milestones.split(',')]
snapshot_dir = osp.split(args.snapshot)[0]
mkdir(snapshot_dir)
def print_args(args):
for arg in vars(args):
s = arg + ': ' + str(getattr(args, arg))
logging.info(s)
def adjust_learning_rate(optimizer, epoch, milestones=None):
"""Sets the learning rate: milestone is a list/tuple"""
def to(epoch):
if epoch <= args.warmup:
return 1
elif args.warmup < epoch <= milestones[0]:
return 0
for i in range(1, len(milestones)):
if milestones[i - 1] < epoch <= milestones[i]:
return i
return len(milestones)
n = to(epoch)
global lr
lr = args.base_lr * (0.2 ** n)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def save_checkpoint(state, filename='checkpoint.pth.tar'):
torch.save(state, filename)
logging.info(f'Save checkpoint to {filename}')
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
model.train()
end = time.time()
# loader is batch style
# for i, (input, target) in enumerate(train_loader):
for i, (input, target) in enumerate(train_loader):
target.requires_grad = False
target = target.cuda(non_blocking=True)
output = model(input)
data_time.update(time.time() - end)
if args.loss.lower() == 'vdc':
loss = criterion(output, target)
elif args.loss.lower() == 'wpdc':
loss = criterion(output, target)
elif args.loss.lower() == 'pdc':
loss = criterion(output, target)
else:
raise Exception(f'Unknown loss {args.loss}')
losses.update(loss.item(), input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# log
if i % args.print_freq == 0:
logging.info(f'Epoch: [{epoch}][{i}/{len(train_loader)}]\t'
f'LR: {lr:8f}\t'
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
# f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
f'Loss {losses.val:.4f} ({losses.avg:.4f})')
def validate(val_loader, model, criterion, epoch):
model.eval()
end = time.time()
with torch.no_grad():
losses = []
for i, (input, target) in enumerate(val_loader):
# compute output
target.requires_grad = False
target = target.cuda(non_blocking=True)
output = model(input)
loss = criterion(output, target)
losses.append(loss)
elapse = time.time() - end
loss = np.mean(losses)
logging.info(f'Val: [{epoch}][{len(val_loader)}]\t'
f'Loss {loss:.4f}\t'
f'Time {elapse:.3f}')
def main():
parse_args() # parse global argsl
# logging setup
logging.basicConfig(
format='[%(asctime)s] [p%(process)s] [%(pathname)s:%(lineno)d] [%(levelname)s] %(message)s',
level=logging.INFO,
handlers=[
logging.FileHandler(args.log_file, mode=args.log_mode),
logging.StreamHandler()
]
)
print_args(args) # print args
# step1: define the model structure
model = getattr(mobilenet_v1, args.arch)(num_classes=args.num_classes)
torch.cuda.set_device(args.devices_id[0]) # fix bug for `ERROR: all tensors must be on devices[0]`
model = nn.DataParallel(model, device_ids=args.devices_id).cuda() # -> GPU
# step2: optimization: loss and optimization method
# criterion = nn.MSELoss(size_average=args.size_average).cuda()
if args.loss.lower() == 'wpdc':
print(args.opt_style)
criterion = WPDCLoss(opt_style=args.opt_style).cuda()
logging.info('Use WPDC Loss')
elif args.loss.lower() == 'vdc':
criterion = VDCLoss(opt_style=args.opt_style).cuda()
logging.info('Use VDC Loss')
elif args.loss.lower() == 'pdc':
criterion = nn.MSELoss(size_average=args.size_average).cuda()
logging.info('Use PDC loss')
else:
raise Exception(f'Unknown Loss {args.loss}')
optimizer = torch.optim.SGD(model.parameters(),
lr=args.base_lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=True)
# step 2.1 resume
if args.resume:
if Path(args.resume).is_file():
logging.info(f'=> loading checkpoint {args.resume}')
checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)['state_dict']
# checkpoint = torch.load(args.resume)['state_dict']
model.load_state_dict(checkpoint)
else:
logging.info(f'=> no checkpoint found at {args.resume}')
# step3: data
normalize = NormalizeGjz(mean=127.5, std=128) # may need optimization
train_dataset = DDFADataset(
root=args.root,
filelists=args.filelists_train,
param_fp=args.param_fp_train,
transform=transforms.Compose([ToTensorGjz(), normalize])
)
val_dataset = DDFADataset(
root=args.root,
filelists=args.filelists_val,
param_fp=args.param_fp_val,
transform=transforms.Compose([ToTensorGjz(), normalize])
)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
shuffle=True, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.val_batch_size, num_workers=args.workers,
shuffle=False, pin_memory=True)
# step4: run
cudnn.benchmark = True
if args.test_initial:
logging.info('Testing from initial')
validate(val_loader, model, criterion, args.start_epoch)
for epoch in range(args.start_epoch, args.epochs + 1):
# adjust learning rate
adjust_learning_rate(optimizer, epoch, args.milestones)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch)
filename = f'{args.snapshot}_checkpoint_epoch_{epoch}.pth.tar'
save_checkpoint(
{
'epoch': epoch,
'state_dict': model.state_dict(),
# 'optimizer': optimizer.state_dict()
},
filename
)
validate(val_loader, model, criterion, epoch)
if __name__ == '__main__':
main()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。