Editor’s Note: Luke Metz is a speaker for ODSC East 2022 this April 19th-21st. Be sure to check out his talk, “Learned Optimizers,” there!

As machine learning models continue to grow, the cost and time to train such models have grown increasingly unwieldy. These rising costs make it both harder to train models on new data as well as to perform research to improve future versions of these models. Traditionally, to train such models one makes use of hand-designed optimization algorithms such as stochastic gradient descent, or more sophisticated algorithms such as Adam. This post is about learned optimizers, which instead of relying on these supports, learn the optimization procedure that suits the goal best, which results in faster optimization!

I believe learned optimization, and more generally meta-learning is a natural next step to the deep learning revolution. Much in the same way learned features and compute replaced hand-designed features, learned algorithms and compute will replace hand-designed ones.

I am personally quite excited about this line of research, and have been pursuing this direction over the last few years,  training more and more capable learned optimizers.

In this post, I wanted to give a flavor of what training and using a learned optimizer looks like. We will make use of our team’s new open-source package learnedoptimization which enables learned optimizer and meta-learning research in JAX — Google’s new ML framework.

This post is brief. For a more detailed introduction, including getting started in collabs and examples, see our documentation.

Training a learned optimizer with `learned_optimization`

To kick things off, let’s first install learned_optimization and import the modules we will need.

!pip install git+https://github.com/google/learned_optimization.git

import jax
import jax.numpy as jnp
import matplotlib.pylab as plt
import numpy as onp
import tqdm # For fancy progress bars
from colabtools import adhoc_import
from learned_optimization.tasks.fixed import conv
from learned_optimization.tasks import base as tasks_base
from learned_optimization import eval_training
from learned_optimization.optimizers import base as opt_base
from learned_optimization.learned_optimizers import adafac_mlp_lopt
from learned_optimization.outer_trainers import truncated_pes
from learned_optimization.outer_trainers import lopt_truncated_step
from learned_optimization.outer_trainers import truncated_grad
from learned_optimization.outer_trainers import gradient_learner
from learned_optimization.outer_trainers import truncation_schedule

Target Task

When creating a learned optimizer, we must define some task, or set of tasks with which this learned optimizer should perform well on. This task contains a model architecture, a dataset, and a loss function. This could represent any kind of optimization problem, but in this particular example, we will focus on neural network optimization problems.

The task we will work with is a small convnet trained on Cifar10. This convnet operates on batches of Cifar10 images re-sized to 16×16, and leverages a small 3 hidden layer convnet to make predictions. This convnet is already implemented in `learned_optimization` and is located here.

We choose this problem as it is extremely small, and thus fast to experiment with.

task = conv.Conv_Cifar10_16_32x64x64()

We can initialize the neural network weights, sample a batch of data, compute a loss, or use JAX to compute gradients of this.

key = jax.random.PRNGKey(0)
weights = task.init(key)
batch = next(task.datasets.train)
print("loss:", task.loss(weights, key, batch))
grad = jax.grad(task.loss)(weights, key, batch)
print("Gradient shapes:")
jax.tree_map(lambda x: x.shape, grad)
loss: 2.3102112
Gradient shapes:
{'conv2_d': {'b': (32,), 'w': (3, 3, 3, 32)},
 'conv2_d_1': {'b': (64,), 'w': (3, 3, 32, 64)},
 'conv2_d_2': {'b': (64,), 'w': (3, 3, 64, 64)},
 'linear': {'b': (10,), 'w': (64, 10)}}


Before talking about learned optimizers, let’s introduce the more standard, hand-designed optimizer interface. These are functions that take gradient values and produce some new state which contains the new parameter values. To demonstrate, we can construct the SGD optimizer, and use this to take one step (`opt.update`) with a made-up gradient.

opt = opt_base.SGD(0.1)
params = jnp.ones([3])
opt_state = opt.init(params)
grads = jnp.ones([3])
new_opt_state = opt.update(opt_state, grads)

Learned Optimizers

A learned optimizer is a parametric optimizer — namely an optimizer which is a function of some set of parameters. One can initialize the weights of this learned optimizer, and use those weights to get an instance of an optimizer with which to do updates.

Much like with neural networks, there is a family of different kinds of learned optimizers we can use. The learned optimizer architecture we will use in this post was introduced in Practical tradeoffs between memory, compute, and performance in learned optimizers. It consists of a small neural network that is applied to each parameter.

lopt = adafac_mlp_lopt.AdafacMLPLOpt(hidden_size=32)

We can randomly initialize a set of weights and look at the structure of them.  First, there are only a small number of learnable parameters — only 242. We can also see the majority of these weights parameterize a neural network — w0 maps from 39 features to a hidden size of 4, w1 maps from 4 to 4, and w2 maps to the output.

