1 Star 0 Fork 0

米斯特龙 / semantic-segmentation-pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
test.py 5.73 KB
一键复制 编辑 原始数据 按行查看 历史
# System libs
import os
import argparse
from distutils.version import LooseVersion
# Numerical libs
import numpy as np
import torch
import torch.nn as nn
from scipy.io import loadmat
import csv
# Our libs
from mit_semseg.dataset import TestDataset
from mit_semseg.models import ModelBuilder, SegmentationModule
from mit_semseg.utils import colorEncode, find_recursive, setup_logger
from mit_semseg.lib.nn import user_scattered_collate, async_copy_to
from mit_semseg.lib.utils import as_numpy
from PIL import Image
from tqdm import tqdm
from mit_semseg.config import cfg
colors = loadmat('data/color150.mat')['colors']
names = {}
with open('data/object150_info.csv') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
names[int(row[0])] = row[5].split(";")[0]
def visualize_result(data, pred, cfg):
(img, info) = data
# print predictions in descending order
pred = np.int32(pred)
pixs = pred.size
uniques, counts = np.unique(pred, return_counts=True)
print("Predictions in [{}]:".format(info))
for idx in np.argsort(counts)[::-1]:
name = names[uniques[idx] + 1]
ratio = counts[idx] / pixs * 100
if ratio > 0.1:
print(" {}: {:.2f}%".format(name, ratio))
# colorize prediction
pred_color = colorEncode(pred, colors).astype(np.uint8)
# aggregate images and save
im_vis = np.concatenate((img, pred_color), axis=1)
img_name = info.split('/')[-1]
Image.fromarray(im_vis).save(
os.path.join(cfg.TEST.result, img_name.replace('.jpg', '.png')))
def test(segmentation_module, loader, gpu):
segmentation_module.eval()
pbar = tqdm(total=len(loader))
for batch_data in loader:
# process data
batch_data = batch_data[0]
segSize = (batch_data['img_ori'].shape[0],
batch_data['img_ori'].shape[1])
img_resized_list = batch_data['img_data']
with torch.no_grad():
scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1])
scores = async_copy_to(scores, gpu)
for img in img_resized_list:
feed_dict = batch_data.copy()
feed_dict['img_data'] = img
del feed_dict['img_ori']
del feed_dict['info']
feed_dict = async_copy_to(feed_dict, gpu)
# forward pass
pred_tmp = segmentation_module(feed_dict, segSize=segSize)
scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)
_, pred = torch.max(scores, dim=1)
pred = as_numpy(pred.squeeze(0).cpu())
# visualization
visualize_result(
(batch_data['img_ori'], batch_data['info']),
pred,
cfg
)
pbar.update(1)
def main(cfg, gpu):
torch.cuda.set_device(gpu)
# Network Builders
net_encoder = ModelBuilder.build_encoder(
arch=cfg.MODEL.arch_encoder,
fc_dim=cfg.MODEL.fc_dim,
weights=cfg.MODEL.weights_encoder)
net_decoder = ModelBuilder.build_decoder(
arch=cfg.MODEL.arch_decoder,
fc_dim=cfg.MODEL.fc_dim,
num_class=cfg.DATASET.num_class,
weights=cfg.MODEL.weights_decoder,
use_softmax=True)
crit = nn.NLLLoss(ignore_index=-1)
segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
# Dataset and Loader
dataset_test = TestDataset(
cfg.list_test,
cfg.DATASET)
loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=cfg.TEST.batch_size,
shuffle=False,
collate_fn=user_scattered_collate,
num_workers=5,
drop_last=True)
segmentation_module.cuda()
# Main loop
test(segmentation_module, loader_test, gpu)
print('Inference done!')
if __name__ == '__main__':
assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \
'PyTorch>=0.4.0 is required'
parser = argparse.ArgumentParser(
description="PyTorch Semantic Segmentation Testing"
)
parser.add_argument(
"--imgs",
required=True,
type=str,
help="an image path, or a directory name"
)
parser.add_argument(
"--cfg",
default="config/ade20k-resnet50dilated-ppm_deepsup.yaml",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument(
"--gpu",
default=0,
type=int,
help="gpu id for evaluation"
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
cfg.merge_from_file(args.cfg)
cfg.merge_from_list(args.opts)
# cfg.freeze()
logger = setup_logger(distributed_rank=0) # TODO
logger.info("Loaded configuration file {}".format(args.cfg))
logger.info("Running with config:\n{}".format(cfg))
cfg.MODEL.arch_encoder = cfg.MODEL.arch_encoder.lower()
cfg.MODEL.arch_decoder = cfg.MODEL.arch_decoder.lower()
# absolute paths of model weights
cfg.MODEL.weights_encoder = os.path.join(
cfg.DIR, 'encoder_' + cfg.TEST.checkpoint)
cfg.MODEL.weights_decoder = os.path.join(
cfg.DIR, 'decoder_' + cfg.TEST.checkpoint)
assert os.path.exists(cfg.MODEL.weights_encoder) and \
os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"
# generate testing image list
if os.path.isdir(args.imgs):
imgs = find_recursive(args.imgs)
else:
imgs = [args.imgs]
assert len(imgs), "imgs should be a path to image (.jpg) or directory."
cfg.list_test = [{'fpath_img': x} for x in imgs]
if not os.path.isdir(cfg.TEST.result):
os.makedirs(cfg.TEST.result)
main(cfg, args.gpu)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/Mr_xiaolong/semantic-segmentation-pytorch.git
git@gitee.com:Mr_xiaolong/semantic-segmentation-pytorch.git
Mr_xiaolong
semantic-segmentation-pytorch
semantic-segmentation-pytorch
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891