代码拉取完成,页面将自动刷新
from operator import mod
import torch
from models import *
from voc import *
import torchvision.transforms as transforms
from einops import rearrange
num_classes=20
patches=1
model = mix_resnet101(num_classes=num_classes, pretrained=False,freeze=0, base_patches=patches, mix_layers=2, t=0, adj_file='data/voc/voc_adj.pkl')
checkpoint = torch.load('checkpoint/voc/voc_checkpoint.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
train_dataset = Voc2007Classification('data/voc', 'trainval', inp_name='data/voc/voc_glove_word2vec.pkl')
normalize = transforms.Normalize(mean=model.image_normalization_mean, std=model.image_normalization_std)
transforms_data = transforms.Compose([
Warp(448),
transforms.ToTensor(),
normalize,
])
train_dataset.transform=transforms_data
# print(train_dataset[0])
input,target=train_dataset[0]
target[target == 0] = 1
target[target == -1] = 0
print(input[0])
print(target)
f= lambda x:rearrange(x,'(n c) h w -> n c h w',n=1)
f2 = lambda x:rearrange(x,'(n c) w -> n c w',n=1)
feature = f(input[0])
inp = f2(torch.tensor(input[2]))
feature_var = torch.autograd.Variable(feature).float()
target_var = torch.autograd.Variable(target).float()
inp_var = torch.autograd.Variable(inp).float().detach()
print(model(feature_var,inp_var))
print(target_var)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。