代码拉取完成,页面将自动刷新
同步操作将从 北部湾的落日/imageProcessing 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
from PIL import Image
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import re
mnist = input_data.read_data_sets("./tensorflow/MNIST_data", one_hot=True)
'''
x不是一个特定的值,而是一个占位符placeholder,我们在TensorFlow运行计算时输入这个值。
我们希望能够输入任意数量的MNIST图像,每一张图展平成784维的向量。我们用2维的浮点数张量来表示这些图,
这个张量的形状是[None,784 ]。(这里的None表示此张量的第一个维度可以是任何长度的。)
'''
x = tf.placeholder(tf.float32, [None, 784])
'''
用全为零的张量来初始化W和b
'''
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
'''
训练模型
用tf.matmul(X,W)表示x乘以W,对应之前等式里面的
'''
y = tf.nn.softmax(tf.matmul(x,W) + b)
# #评估指标
# y_ = tf.placeholder("float", [None,10])
# #交叉熵计算
# cross_entropy = -tf.reduce_sum(y_*tf.log(y))
# '''
# 用梯度下降算法(gradient descent algorithm)以0.01的学习速率最小化交叉熵
# '''
# train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
# # Add an op to initialize the variables.
# init = tf.initialize_all_variables()
#saver model
model_saver = tf.train.Saver()
# sess = tf.Session()
# sess.run(init)
# for i in range(1000):
# batch_xs, batch_ys = mnist.train.next_batch(100)
# sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
#
# correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
# print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
#
# # create dir for model saver
# model_path = "./tensorflow/tmp/model.ckpt"
# save_path =model_saver.save(sess, model_path)
# print("Model saved in file: ", save_path)
# #图片读取
# def readImages(filename):
# label=1
# # imageList = []
# with tf.Session() as sess:
# # for filename in filenameList:
# # print(filename)
# # 读取图像的原始数据
# image_raw_data = tf.gfile.FastGFile(filename,'rb').read() # 必须是 ‘rb’ 模式打开,否则会报错
# # 将图像使用 jpeg 的格式解码从而得到图像对应的三维矩阵
# # tf.image.decode_jpeg 函数对 png 格式的图像进行解码。解码之后的结果为一个张量,
# # 在使用它的取值之前需要明确调用运行的过程。
# print(filename)
# img_data = tf.image.decode_jpeg(image_raw_data)
# # arr = np.reshape(img_data.eval(sess), [-1]) # 多维矩阵转一维矩阵
# arr = sess.run(tf.reshape(img_data.eval(), [-1]))
# # imageList.append(arr)
# print(tf.shape(arr))
# return np.array(arr),label
def image_to_array(path):
im = Image.open(path)
# w, h = im.size
# r, g, b = im.split() # rgb通道分离
r_arr = np.array(im).reshape(-1,28*28)
# g_arr = np.array(g).reshape(28)
# b_arr = np.array(b).reshape(28)
# plt.imshow(im)
# plt.show()
if(np.shape(r_arr)[0]>1):
r_arr = r_arr[0].reshape(-1,28*28)
return r_arr,1
#Launch the gtrph
with tf.Session() as sess:
#create dir for model saver
model_path = "./tensorflow/tmp/model.ckpt"
model_saver.restore(sess,model_path)
# img=mnist.test.images[20].reshape(-1,784)
# img_label=sess.run(tf.argmax(mnist.test.labels[20]))
image_path = "./tensorflow/mnist_digits_images/2.jpg"
img,img_label = image_to_array(image_path)
print(np.shape(img))
ret=sess.run(y,feed_dict={x:img})
num_pred=sess.run(tf.argmax(ret,1))
print("预测值:%d\n" % num_pred)
print("真实值:",img_label)
print("模型恢复成功")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。