1 Star 0 Fork 0

Mikael / ML-GCN

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
inference.py 1.34 KB
一键复制 编辑 原始数据 按行查看 历史
G5 提交于 2022-03-23 13:25 . inference
from operator import mod
import torch
from models import *
from voc import *
import torchvision.transforms as transforms
from einops import rearrange
num_classes=20
patches=1
model = mix_resnet101(num_classes=num_classes, pretrained=False,freeze=0, base_patches=patches, mix_layers=2, t=0, adj_file='data/voc/voc_adj.pkl')
checkpoint = torch.load('checkpoint/voc/voc_checkpoint.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
train_dataset = Voc2007Classification('data/voc', 'trainval', inp_name='data/voc/voc_glove_word2vec.pkl')
normalize = transforms.Normalize(mean=model.image_normalization_mean, std=model.image_normalization_std)
transforms_data = transforms.Compose([
Warp(448),
transforms.ToTensor(),
normalize,
])
train_dataset.transform=transforms_data
# print(train_dataset[0])
input,target=train_dataset[0]
target[target == 0] = 1
target[target == -1] = 0
print(input[0])
print(target)
f= lambda x:rearrange(x,'(n c) h w -> n c h w',n=1)
f2 = lambda x:rearrange(x,'(n c) w -> n c w',n=1)
feature = f(input[0])
inp = f2(torch.tensor(input[2]))
feature_var = torch.autograd.Variable(feature).float()
target_var = torch.autograd.Variable(target).float()
inp_var = torch.autograd.Variable(inp).float().detach()
print(model(feature_var,inp_var))
print(target_var)
Python
1
https://gitee.com/Mikael_ac/ml-gcn.git
git@gitee.com:Mikael_ac/ml-gcn.git
Mikael_ac
ml-gcn
ML-GCN
master

搜索帮助