1 Star 0 Fork 1

diycp2015 / imageProcessing

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
mnist_test.py 2.33 KB
一键复制 编辑 原始数据 按行查看 历史
北部湾的落日 提交于 2018-05-09 09:58 . Initial commit
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./tensorflow/MNIST_data", one_hot=True) #MNIST数据输入
import tensorflow as tf
import os
batch_size = 100
display_step = 1
#Network Parameters
n_input = 784
n_classes = 10
#tf Graph input
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])
#pre-define
def conv2d(x,W):
return tf.nn.conv2d(x,W,
strides=[1,1,1,1],
padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],
strides=[1,2,2,1],
padding='SAME')
#Create model
def multilayer_preceptron(x,weights,biases):
#now,we want to change this to a CNN network
#first,reshape the data to 4_D
x_image=tf.reshape(x,[-1,28,28,1])
#then apply cnn layers
h_conv1=tf.nn.relu(conv2d(x_image,weights['conv1'])+biases['conv_b1'])
h_pool1=max_pool_2x2(h_conv1)
h_conv2=tf.nn.relu(conv2d(h_pool1,weights['conv2'])+biases['conv_b2'])
h_pool2=max_pool_2x2(h_conv2)
h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,weights['fc1'])+biases['fc1_b'])
out_layer=tf.matmul(h_fc1,weights['out'])+biases['out_b']
return out_layer
weights={
'conv1':tf.Variable(tf.random_normal([5,5,1,32])),
'conv2':tf.Variable(tf.random_normal([5,5,32,64])),
'fc1':tf.Variable(tf.random_normal([7*7*64,256])),
'out':tf.Variable(tf.random_normal([256,n_classes]))
}
biases={
'conv_b1':tf.Variable(tf.random_normal([32])),
'conv_b2':tf.Variable(tf.random_normal([64])),
'fc1_b':tf.Variable(tf.random_normal([256])),
'out_b':tf.Variable(tf.random_normal([n_classes]))
}
#Construct model
pred = multilayer_preceptron(x,weights,biases)
#create class Saver
model_saver = tf.train.Saver()
#Launch the gtrph
with tf.Session() as sess:
#create dir for model saver
model_dir = "mnist"
model_name = "cpk"
model_path=os.path.join(model_dir,model_name)
model_saver.restore(sess,model_path)
img=mnist.test.images[100].reshape(-1,784)
img_label=sess.run(tf.argmax(mnist.test.labels[100]))
ret=sess.run(pred,feed_dict={x:img})
num_pred=sess.run(tf.argmax(ret,1))
print("预测值:%d\n" % num_pred)
print("真实值:",img_label)
print("模型恢复成功")
1
https://gitee.com/diycp2015/imageProcessing.git
git@gitee.com:diycp2015/imageProcessing.git
diycp2015
imageProcessing
imageProcessing
master

搜索帮助