lopt_weights = lopt.init(jax.random.PRNGKey(0))
shapes = jax.tree_map(lambda x: x.shape, lopt_weights)
num_params = sum(map(onp.prod, jax.tree_leaves(shapes)))
print("Total params:", num_params)
Total params: 242
{'adafactor_decays': (3,),
 'momentum_decays': (3,),
 'nn': {'~': {'b0': (32,),
   'b1': (32,),
   'b2': (2,),
   'w0': (39, 32),
   'w1': (32, 32),
   'w2': (32, 2)}},
 'rms_decays': (1,)}

We can use these weights to construct an instance of the optimizer as follows.

opt = lopt.opt_fn(lopt_weights)

This optimizer can then be used like before (but this time with an additional loss argument provided to update).

opt_state = opt.init(params, num_steps=10)
grads = jnp.ones([3])
loss = 1.0
new_opt_state = opt.update(opt_state, loss=loss, grad=grads)
new_params = opt.get_params(new_opt_state)

At this point, however, the lopt_weights are initialized completely randomly! This will not make for a very good optimizer without training them.


Before training our learned optimizer, let’s first run some baselines.

Our goal is to try to train this little convnet faster than hand-designed problems. As such we will compare against training this convnet task with Adam searching over a couple of different learning rates and for each learning rate 5 different random initializations.

We will make use of the learned_optimization.eval_training module to make this easier, and in particular, the single_task_training_curves function which iterates for num_steps, each step computing a gradient, and applying the provided optimizer. In addition to training, this function also evaluates the model performance while training.

key = jax.random.PRNGKey(0)
curves_for_lr = {}
for lr in [1e-4, 3e-4, 1e-3, 3e-3, 5e-3, 7e-3, 1e-2, 2e-2, 3e-2]:
  opt = opt_base.Adam(lr)
  curves_for_lr[lr] = []
  for s in range(5):
    key1, key = jax.random.split(key)
    curves = eval_training.single_task_training_curves(task, opt,
                                                      num_steps=200, key=key1, eval_every=5,
                                                      eval_batches=10, last_eval_batches=30)

We can now plot the results. On the left, we see learning curves for each learning rate. On the right, we can see the average performance achieved (orange) and the performance at the end of training (blue) as a function of learning rate.

fig, axs = plt.subplots(1,2, figsize=(15, 5))
for lr, curves in curves_for_lr.items():
  x = curves[0]["eval/xs"]
  y = onp.mean([c["eval/train/loss"] for c in curves], axis=0)
  axs[0].plot(x, y)
axs[0].set_xlabel("training iteration")
axs[0].set_ylabel("training loss")

xs = []
ys = []
ys_mean = []
for lr, curves in curves_for_lr.items():
  last_value = onp.mean([c["eval/train/loss"] for c in curves], axis=0)[-1]
  mean_value = onp.mean([c["eval/train/loss"] for c in curves])
axs[1].semilogx(xs, ys, "o-", label="last loss")
axs[1].semilogx(xs, ys_mean, "o-", label="mean loss")
axs[1].set_xlabel("learning rate")
axs[1].set_ylabel("training loss")

Learned OptimizersFrom this, it looks like a learning rate of ~2e-3 is roughly the best we can do and we can reach a minimum loss value around ~1.75.

Training the learned optimizer

Training a learned optimizer entails repeatedly training the inner problem (our small convnet) over and over again. In each iteration, we estimate some “meta-gradient” — a direction to move the weights of the learned optimizer to improve the ability of this learned optimizer to optimize this task. As with standard gradient-based training, we then move a small bit in this direction and repeat it over and over again. 

learned_optimization supports a number of different ways to estimate this gradient that span from finite difference, computing gradients with backprop, to more sophisticated techniques such as Persisent Evolution Strategies (PES) to estimate gradients.

For this example, we will make use of the PES gradient estimator as it has been demonstrated to work well for training learned optimizers. PES works by trying to improve the average loss the learned optimizer obtains over the course of training.

learned_optimization’s gradient estimators work on objects called TruncatedStep. These encapsulate all the details related to learned optimizers and expose a simple interface so that the same gradient estimators can be used for different kinds of meta-learned systems — not just learned optimizers.

For now, though, we will construct this `TruncatedStep` object for learned optimizers. In this step, we specify the truncation schedule, or how long we want each inner problem to take. We ran our baselines for 200 iterations, so we will use this same length here.

Training the learned optimizer is expensive. To make the computation run faster, we make use of vectorization. In particular, we use our learned optimizer to train multiple convnet at the same time leveraging accelerator hardware. We specify this with the `num_tasks` argument.

max_length = 200
trunc_sched = truncation_schedule.ConstantTruncationSchedule(max_length)
truncated_step = lopt_truncated_step.VectorizedLOptTruncatedStep(
    tasks_base.single_task_to_family(task), lopt, trunc_sched, num_tasks=32, random_initial_iteration_offset=max_length)

Finally, we can construct the gradient estimator.

grad_estimator = truncated_pes.TruncatedPES(truncated_step=truncated_step,

Next, we specify how to use these gradients to update the weights of the learned optimizer. For this we will use the Adam optimizer with clipped gradients. This combination has proven successful in the past, though the learning rate often needs to be searched.

outer_learning_rate = 3e-3
theta_opt = opt_base.GradientClipOptimizer(opt_base.Adam(outer_learning_rate))

Finally, the SingleMachineGradientLearner class makes use of this gradient estimator, and outer optimizer (Adam).

gradient_estimators = [grad_estimator]
outer_trainer = gradient_learner.SingleMachineGradientLearner(
    lopt, gradient_estimators, theta_opt)

Finally, we can train the weights of the learned optimizer! First, we initialize the initial state of the outer_trainer.

# Initialize weights of learned optimizer + state of workers.
key = jax.random.PRNGKey(int(onp.random.randint(0, int(2**30))))
outer_trainer_state = outer_trainer.init(key)
all_losses = []
losses = []

Then we iterate to train the weights of the learned optimizer. Each step of `outer_trainer.update` does one unroll (of length 20) and computes a gradient estimate for each different inner-problem instance, averages the meta-gradients, and applies Adam to update the weights of the learned optimizer.

For the sake of this post, and to make things faster to run, we only meta-train for 1000 iterations. This should take ~10 min on a good accelerator — I ran this on a single chip of a TPUv3. This is enough to outperform the baselines (as we will see) but using more compute pretty much always improves performance.

outer_iterations = 1000
for i in tqdm.trange(outer_iterations):
  key1, key = jax.random.split(key)
  outer_trainer_state, loss, metrics = outer_trainer.update(
      outer_trainer_state, key1)
  if i % 50 == 0:
    losses = []

Let’s see how we did! The following plot is showing outer-iteration (each iteration with which we update the learned optimizer weights) vs outer-loss (the measurement of performance for the learned optimizer — how well it is optimizing).

Losses going down — great!

plt.plot(onp.arange(len(all_losses))*50, all_losses,  "o-")
plt.xlabel("outer-iteration updates")
plt.ylabel("average loss of convnet inner-problem (loss from PES)")
plt.ylim(1.7, 2.3)

Learned Optimizers

Evaluating the trained model

Now the foregoing plot  shows losses averaged across each convnet being trained. This is a somewhat abstract measurement, and what we really want to see is that this optimizes faster than the baselines. To show this, we will evaluate our optimizer with the same single_task_training_curves function we used for our baselines.

To do this, we first need to construct an optimizer instance. Let’s first load the optimizer from the weights learned optimizer we found in the previous meta-training.

theta = outer_trainer.get_meta_params(outer_trainer_state)
opt = lopt.opt_fn(theta)

And then we can run the trainer.

key = jax.random.PRNGKey(1)
lopt_curves = []
for i in range(5):
  key1, key = jax.random.split(key)
  lopt_curves.append(eval_training.single_task_training_curves(task, opt,
                                                          num_steps=200, key=key1, eval_every=5,
                                                          eval_batches=20, last_eval_batches=30))

Finally, we can plot the result. We see our learned optimizer is faster and is reaching a minimum.

fig, ax = plt.subplots(1,1, figsize=(8, 5))
for lr, curves in curves_for_lr.items():
  x = curves[0]["eval/xs"]
  y = onp.mean([c["eval/train/loss"] for c in curves], axis=0)
  ax.plot(x, y)
x = lopt_curves[0]["eval/xs"]
y = onp.mean([c["eval/train/loss"] for c in lopt_curves], axis=0)
ax.plot(x,y, color="k")
ax.set_xlabel("training iteration")
ax.set_ylabel("training loss")


I hope this post gives a brief preview of how to train a learned optimizer!

What we show here is quite a small scale, capable of running inside a Colab notebook (here is this same post in notebook form; for faster training be sure to change the runtime type to use a GPU/TPU. Faster yet, get a GCP instance.)

As the amount of computing in the world grows, I am excited to see what learned optimizers will enable. My research agenda is to work towards more general-purpose learned optimizers by training them on a wide variety of tasks. We have published some results[1][2][3], but work is ongoing! If this tutorial has piqued your interest, give this a try and train your own learned optimizer! Also, check out my talk at the upcoming ODSC East conference, “Learned Optimizers.”

About the author/ODSC East 2022 speaker on Learned Optimizers

Luke Metz is a research scientist at Google Brain working on meta-learning and learned optimizers. He’s interested in building general-purpose, learned learning algorithms that not only perform well, but generalizes to new types of never before seen problems.