.TOWER_NAME, i)) as scope:
# Calculate the loss for one tower of the CIFAR model. This function
# constructs the entire CIFAR model but shares the variables across
# all towers.
loss = tower_loss(scope)
# Reuse variables for the next tower.
tf.get_variable_scope().reuse_variables()
# Retain the summaries from the final tower.
# summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
# Calculate the gradients for the batch of data on this CIFAR tower.
grads = opt.compute_gradients(loss)
# Keep track of the gradients across all towers.
tower_grads.append(grads)
# We must calculate the mean of each gradient. Note that this is the
# synchronization point across all towers.
grads = average_gradients(tower_grads)
# Add a summary to track the learning rate.
# summaries.append(tf.scalar_summary('learning_rate', lr))
# Add histograms for gradients.
# for grad, var in grads:
# if grad is not None:
# summaries.append(
# tf.histogram_summary(var.op.name + '/gradients', grad))
# Apply the gradients to adjust the shared variables.
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
# Add histograms for trainable variables.
# for var in tf.trainable_variables():
# summaries.append(tf.histogram_summary(var.op.name, var))
# Track the moving averages of all trainable variables.
# variable_averages = tf.train.ExponentialMovingAverage(
# cifar10.MOVING_AVERAGE_DECAY, global_step)
# variables_averages_op = variable_averages.apply(tf.trainable_variables())
# Group all updates to into a single train op.
# train_op = tf.group(apply_gradient_op, variables_averages_op)
# Create a saver.
saver = tf.train.Saver(tf.all_variables())
# Build the summary operation from the last tower summaries.
# summary_op = tf.merge_summary(summaries)
# Build an initialization operation to run below.
init = tf.global_variables_initializer()
# Start running operations on the Graph. allow_soft_placement must be set to
# True to build towers on GPU, as some of the ops do not have GPU
# implementations.
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run(init)
# Start the queue runners.
tf.train.start_queue_runners(sess=sess)
# summary_writer = tf.train.SummaryWriter(train_dir, sess.graph)
for step in range(max_steps):
start_time = time.time()
_, loss_value = sess.run([apply_gradient_op, loss])
duration = time.time() - start_time
assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
if step % 10 == 0:
num_examples_per_step = batch_size * num_gpus
examples_per_sec = num_examples_per_step / duration
sec_per_batch = duration / num_gpus
format_str = ('step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (step, loss_value,
examples_per_sec, sec_per_batch))
# if step % 100 == 0:
# summary_str = sess.run(summary_op)
# summary_writer.add_summary(summary_str, step)
# Save the model checkpoint periodically.
if step % 1000 == 0 or (step + 1) == max_steps:
# checkpoint_path = os.path.join(train_dir, 'model.ckpt')
saver.save(sess, '/tmp/cifar10_train/model.ckpt', global_step=step)
cifar10.maybe_download_and_extract()
#if tf.gfile.Exists(train_dir):
# tf.gfile.DeleteRecursively(train_dir)
#tf.gfile.MakeDirs(train_dir)
train()
参考资料: 《TensorFlow实战》
欢迎付费咨询(150元每小时),我的微信:qingxingfengzi
|