1 Star 0 Fork 0

phoneProject / VGG

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
model.py 3.11 KB
一键复制 编辑 原始数据 按行查看 历史
tsl 提交于 2021-06-16 23:05 . VGG-pytorch
import torch
import torch.nn as nn
class VGG(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.num_classes = num_classes
self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64,padding=1,kernel_size=3,stride=1)
self.relu = nn.ReLU()
self.conv1_2 = nn.Conv2d(in_channels=64,out_channels=64,padding=1,kernel_size=3,stride=1)
self.relu = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
self.conv2_1 = nn.Conv2d(in_channels=64,out_channels=128,padding=1,kernel_size=3,stride=1)
self.relu = nn.ReLU()
self.conv2_2 = nn.Conv2d(in_channels=128,out_channels=128,padding=1,kernel_size=3,stride=1)
self.relu = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
self.conv3_1 = nn.Conv2d(in_channels=128,out_channels=256,padding=1,kernel_size=3,stride=1)
self.relu = nn.ReLU()
self.conv3_2 = nn.Conv2d(in_channels=256,out_channels=256,padding=1,kernel_size=3,stride=1)
self.pool3 = nn.MaxPool2d(kernel_size=2,stride=2)
self.conv4_1 = nn.Conv2d(in_channels=256,out_channels=512,padding=1,kernel_size=3,stride=1)
self.relu = nn.ReLU()
self.conv4_2 = nn.Conv2d(in_channels=512,out_channels=512,padding=1,kernel_size=3,stride=1)
self.relu = nn.ReLU()
self.conv4_2 = nn.Conv2d(in_channels=512,out_channels=512,padding=1,kernel_size=3,stride=1)
self.pool4 = nn.MaxPool2d(kernel_size=2,stride=2)
self.conv5_1 = nn.Conv2d(in_channels=512,out_channels=512,padding=1,kernel_size=3,stride=1)
self.relu = nn.ReLU()
self.conv5_2 = nn.Conv2d(in_channels=512,out_channels=512,padding=1,kernel_size=3,stride=1)
self.relu = nn.ReLU()
self.conv5_3 = nn.Conv2d(in_channels=512,out_channels=512,padding=1,kernel_size=3,stride=1)
self.pool5 = nn.MaxPool2d(kernel_size=2,stride=2)
self.liner6_1 = nn.Linear(in_features=512*7*7,out_features=4096)
self.relu = nn.ReLU()
self.dropout6_1 = nn.Dropout(p=0.5)
self.liner6_2 = nn.Linear(in_features=4096,out_features=4096)
self.relu = nn.ReLU()
self.liner6_3 = nn.Linear(in_features=4096, out_features=self.num_classes)
self.dropout6_2 = nn.Dropout(p=0.5)
def forward(self,x):
x = self.conv1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
x = self.relu(x)
x = self.pool1(x)
x = self.conv2_1(x)
x = self.relu(x)
x = self.conv2_2(x)
x = self.relu(x)
x = self.pool3(x)
x = self.conv4_1(x)
x = self.relu(x)
x = self.conv4_2(x)
x = self.relu(x)
x = self.pool4(x)
x = self.conv5_1(x)
x = self.relu(x)
x = self.conv5_2(x)
x = self.relu(x)
x = self.conv5_3(x)
x = self.relu(x)
x = self.pool5(x)
x = x.view(x.size(0), -1)
x = self.liner6_1(x)
x = self.relu(x)
x = self.dropout6_1(x)
x = self.liner6_2(x)
x = self.relu(x)
x = self.dropout6_2(x)
x = self.liner6_3(x)
x = self.dropout6_3(x)
return x
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/giteebytsl/vgg.git
git@gitee.com:giteebytsl/vgg.git
giteebytsl
vgg
VGG
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891