from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import re
import time
import numpy as np
from datetime import datetime
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# Constants used for dealing with the files, matches convert_to_records.
TRAIN_FILE = 'train.tfrecords'
VALIDATION_FILE = 'validation.tfrecords'
# If a model is trained with multiple GPUs, prefix all Op names with tower_name
# to differentiate the operations. Note that this prefix is removed from the
# names of the summaries when visualizing a model.
TOWER_NAME = 'tower'
# Constants describing the training process.
MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays.
LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor.
INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
# Global constants describing the MNIST data set.
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('batch_size', 64,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', '/home/norman/MNIST_data',
"""Path to the MNIST data directory.""")
tf.app.flags.DEFINE_string('train_dir', '/home/norman/MNIST_train',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_integer('num_gpus', 2,
"""How many GPUs to use.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
tf.app.flags.DEFINE_boolean('tb_logging', False,
"""Whether to log to Tensorboard.""")
tf.app.flags.DEFINE_integer('num_epochs', 10,
"""Number of epochs to run trainer.""")
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
# Defaults are not specified since both keys are required.
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
# Convert from a scalar string tensor (whose single string has
# length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
# [mnist.IMAGE_PIXELS].
image = tf.decode_raw(features['image_raw'], tf.uint8)
# OPTIONAL: Could reshape into a 28x28 image and apply distortions
# here. Since we are not applying any distortions in this
# example, and the next step expects the image to be flattened
# into a vector, we don't bother.
# Convert from [0, 255] -> [-0.5, 0.5] floats.
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
# Convert label from a scalar uint8 tensor to an int32 scalar.
label = tf.cast(features['label'], tf.int32)
return image, label
def inputs(train, batch_size, num_epochs):
"""Reads input data num_epochs times.
train: Selects between the training (True) and validation (False) data.
batch_size: Number of examples per returned batch.
num_epochs: Number of times to read the input data, or 0/None to
train forever.
A tuple (images, labels), where:
* images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
in the range [-0.5, 0.5].
* labels is an int32 tensor with shape [batch_size] with the true label,
a number in the range [0, mnist.NUM_CLASSES).
Note that an tf.train.QueueRunner is added to the graph, which
must be run using e.g. tf.train.start_queue_runners().
if not num_epochs: num_epochs = None
filename = os.path.join(FLAGS.data_dir,
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(
[filename], num_epochs=num_epochs)
# Even when reading in multiple threads, share the filename
# queue.
image, label = read_and_decode(filename_queue)
# Shuffle the examples and collect them into batch_size batches.
# (Internally uses a RandomShuffleQueue.)
# We run this in two threads to avoid being a bottleneck.
images, sparse_labels = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, num_threads=2,
capacity=1000 + 3 * batch_size,
# Ensures a minimum amount of shuffling of examples.
return images, sparse_labels
def inference(images):
"""Build the MNIST model.
images: Images returned from MNIST or inputs().
# We instantiate all variables using tf.get_variable() instead of
# tf.Variable() in order to share variables across multiple GPU training
# runs. If we only ran this model on a single GPU, we could simplify this
# function by replacing all instances of tf.get_variable()
# with tf.Variable().
# Reshape to use within a convolutional neural net.
# Last dimension is for "features" - there is only one here, since images
# are grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc.
x_image = tf.reshape(images, [-1, 28, 28, 1])
# conv1
with tf.variable_scope('conv1') as scope:
kernel = _variable_with_weight_decay('weights',
shape=[5, 5, 1, 32],
biases = _variable_on_cpu('biases', [32], tf.constant_initializer(0.0))
conv = tf.nn.conv2d(x_image, kernel, strides=[1, 1, 1, 1],
pre_activation = tf.nn.bias_add(conv, biases)
conv1 = tf.nn.relu(pre_activation, name=scope.name)
# pool1
pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME', name='pool1')
# norm1
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
# conv2
with tf.variable_scope('conv2') as scope:
kernel = _variable_with_weight_decay('weights',
shape=[5, 5, 32, 64],
conv = tf.nn.conv2d(norm1, kernel, strides=[1, 1, 1, 1], padding='SAME')
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
pre_activation = tf.nn.bias_add(conv, biases)
conv2 = tf.nn.relu(pre_activation, name=scope.name)
# norm2
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
# pool2
pool2 = tf.nn.max_pool(norm2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME', name='pool2')
# local3
with tf.variable_scope('local3') as scope:
# Move everything into depth so we can perform a single matrix multiply.
reshape = tf.reshape(pool2, [-1, 7 * 7 * 64])
dim = reshape.get_shape()[1].value
weights = _variable_with_weight_decay('weights', shape=[dim, 1024],
stddev=0.04, wd=0.004)
biases = _variable_on_cpu('biases', [1024],
local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases,
# local4
with tf.variable_scope('local4') as scope:
weights = _variable_with_weight_decay('weights', shape=[1024, 10],
stddev=0.04, wd=0.004)
biases = _variable_on_cpu('biases', [10], tf.constant_initializer(0.1))
local4 = tf.nn.relu(tf.matmul(local3, weights) + biases,
# linear layer(WX + b),
# We don't apply softmax here because
# tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits
# and performs the softmax internally for efficiency.
with tf.variable_scope('softmax_linear') as scope:
weights = _variable_with_weight_decay('weights', [10, 10],
stddev=1 / 192.0, wd=0.0)
biases = _variable_on_cpu('biases', [10],
softmax_linear = tf.add(tf.matmul(local4, weights), biases,
return softmax_linear
def _variable_with_weight_decay(name, shape, stddev, wd):
"""Helper to create an initialized Variable with weight decay.
Note that the Variable is initialized with a truncated normal distribution.
A weight decay is added only if one is specified.
name: name of the variable
shape: list of ints
stddev: standard deviation of a truncated Gaussian
wd: add L2Loss weight decay multiplied by this float. If None, weight
decay is not added for this Variable.
Variable Tensor
dtype = tf.float32
var = _variable_on_cpu(
tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
if wd is not None:
weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
tf.add_to_collection('losses', weight_decay)
return var
def _variable_on_cpu(name, shape, initializer):
"""Helper to create a Variable stored on CPU memory.
name: name of the variable
shape: list of ints
initializer: initializer for Variable
Variable Tensor
with tf.device('/cpu:0'):
dtype = tf.float32
var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
return var
def _activation_summary(x):
"""Helper to create summaries for activations.
Creates a summary that provides a histogram of activations.
Creates a summary that measures the sparsity of activations.
x: Tensor
# Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
# session. This helps the clarity of presentation on tensorboard.
if FLAGS.tb_logging:
tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
tf.summary.histogram(tensor_name + '/activations', x)
tf.summary.scalar(tensor_name + '/sparsity',
def loss(logits, labels):
"""Add L2Loss to all the trainable variables.
Add summary for "Loss" and "Loss/avg".
logits: Logits from inference().
labels: Labels from distorted_inputs or inputs(). 1-D tensor
of shape [batch_size]
Loss tensor of type float.
# Calculate the average cross entropy loss across the batch.
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name='cross_entropy_per_example')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
# The total loss is defined as the cross entropy loss plus all of the weight
# decay terms (L2 loss).
return tf.add_n(tf.get_collection('losses'), name='total_loss')
def tower_loss(scope):
"""Calculate the total loss on a single tower running the MNIST model.
scope: unique prefix string identifying the MNIST tower, e.g. 'tower_0'
Tensor of shape [] containing the total loss for a batch of data
# Input images and labels.
images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
# Build inference Graph.
logits = inference(images)
# Build the portion of the Graph calculating the losses. Note that we will
# assemble the total_loss using a custom function below.
_ = loss(logits, labels)
# Assemble all of the losses for the current tower only.
losses = tf.get_collection('losses', scope)
# Calculate the total loss for the current tower.
total_loss = tf.add_n(losses, name='total_loss')
# Attach a scalar summary to all individual losses and the total loss; do
# the same for the averaged version of the losses.
if FLAGS.tb_logging:
for l in losses + [total_loss]:
# Remove 'tower_[0-9]/' from the name in case this is a multi-GPU
# training session. This helps the clarity of presentation on
# tensorboard.
loss_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', l.op.name)
tf.summary.scalar(loss_name, l)
return total_loss
def average_gradients(tower_grads):
"""Calculate average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
List of pairs of (gradient, variable) where the gradient has been
averaged across all towers.
average_grads = []
for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads = []
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
# Average over the 'tower' dimension.
grad = tf.concat(grads, 0)
grad = tf.reduce_mean(grad, 0)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
return average_grads
def train():
with tf.Graph().as_default(), tf.device('/cpu:0'):
# Create a variable to count the number of train() calls. This equals
# the number of batches processed * FLAGS.num_gpus.
global_step = tf.get_variable(
'global_step', [],
initializer=tf.constant_initializer(0), trainable=False)
# Calculate the learning rate schedule.
num_batches_per_epoch = (NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN /
decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
# Decay the learning rate exponentially based on the number of steps.
lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
opt = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True, use_locking=True)
# Calculate the gradients for each model tower.
tower_grads = []
with tf.variable_scope(tf.get_variable_scope()):
for i in xrange(FLAGS.num_gpus):
with tf.device('/gpu:%d' % i):
with tf.name_scope(
'%s_%d' % (TOWER_NAME, i)) as scope:
# Calculate the loss for one tower of the CIFAR model.
# This function constructs the entire CIFAR model but
# shares the variables across all towers.
loss = tower_loss(scope)
# Reuse variables for the next tower.
# Retain the summaries from the final tower.
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
# Calculate the gradients for the batch of data on this
# MNIST tower.
grads = opt.compute_gradients(loss, gate_gradients=0)
# Keep track of the gradients across all towers.
# We must calculate the mean of each gradient. Note that this is the
# synchronization point across all towers.
grads = average_gradients(tower_grads)
# Add histograms for gradients.
if FLAGS.tb_logging:
for grad, var in grads:
if grad is not None:
tf.summary.histogram(var.op.name + '/gradients', grad))
# Add a summary to track the learning rate.
summaries.append(tf.summary.scalar('learning_rate', lr))
train_op = opt.apply_gradients(grads, global_step=global_step)
# Add histograms for trainable variables.
if FLAGS.tb_logging:
for var in tf.trainable_variables():
summaries.append(tf.summary.histogram(var.op.name, var))
# Create a saver.
saver = tf.train.Saver(tf.global_variables(), sharded=True)
# Build the summary operation from the last tower summaries.
summary_op = tf.summary.merge(summaries)
# Build an initialization operation to run below.
# init = tf.global_variables_initializer()
# The op for initializing the variables.
init_op = tf.group(tf.global_variables_initializer(),
# Start running operations on the Graph. allow_soft_placement must be
# set to True to build towers on GPU, as some of the ops do not have GPU
# implementations.
sess = tf.Session(config=tf.ConfigProto(
# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
step = 0
while not coord.should_stop():
start_time = time.time()
# Run one step of the model. The return values are
# the activations from the `train_op` (which is
# discarded) and the `loss` op. To inspect the values
# of your ops or variables, you may include them in
# the list passed to sess.run() and the value tensors
# will be returned in the tuple from the call.
_, loss_value = sess.run([train_op, loss])
duration = time.time() - start_time
assert not np.isnan(
loss_value), 'Model diverged with loss = NaN'
# Print an overview fairly often.
if step % 100 == 0:
num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
examples_per_sec = num_examples_per_step / duration
sec_per_batch = duration / FLAGS.num_gpus
format_str = (
'%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
print(format_str % (datetime.now(), step, loss_value,
examples_per_sec, sec_per_batch))
if FLAGS.tb_logging:
if step % 10 == 0:
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, step)
# Save the model checkpoint periodically.
if step % 1000 == 0 or (
step + 1) == FLAGS.num_epochs * FLAGS.batch_size:
checkpoint_path = os.path.join(FLAGS.train_dir,
saver.save(sess, checkpoint_path, global_step=step)
step += 1
except tf.errors.OutOfRangeError:
print('Done training for %d epochs, %d steps.' % (
FLAGS.num_epochs, step))
# When done, ask the threads to stop.
# Wait for threads to finish.
def evaluate():
"""Eval MNIST for a number of steps."""
with tf.Graph().as_default():
# Get images and labels for MNIST.
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False)
images = mnist.test.images
labels = mnist.test.labels
# Build a Graph that computes the logits predictions from the
# inference model.
logits = inference(images)
# Calculate predictions.
top_k_op = tf.nn.in_top_k(predictions=logits, targets=labels, k=1)
# Create saver to restore the learned variables for eval.
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
print('No checkpoint file found')
predictions = np.sum(sess.run([top_k_op]))
# Compute precision.
print('%s: precision = %.3f' % (datetime.now(), predictions))
def main(argv=None): # pylint: disable=unused-argument
start_time = time.time()
duration = time.time() - start_time
print('Total Duration (%.3f sec)' % duration)
if __name__ == '__main__':
