Multi-GPU training with Tensorflow Estimators

September 6, 2018

Tensorflow Estimators handle much of the boilerplate of training a neural network like saving checkpoints and summaries, running the training loop, periodically evaluating on the validation set, and so on. If you’re already using Estimators, it’s easy to move from training on a single GPU to synchronous multi-GPU training with only three new lines of code.

First, wrap your model_fn with tf.contrib.estimator.replicate_model_fn:

nn = tf.estimator.Estimator(

Second, wrap your optimizer with TowerOptimizer in your model_fn:

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

Unfortunately there’s one additional (undocumented) step that you have to take to get the code to work. If you only use the two lines above, you’ll see an error that will look something like this:

ValueError: Variable does not exist, or was not created with
tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?

To fix this, you need to wrap the definition of your model with a variable_scope that has reuse set to tf.AUTO_REUSE:

def model_fn(features, labels, mode, params):
  with tf.variable_scope('my_model', reuse=tf.AUTO_REUSE):
    net = tf.identity(features, 'input')
    net = tf.layers.dense(net, 1024, activation=tf.nn.relu)

The reason for this is that behind the scenes, replicate_model_fn will try to put all the variables in your net on each GPU. However, you want to share the variables across all the GPUs. This means that the variables need to be created with reuse=True for all but one of the GPUs. (One of the GPUs needs to have reuse=False so that the variables can be created — you can’t reuse what doesn’t exist in the first place!) Setting reuse=tf.AUTO_REUSE takes care of this variable management for you.

Note that replicate_model_fn is now deprecated, although its proposed replacement, MirroredStrategy doesn’t seem to be ready to use quite yet (at least I haven’t been able to get it to work).