1 Star 0 Fork 1

diycp2015 / imageProcessing

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
readMnistTestDataset.py 3.01 KB
一键复制 编辑 原始数据 按行查看 历史
北部湾的落日 提交于 2018-05-09 09:58 . Initial commit
#!/usr/bin/python3.6
# -*- coding: utf-8 -*-
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from PIL import Image
# 声明图片宽高
rows = 28
cols = 28
# 要提取的图片数量
images_to_extract = 1000
# 当前路径下的保存目录
save_dir = "./tensorflow/mnist_test_images"
# 读入mnist数据
mnist = input_data.read_data_sets("./tensorflow/MNIST_data/", one_hot=False)
# 创建会话
sess = tf.Session()
# 获取图片总数
shape = sess.run(tf.shape(mnist.test.images))
images_count = shape[0]
pixels_per_image = shape[1]
# 获取标签总数
shape = sess.run(tf.shape(mnist.test.labels))
labels_count = shape[0]
# mnist.train.labels是一个二维张量,为便于后续生成数字图片目录名,有必要一维化(后来发现只要把数据集的one_hot属性设为False,mnist.train.labels本身就是一维)
# labels = sess.run(tf.argmax(mnist.train.labels, 1))
labels = mnist.test.labels
# 检查数据集是否符合预期格式
if (images_count == labels_count) and (shape.size == 1):
print("数据集总共包含 %s 张图片,和 %s 个标签" % (images_count, labels_count))
print("每张图片包含 %s 个像素" % (pixels_per_image))
print("数据类型:%s" % (mnist.test.images.dtype))
# mnist图像数据的数值范围是[0,1],需要扩展到[0,255],以便于人眼观看
if mnist.test.images.dtype == "float32":
print("准备将数据类型从[0,1]转为binary[0,255]...")
for i in range(0, images_to_extract):
for n in range(pixels_per_image):
if mnist.test.images[i][n] != 0:
mnist.test.images[i][n] = 255
# 由于数据集图片数量庞大,转换可能要花不少时间,有必要打印转换进度
if ((i + 1) % 50) == 0:
print("图像浮点数值扩展进度:已转换 %s 张,共需转换 %s 张" % (i + 1, images_to_extract))
# 创建数字图片的保存目录
for i in range(10):
dir = "%s/%s/" % (save_dir, i)
if not os.path.exists(dir):
print("目录 ""%s"" 不存在!自动创建该目录..." % dir)
os.makedirs(dir)
# 通过python图片处理库,生成图片
indices = [0 for x in range(0, 10)]
for i in range(0, images_to_extract):
img = Image.new("L", (cols, rows))
for m in range(rows):
for n in range(cols):
img.putpixel((n, m), int(mnist.test.images[i][n + m * cols]))
# 根据图片所代表的数字label生成对应的保存路径
digit = labels[i]
path = "%s/%s/%s.bmp" % (save_dir, labels[i], indices[digit])
indices[digit] += 1
img.save(path)
# 由于数据集图片数量庞大,保存过程可能要花不少时间,有必要打印保存进度
if ((i + 1) % 50) == 0:
print("图片保存进度:已保存 %s 张,共需保存 %s 张" % (i + 1, images_to_extract))
else:
print("图片数量和标签数量不一致!")
1
https://gitee.com/diycp2015/imageProcessing.git
git@gitee.com:diycp2015/imageProcessing.git
diycp2015
imageProcessing
imageProcessing
master

搜索帮助

53164aa7 5694891 3bd8fe86 5694891