代码拉取完成,页面将自动刷新
同步操作将从 北部湾的落日/imageProcessing 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("./tensorflow/MNIST_data", one_hot=True) #MNIST数据输入
import tensorflow as tf
import os
batch_size = 100
display_step = 1
#Network Parameters
n_input = 784
n_classes = 10
#tf Graph input
x = tf.placeholder("float",[None,n_input])
y = tf.placeholder("float",[None,n_classes])
#pre-define
def conv2d(x,W):
return tf.nn.conv2d(x,W,
strides=[1,1,1,1],
padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],
strides=[1,2,2,1],
padding='SAME')
#Create model
def multilayer_preceptron(x,weights,biases):
#now,we want to change this to a CNN network
#first,reshape the data to 4_D
x_image=tf.reshape(x,[-1,28,28,1])
#then apply cnn layers
h_conv1=tf.nn.relu(conv2d(x_image,weights['conv1'])+biases['conv_b1'])
h_pool1=max_pool_2x2(h_conv1)
h_conv2=tf.nn.relu(conv2d(h_pool1,weights['conv2'])+biases['conv_b2'])
h_pool2=max_pool_2x2(h_conv2)
h_pool2_flat=tf.reshape(h_pool2,[-1,7*7*64])
h_fc1=tf.nn.relu(tf.matmul(h_pool2_flat,weights['fc1'])+biases['fc1_b'])
out_layer=tf.matmul(h_fc1,weights['out'])+biases['out_b']
return out_layer
weights={
'conv1':tf.Variable(tf.random_normal([5,5,1,32])),
'conv2':tf.Variable(tf.random_normal([5,5,32,64])),
'fc1':tf.Variable(tf.random_normal([7*7*64,256])),
'out':tf.Variable(tf.random_normal([256,n_classes]))
}
biases={
'conv_b1':tf.Variable(tf.random_normal([32])),
'conv_b2':tf.Variable(tf.random_normal([64])),
'fc1_b':tf.Variable(tf.random_normal([256])),
'out_b':tf.Variable(tf.random_normal([n_classes]))
}
#Construct model
pred = multilayer_preceptron(x,weights,biases)
#create class Saver
model_saver = tf.train.Saver()
#Launch the gtrph
with tf.Session() as sess:
#create dir for model saver
model_dir = "mnist"
model_name = "cpk"
model_path=os.path.join(model_dir,model_name)
model_saver.restore(sess,model_path)
img=mnist.test.images[100].reshape(-1,784)
img_label=sess.run(tf.argmax(mnist.test.labels[100]))
ret=sess.run(pred,feed_dict={x:img})
num_pred=sess.run(tf.argmax(ret,1))
print("预测值:%d\n" % num_pred)
print("真实值:",img_label)
print("模型恢复成功")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。