Skip to content

Instantly share code, notes, and snippets.

@amoudgl
Last active January 3, 2024 14:38
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amoudgl/12a1c079fe010ac966b73766e67405d0 to your computer and use it in GitHub Desktop.
Save amoudgl/12a1c079fe010ac966b73766e67405d0 to your computer and use it in GitHub Desktop.
Code for our ICML23W paper "Learning to Optimize with Recurrent Hierarchical Transformers"
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
A transformer-based learned optimizer which synthesizes inter-tensor
communication with self-attention and propagates CLS token as hidden
state to keep track of optimization history.
This optimizer was introduced in:
https://openreview.net/forum?id=MusMaHCrs2
Acknowledgements:
* We use learned_optimization library for meta-training:
https://github.com/google/learned_optimization/
* Haiku transformer implementation:
https://github.com/google-deepmind/dm-haiku/blob/master/examples/transformer/
"""
from typing import Any, Optional, Tuple, Sequence
import dataclasses
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as onp
import warnings
import functools
import flax
import gin
from jax import lax
from learned_optimization import summary
from learned_optimization import tree_utils
from learned_optimization.learned_optimizers import base as lopt_base
from learned_optimization.learned_optimizers import common
from learned_optimization.optimizers import base as opt_base
PRNGKey = jnp.ndarray
def layer_norm(x: jax.Array) -> jax.Array:
"""Applies a unique LayerNorm to x with default settings."""
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
return ln(x)
class MultiHeadAttention(hk.Module):
"""Multi-headed attention (MHA) module.
This module is intended for attending over sequences of vectors.
Rough sketch:
- Compute keys (K), queries (Q), and values (V) as projections of inputs.
- Attention weights are computed as W = softmax(QK^T / sqrt(key_size)).
- Output is another projection of WV^T.
For more detail, see the original Transformer paper:
"Attention is all you need" https://arxiv.org/abs/1706.03762.
Glossary of shapes:
- T: Sequence length.
- D: Vector (embedding) size.
- H: Number of attention heads.
"""
def __init__(
self,
num_heads: int,
key_size: int,
w_init_scale: Optional[float] = None,
*,
w_init: Optional[hk.initializers.Initializer] = None,
with_bias: bool = True,
b_init: Optional[hk.initializers.Initializer] = None,
value_size: Optional[int] = None,
model_size: Optional[int] = None,
name: Optional[str] = None,
):
"""Initialises the module.
Args:
num_heads: Number of independent attention heads (H).
key_size: The size of keys (K) and queries used for attention.
w_init_scale: DEPRECATED. Please use w_init instead.
w_init: Initialiser for weights in the linear map. Once `w_init_scale` is
fully deprecated `w_init` will become mandatory. Until then it has a
default value of `None` for backwards compatability.
with_bias: Whether to add a bias when computing various linear
projections.
b_init: Optional initializer for bias. By default, zero.
value_size: Optional size of the value projection (V). If None, defaults
to the key size (K).
model_size: Optional size of the output embedding (D'). If None, defaults
to the key size multiplied by the number of heads (K * H).
name: Optional name for this module.
"""
super().__init__(name=name)
self.num_heads = num_heads
self.key_size = key_size
self.value_size = value_size or key_size
self.model_size = model_size or key_size * num_heads
# Backwards-compatibility for w_init_scale.
if w_init_scale is not None:
warnings.warn(
"w_init_scale is deprecated; please pass an explicit weight "
"initialiser instead.",
DeprecationWarning,
)
if w_init and w_init_scale:
raise ValueError("Please provide only `w_init`, not `w_init_scale`.")
if w_init is None and w_init_scale is None:
raise ValueError(
"Please provide a weight initializer: `w_init`. "
"`w_init` will become mandatory once `w_init_scale` is "
"fully deprecated."
)
if w_init is None:
w_init = hk.initializers.VarianceScaling(w_init_scale)
self.w_init = w_init
self.with_bias = with_bias
self.b_init = b_init
def __call__(
self,
query: jax.Array,
key: jax.Array,
value: jax.Array,
mask: Optional[jax.Array] = None,
) -> jax.Array:
"""Computes (optionally masked) MHA with queries, keys & values.
This module broadcasts over zero or more 'batch-like' leading dimensions.
Args:
query: Embeddings sequence used to compute queries; shape [..., T', D_q].
key: Embeddings sequence used to compute keys; shape [..., T, D_k].
value: Embeddings sequence used to compute values; shape [..., T, D_v].
mask: Optional mask applied to attention weights; shape [..., H=1, T', T].
Returns:
A new sequence of embeddings, consisting of a projection of the
attention-weighted value projections; shape [..., T', D'].
"""
# In shape hints below, we suppress the leading dims [...] for brevity.
# Hence e.g. [A, B] should be read in every case as [..., A, B].
*leading_dims, sequence_length, _ = query.shape
projection = self._linear_projection
# Compute key/query/values (overload K/Q/V to denote the respective sizes).
query_heads = projection(query, self.key_size, "query") # [T', H, Q=K]
key_heads = projection(key, self.key_size, "key") # [T, H, K]
value_heads = projection(value, self.value_size, "value") # [T, H, V]
# Compute attention weights.
attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads)
attn_logits = attn_logits / onp.sqrt(self.key_size).astype(key.dtype)
if mask is not None:
if mask.ndim != attn_logits.ndim:
raise ValueError(
f"Mask dimensionality {mask.ndim} must match logits dimensionality "
f"{attn_logits.ndim}."
)
attn_logits = jnp.where(mask, attn_logits, -1e30)
attn_weights = jax.nn.softmax(attn_logits) # [H, T', T]
# Weight the values by the attention and flatten the head vectors.
attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
attn = jnp.reshape(attn, (*leading_dims, sequence_length, -1)) # [T', H*V]
# Apply another projection to get the final embeddings.
final_projection = hk.Linear(
self.model_size,
w_init=self.w_init,
with_bias=self.with_bias,
b_init=self.b_init,
)
return final_projection(attn) # [T', D']
@hk.transparent
def _linear_projection(
self,
x: jax.Array,
head_size: int,
name: Optional[str] = None,
) -> jax.Array:
y = hk.Linear(
self.num_heads * head_size,
w_init=self.w_init,
with_bias=self.with_bias,
b_init=self.b_init,
name=name,
)(x)
*leading_dims, _ = x.shape
return y.reshape((*leading_dims, self.num_heads, head_size))
@dataclasses.dataclass
class Transformer(hk.Module):
"""A transformer stack, adapted from DM Haiku example implementation.
NOTE: Dropout is turned off in this model, it's
just a dummy parameter for now.
"""
num_heads: int
num_layers: int
key_size: int
dropout_rate: float
widening_factor: int = 4
name: Optional[str] = None
def __call__(
self,
embeddings: jax.Array, # [B, T, D]
mask: jax.Array, # [B, T]
*,
is_training: bool = True,
) -> jax.Array: # [B, T, D]
"""Transforms input embedding sequences to output embedding sequences."""
initializer = hk.initializers.VarianceScaling(2 / self.num_layers)
# Dropout is disabled in the optimizer
# dropout_rate = self.dropout_rate if is_training else 0.0
seq_len, model_size = embeddings.shape
# Compute bidirectional mask
mask = mask[None, None, :] # [B, H=1, T'=1, T]
bidirectional_mask = onp.ones((1, seq_len, seq_len))
mask = mask * bidirectional_mask # [B, H=1, T, T]
h = embeddings
for _ in range(self.num_layers):
# First the attention block.
attn_block = MultiHeadAttention(
num_heads=self.num_heads,
key_size=self.key_size,
model_size=model_size,
w_init=initializer,
)
h_norm = layer_norm(h)
h_attn = attn_block(h_norm, h_norm, h_norm, mask=mask)
# h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
h = h + h_attn
# Then the dense block.
dense_block = hk.Sequential(
[
hk.Linear(self.widening_factor * model_size, w_init=initializer),
jax.nn.gelu,
hk.Linear(model_size, w_init=initializer),
]
)
h_norm = layer_norm(h)
h_dense = dense_block(h_norm)
# h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
h = h + h_dense
return layer_norm(h)
@dataclasses.dataclass
class EncoderModel(hk.Module):
"""A transformer encoder model."""
transformer: Transformer
model_size: int
name: Optional[str] = None
def __call__(
self,
tokens: jax.Array,
hidden_state: jax.Array,
*,
is_training: bool = True,
) -> jax.Array:
"""Forward pass of transformer.
N: number of tensors in a neural net.
D: size of a tensor feature vector.
H: hidden size of transformer aka `model_size`.
Args:
tokens (jax.Array): Tensor features with shape (N, D).
hidden state (jax.Array): Hidden state vector with shape (D).
is_training (bool, optional): Training mode. Defaults to True.
Returns:
jax.Array: Transformed tensor embeddings with shape (N, H).
"""
# Embed the input tokens and positions.
embed_token = hk.Linear(self.model_size)
token_embeddings = embed_token(tokens)
# We just add 'token type' positinal embedding to separate hidden state
# from tensor tokens
embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
positional_embeddings = hk.get_parameter(
"positional_embeddings", [2, self.model_size], init=embed_init
)
input_embeddings = token_embeddings + positional_embeddings[1] # [B, T, D]
hidden_embedding = hidden_state[None, :] + positional_embeddings[0]
input_embeddings = jnp.concatenate([hidden_embedding, token_embeddings], axis=0)
input_mask = jnp.ones(input_embeddings.shape[:-1])
# Run the transformer over the inputs.
output_embeddings = self.transformer(
input_embeddings,
input_mask,
is_training=is_training,
) # [B, T, D]
hidden_state = output_embeddings[0]
embeddings = output_embeddings[1:]
return embeddings, hidden_state
def _second_moment_normalizer(x, axis, eps=1e-5):
return x * lax.rsqrt(eps + jnp.mean(jnp.square(x), axis=axis, keepdims=True))
def _sin_embedding(iteration: jnp.ndarray) -> jnp.ndarray:
"""Embed the inner-training iteration with sin of various frequency."""
def one_freq(timescale):
return jnp.sin(iteration / (jnp.float32(timescale) * jnp.pi))
timescales = jnp.asarray(
[1, 3, 10, 30, 100, 300, 1000, 3000, 10000, 30000, 100000], dtype=jnp.float32
)
return jax.vmap(one_freq)(timescales)
@flax.struct.dataclass
class _LossNormalizerState:
mean: jnp.ndarray
var: jnp.ndarray
updates: jnp.ndarray
class _LossNormalizer:
"""Tracks loss through time and normalizes to a similar range across tasks."""
def __init__(self, decay: float):
self.decay = decay
def init(self) -> _LossNormalizerState:
return _LossNormalizerState(
mean=jnp.asarray(0.0), var=jnp.asarray(0.0), updates=jnp.int32(0)
)
def next_state(self, state: _LossNormalizerState, loss: jnp.ndarray) -> _LossNormalizerState:
new_mean = self.decay * state.mean + (1.0 - self.decay) * loss
new_var = self.decay * state.var + (1.0 - self.decay) * jnp.square(new_mean - loss)
new_updates = state.updates + 1
return _LossNormalizerState(mean=new_mean, var=new_var, updates=new_updates)
def weight_loss(self, state: _LossNormalizerState, loss: jnp.ndarray) -> jnp.ndarray:
c = 1.0 / (1 - self.decay ** jnp.asarray(state.updates, jnp.float32) + 1e-8)
cor_mean = state.mean * c
cor_var = state.var * c
l = (loss - cor_mean) * lax.rsqrt(cor_var + 1e-8)
return jnp.clip(l, -5, 5)
def corrected_mean(self, state: _LossNormalizerState) -> jnp.ndarray:
c = 1.0 / (1 - self.decay ** jnp.asarray(state.updates, jnp.float32) + 1e-7)
return state.mean * c
def _avg_square_mean(tree: Any) -> jnp.ndarray:
return sum([jnp.mean(jnp.square(x)) for x in jax.tree_util.tree_leaves(tree)]) / len(
jax.tree_util.tree_leaves(tree)
)
def _clip_log_abs(value: jnp.ndarray) -> jnp.ndarray:
mag = jnp.log(1e-8 + jnp.abs(value))
return jnp.clip(mag, -5, 5)
def _sorted_values(dd):
return list(zip(*sorted(dd.items(), key=lambda x: x[0])))[1]
def _unstack(a: jnp.ndarray, axis: int = 0) -> Sequence[jnp.ndarray]:
"""The opposite of jnp.stack()."""
shape = a.shape
return [jnp.squeeze(b, axis=axis) for b in jnp.split(a, shape[axis], axis=axis)]
@flax.struct.dataclass
class _DynamicGradientClipperState:
iteration: jnp.ndarray
value: jnp.ndarray
class _DynamicGradientClipper:
"""Keep track of gradient norms and clip gradients to reasonable range."""
def __init__(self, alpha: float = 0.99, clip_mult: float = 10.0):
self.alpha = alpha
self.clip_mult = clip_mult
def initial_state(self) -> _DynamicGradientClipperState:
return _DynamicGradientClipperState(
jnp.asarray(1, dtype=jnp.float32),
jnp.asarray(1.0, dtype=jnp.float32) * (1 - self.alpha),
)
def _normalize(
self, state: _DynamicGradientClipperState, grads: opt_base.Params
) -> opt_base.Params:
t, snd = state.iteration, state.value
clip_amount = (snd / (1 - self.alpha**t)) * self.clip_mult
summary.summary("dynamic_grad_clip", clip_amount)
return jax.tree_util.tree_map(lambda g: jnp.clip(g, -clip_amount, clip_amount), grads)
def next_state_and_normalize(
self, state: _DynamicGradientClipperState, grads: opt_base.Params
) -> Tuple[_DynamicGradientClipperState, opt_base.Params]:
t, snd = state.iteration, state.value
clipped_grads = self._normalize(state, grads)
avg_squared_mean = _avg_square_mean(clipped_grads)
new_snd_moment = jnp.sqrt(1e-8 + avg_squared_mean)
next_snd = snd * self.alpha + new_snd_moment * (1.0 - self.alpha)
return _DynamicGradientClipperState(t + 1, next_snd), clipped_grads
@flax.struct.dataclass
class LOptState:
"""State used to train a Task / inner-problem."""
params: opt_base.Params
mom_rolling: common.MomAccumulator
rms_rolling: common.RMSAccumulator
iteration: jnp.ndarray
state: Optional[opt_base.ModelState]
tx_hidden_state: Any
from_mlp: Any
train_loss_accum: Any
valid_loss_accum: _LossNormalizerState
dynamic_clip: _DynamicGradientClipperState
@gin.configurable
class TxLOpt(lopt_base.LearnedOptimizer):
"""Learned optimizer with a transformer and per param MLP.
See top level doc string for more information.
"""
def __init__(
self,
step_multiplier: float = 0.001,
magnitude_rate: float = 0.001,
hidden_size: int = 32,
hidden_layer: int = 2,
from_mlp_size: int = 16,
tx_to_ff: int = 17,
tx_hidden_size: int = 64,
num_heads: int = 4,
num_layers: int = 4,
decays: Sequence[float] = (0.5, 0.9, 0.99, 0.999, 0.9999),
):
self.step_multiplier = step_multiplier
self.magnitude_rate = magnitude_rate
self.hidden_size = hidden_size
self.hidden_layer = hidden_layer
self.from_mlp_size = from_mlp_size
self.tx_to_ff = tx_to_ff
self.tx_hidden_size = tx_hidden_size
self.decays = jnp.asarray(decays)
self.num_heads = num_heads
self.num_layers = num_layers
def _per_param_mlp_network(inp):
hiddens = [hidden_size] * hidden_layer + [2 + from_mlp_size]
return hk.nets.MLP(hiddens)(inp)
self.per_param_mlp_network = hk.without_apply_rng(hk.transform(_per_param_mlp_network))
self.tx_to_mlp_network = hk.without_apply_rng(
hk.transform(lambda x: hk.Linear(tx_to_ff, name="tx_to_ff")(x))
)
def _forward_tx(tokens, state):
tx = Transformer(
num_heads=num_heads,
num_layers=num_layers,
key_size=32,
dropout_rate=0.1,
)
encoder = EncoderModel(model_size=tx_hidden_size, transformer=tx)
return encoder(tokens, state)
self.tx_network = hk.without_apply_rng(hk.transform(_forward_tx))
def initial_state(hidden_size) -> jax.Array:
embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
hidden_embedding = hk.get_parameter("hidden_embedding", [hidden_size], init=embed_init)
return hidden_embedding
self.initial_state = initial_state
def init(self, key) -> lopt_base.MetaParams:
"""Initialization of the meta-parameters."""
key1, key2, key3, key4, key5 = jax.random.split(key, 5)
# To create the weights of the transformer, we must know the number of inputs created
# by the `features_for_tensor` function.
tensor_features = 18
tx_inp_size = tensor_features + self.from_mlp_size
# To create the weights of the MLP we must know the number of inputs created
# by the `mlp_features_per_tensor` function.
feed_forward_features = 37
mlp_inp_size = feed_forward_features + self.tx_to_ff
_, var_init = hk.transform(hk.initializers.VarianceScaling())
initial_state_fn = hk.transform(self.initial_state)
params = initial_state_fn.init(key2, self.tx_hidden_size)
tx_initial_state = initial_state_fn.apply(params, None, self.tx_hidden_size)
return {
"initial_from_mlp": var_init(None, key1, [self.from_mlp_size], dtype=jnp.float32),
"tx_init_state": tx_initial_state,
"tx_params": self.tx_network.init(key3, jnp.zeros([1, tx_inp_size]), tx_initial_state),
"tx_to_ff_params": self.tx_to_mlp_network.init(
key4, jnp.zeros([0, self.tx_hidden_size])
),
"ffmod_params": self.per_param_mlp_network.init(key5, jnp.zeros([0, mlp_inp_size])),
}
def opt_fn(self, theta: lopt_base.MetaParams, is_training: bool = False) -> opt_base.Optimizer:
vec_roll_rms = common.vec_rolling_rms(self.decays)
vec_roll_mom = common.vec_rolling_mom(self.decays)
valid_loss_normalizer = _LossNormalizer(0.95)
train_loss_normalizer = _LossNormalizer(0.9)
dynamic_gradient_clip = _DynamicGradientClipper()
parent = self
class _Opt(opt_base.Optimizer):
"""Optimizer which contains meta-parameters."""
def __init__(self, theta: lopt_base.MetaParams):
super().__init__()
self.theta = theta
def init(
self,
params: opt_base.Params,
model_state: Optional[opt_base.ModelState] = None,
num_steps: Optional[jnp.ndarray] = None,
key: Optional[PRNGKey] = None,
) -> LOptState:
# n_states: number of tensors in the optimizee net
n_states = len(jax.tree_util.tree_leaves(params))
tx_hidden_state = self.theta["tx_init_state"]
from_mlp = jax.tree_util.tree_map(lambda x: self.theta["initial_from_mlp"], params)
return LOptState(
params=params,
mom_rolling=vec_roll_mom.init(params),
rms_rolling=vec_roll_rms.init(params),
iteration=jnp.asarray(0, dtype=jnp.int32),
state=model_state,
tx_hidden_state=tx_hidden_state,
from_mlp=from_mlp,
train_loss_accum=valid_loss_normalizer.init(),
valid_loss_accum=train_loss_normalizer.init(),
dynamic_clip=dynamic_gradient_clip.initial_state(),
)
def features_for_tensor(
self,
ms: jnp.ndarray,
rms: jnp.ndarray,
g: jnp.ndarray,
v: jnp.ndarray,
from_mlp: jnp.ndarray,
train_loss_feat: jnp.ndarray,
valid_loss_feat: jnp.ndarray,
) -> Sequence[jnp.ndarray]:
"""Compute per-tensor features.
This function is called once per tensor.
Args:
ms: momentum accumulators
rms: second moment accumulators
g: gradient value
v: parameter vaule
from_mlp: conditioning value sent from per-param mlp.
train_loss_feat: Array which contains featurized train loss
valid_loss_feat: Array which contains featurized valid loss
Returns:
A list of features. Each feature is a vector.
"""
inputs = {}
mean_ms = jnp.mean(ms)
inputs["mean_ms_mag"] = _clip_log_abs(mean_ms)
inputs["mean_ms_sign"] = jnp.sign(mean_ms)
var_ms = jnp.mean(jnp.square(ms - mean_ms))
inputs["var_ms"] = _clip_log_abs(var_ms)
mean_rms = jnp.mean(rms)
inputs["mean_rms"] = _clip_log_abs(mean_rms)
inputs["mean_sign"] = jnp.sign(mean_rms)
var_rms = jnp.mean(jnp.square(rms - mean_rms))
inputs["var_rms"] = _clip_log_abs(var_rms)
mean_v = jnp.mean(v)
inputs["mean_v_mag"] = _clip_log_abs(mean_v)
inputs["mean_v_sign"] = jnp.sign(mean_v)
var_v = jnp.mean(jnp.square(v - mean_v))
inputs["var_v"] = _clip_log_abs(var_v)
inputs["norm_weight"] = _clip_log_abs(jnp.linalg.norm(v))
g_norm = jnp.linalg.norm(g)
inputs["g_norm"] = _clip_log_abs(g_norm)
inputs["is_scalar"] = jnp.asarray(
1.0 if len(v.shape) == 0 else -1.0
) # pylint: disable=g-explicit-length-test
extra_dims = [1.0] * (4 - len(v.shape))
shape_stack = jnp.concatenate(
[onp.asarray(v.shape, jnp.float32), jnp.asarray(extra_dims)], axis=0
)
for j in range(4):
# Shift so that these are closer to zero mean.
inputs["shape_%d" % j] = jnp.log(shape_stack)[j] - 1.0
# Features from training loss
inputs["train_loss_feat"] = train_loss_feat
inputs["valid_loss_feat"] = valid_loss_feat
# Features from lower level MLP
inputs["from_mlp"] = from_mlp
values = _sorted_values(inputs)
reshaped = [
jnp.expand_dims(v, 0) if len(v.shape) == 0 else v
for v in values # pylint: disable=g-explicit-length-test
]
return reshaped
def mlp_features_for_tensor(
self,
m: jnp.ndarray,
rms: jnp.ndarray,
g: jnp.ndarray,
v: jnp.ndarray,
ff_inputs: jnp.ndarray,
training_step: jnp.ndarray,
num_tensors: jnp.ndarray,
) -> jnp.ndarray:
flat_g = jnp.reshape(g, [-1, 1])
# These have a trailing dim of decays. We want to reshape them so that
# they have the leading dimensions flattened.
rms = jnp.reshape(rms, [int(onp.prod(rms.shape[0:-1])), rms.shape[-1]])
m = jnp.reshape(m, [int(onp.prod(m.shape[0:-1])), m.shape[-1]])
rsqrt = lax.rsqrt(rms + 1e-6)
rms_scaled_g = m * rsqrt
flat_v = jnp.reshape(v, [-1, 1])
# Per component features
inps = {}
inps["flat_g"] = flat_g
inps["flat_v"] = flat_v
inps["log_abs_v"] = jnp.log(jnp.abs(flat_v) + 1e-8)
inps["m"] = m
inps["rms_scaled_g"] = rms_scaled_g
inps["rms"] = rms
inps["rsqrt"] = rsqrt
# Stack the values to form one vector which we normalize
inp = jnp.concatenate(_sorted_values(inps), 1)
# Normalize across all the values of the tensor.
inp = _second_moment_normalizer(inp, axis=0)
step = _sin_embedding(training_step)
stack_step = jnp.tile(jnp.reshape(step, [1, -1]), onp.asarray([flat_g.shape[0], 1]))
# These are all featuers that are computed across the tensor. We tile
# them to be able to pass them into the MLP
# Subtract 1. to at least attempt to zero center.
log_num_tensors = jnp.log(float(num_tensors)) - 1.0
stack_num_tensors = jnp.tile(
jnp.reshape(log_num_tensors, [1, 1]), [flat_g.shape[0], 1]
)
# Feature based on the norm of the parameters -- this should not be
# normalized as we care about absolute magnitude
log_norm = jnp.log(jnp.linalg.norm(flat_v) + 1e-8)
stack_log_norm = jnp.tile(jnp.reshape(log_norm, [1, 1]), [flat_g.shape[0], 1])
# Feature which is number of parameters in the current layer
log_n_weight = jnp.log(float(flat_v.shape[0]))
stack_log_n_weight = jnp.tile(
jnp.reshape(log_n_weight, [1, 1]), [flat_g.shape[0], 1]
)
ff_inp = jnp.tile(jnp.reshape(ff_inputs, [1, -1]), [flat_g.shape[0], 1])
# Stack up all the features
return jnp.concatenate(
[
inp,
stack_step,
stack_num_tensors,
stack_log_norm,
stack_log_n_weight,
ff_inp,
],
axis=1,
)
def update(
self,
opt_state: LOptState,
grads,
loss: Optional[jnp.ndarray] = None,
model_state: Optional[opt_base.ModelState] = None,
is_valid: bool = False,
key: Optional[PRNGKey] = None,
**kwargs,
) -> LOptState:
"""Perform a single inner-problem update."""
if loss is None:
raise ValueError("This optimizer must be called with a loss!")
# Instead of doing jax.lax.cond to swap implementations,
# we will run both computations and select one. This is required to get
# summaries to work through a cond. This is fine as the validation path
# is quite cheap.
opt1 = self.update_is_valid(opt_state, loss)
opt2 = self.update_is_training(opt_state, grads, loss, model_state)
return jax.lax.cond(is_valid, lambda _: opt1, lambda _: opt2, ())
def update_is_valid(self, opt_state, loss) -> LOptState:
# When computing an update with vaidation data, all we do is update the
# validation loss.
next_valid_loss_accum = valid_loss_normalizer.next_state(
opt_state.valid_loss_accum, loss
)
next_opt_state = opt_state.replace(
iteration=opt_state.iteration + 1,
valid_loss_accum=next_valid_loss_accum,
)
return tree_utils.match_type(next_opt_state, opt_state)
def update_is_training(self, opt_state, grads, loss, model_state) -> LOptState:
theta = self.theta
# Update the training loss.
next_train_loss_accum = train_loss_normalizer.next_state(
opt_state.train_loss_accum, loss
)
# Compute various loss features
train_loss_feat = train_loss_normalizer.weight_loss(next_train_loss_accum, loss)
valid_loss = valid_loss_normalizer.corrected_mean(opt_state.valid_loss_accum)
valid_loss_feat = train_loss_normalizer.weight_loss(
next_train_loss_accum, valid_loss
)
summary.summary("valid_loss", valid_loss)
# Clip and update gradient clipper
(
next_dynamic_clip,
grads,
) = dynamic_gradient_clip.next_state_and_normalize(opt_state.dynamic_clip, grads)
next_mom_rolling = vec_roll_mom.update(opt_state.mom_rolling, grads)
next_rms_rolling = vec_roll_rms.update(opt_state.rms_rolling, grads)
ms = next_mom_rolling.m
rms = next_rms_rolling.rms
param_tree = jax.tree_util.tree_structure(ms)
def to_map_per_tensor(ms, rms, g, v, from_mlp):
return self.features_for_tensor(
ms, rms, g, v, from_mlp, train_loss_feat, valid_loss_feat
)
tree_args = (ms, rms, grads, opt_state.params, opt_state.from_mlp)
flat_args = [jax.tree_util.tree_leaves(a) for a in tree_args]
stacked_inp_tree = jax.tree_util.tree_map(to_map_per_tensor, *flat_args)
# We stack all the different tensors together so that we can run the
# transformer only once.
tx_inputs = jnp.stack([jnp.concatenate(v, axis=0) for v in stacked_inp_tree])
# Run the transformer on the features
tx_out, next_tx_hidden_state = parent.tx_network.apply(
theta["tx_params"], tx_inputs, opt_state.tx_hidden_state
)
# Compute values passed from the transformer into the FF network.
ff_inputs = parent.tx_to_mlp_network.apply(theta["tx_to_ff_params"], tx_out)
# These need to be unstacked as they are currently concatenated
ff_inputs = _unstack(ff_inputs)
# And need to be converted back to a parameter tree structure.
ff_inputs = jax.tree_util.tree_unflatten(param_tree, ff_inputs)
num_tensors = len(jax.tree_util.tree_leaves(opt_state.params))
def to_map_get_mlp_features(m, rms, g, v, ff_inputs):
return self.mlp_features_for_tensor(
m,
rms,
g,
v,
ff_inputs, # pytype: disable=wrong-arg-types # jax-ndarray
opt_state.iteration,
num_tensors,
)
# Prep the features
ff_feats = jax.tree_util.tree_map(
to_map_get_mlp_features, ms, rms, grads, opt_state.params, ff_inputs
)
# Apply the per parameter mlp on these features.
outputs = jax.tree_util.tree_map(
functools.partial(parent.per_param_mlp_network.apply, theta["ffmod_params"]),
ff_feats,
)
# Split apart the outputs and create both the next parameters, and the
# inputs needed for the next learned optimizer application.
new_params = []
from_mlp = []
for o, v in zip(
jax.tree_util.tree_leaves(outputs),
jax.tree_util.tree_leaves(opt_state.params),
):
direction = o[:, 0:1]
magnitude = o[:, 1:2]
step = (
direction
* jnp.exp(magnitude * parent.magnitude_rate)
* parent.step_multiplier
)
step = step.reshape(v.shape)
new_params.append(v - step)
to_tx = jnp.mean(o[:, 2:], axis=0)
from_mlp.append(to_tx)
# Convert these structures back to match the parameter tree.
new_params = jax.tree_util.tree_unflatten(param_tree, new_params)
new_from_mlp = jax.tree_util.tree_unflatten(param_tree, from_mlp)
# Finally, package all these values up and return.
next_opt_state = LOptState(
params=new_params,
mom_rolling=next_mom_rolling,
rms_rolling=next_rms_rolling,
iteration=opt_state.iteration + 1,
state=model_state,
tx_hidden_state=next_tx_hidden_state,
from_mlp=new_from_mlp,
train_loss_accum=next_train_loss_accum,
valid_loss_accum=opt_state.valid_loss_accum,
dynamic_clip=next_dynamic_clip,
)
return tree_utils.match_type(next_opt_state, opt_state)
return _Opt(theta)
@amoudgl
Copy link
Author

amoudgl commented Jan 3, 2024

If you find this work useful, please cite:

@inproceedings{moudgil2023learning,
  title={Learning to Optimize with Recurrent Hierarchical Transformers},
  author={Moudgil, Abhinav and Knyazev, Boris and Lajoie, Guillaume and Belilovsky, Eugene},
  booktitle={ICML Workshop on New Frontiers in Learning, Control, and Dynamical Systems},
  year={2023}
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment