Skip to content

Instantly share code, notes, and snippets.

@petered
Created September 12, 2018 14:47
Show Gist options
  • Save petered/8a5dbd500092868b4877bf2f9743c4c9 to your computer and use it in GitHub Desktop.
Save petered/8a5dbd500092868b4877bf2f9743c4c9 to your computer and use it in GitHub Desktop.
2018-09-12 Why Alignment?
<p><script type="math/tex; mode=display" id="MathJax-Element-1">
\newcommand{\pderiv}[2]{\frac{\partial #1}{\partial #2}}
\newcommand{\lderiv}[1]{\frac{\partial \mathcal L}{\partial #1}}
\newcommand{\argmax}[1]{\underset{#1}{\operatorname{argmax}}}
\newcommand{\argmin}[1]{\underset{#1}{\operatorname{argmin}}}
\newcommand{\switch}[3]{\begin{cases} #2 & \text{if } {#1} \\ #3 &\text{otherwise}\end{cases}}
\newcommand{\blue}[1]{\color{blue}{#1}}
\newcommand{\red}[1]{\color{red}{#1}}
\newcommand{\overlabel}[2]{\overset{#1}{\overline{#2}}}
\newcommand{\loss}[0]{\mathcal L}
</script></p>
<p>This exploration started when we were trying to optimize a network with layerwise targets. We observed that we were getting the same performance when just optimizing local losses as were were from optimizing global loss.</p>
<h1 id="the-alignment-surprise">The Alignment Surprise:</h1>
<p>We have a network defined by: </p>
<p><script type="math/tex; mode=display" id="MathJax-Element-2">\begin{align}
s_1 &= h(x, \phi_1) \\
s_2 &= h(s_1, \phi_2) \\
... \\
s_L &= h(s_{L-1}, \phi_L) \\
\end{align}</script></p>
<p>We then define a set of layerwise targets <script type="math/tex" id="MathJax-Element-3">s_1^*, ... s_L^*</script>, and define a loss according to the distance from these targets:</p>
<p><script type="math/tex; mode=display" id="MathJax-Element-4">
\mathcal L = \sum_{i=1}^L \ell_i = \sum_{i=1}^L\|s_i - s_i^*\|^2
</script></p>
<p>For a given layer’s parameters <script type="math/tex" id="MathJax-Element-5">\phi_i</script>, the loss-gradient is the sum of a “local” and “distant” components: </p>
<p><script type="math/tex; mode=display" id="MathJax-Element-6">\begin{align}
\lderiv{\phi_i} &= \overlabel{global}{ \overlabel{local}{\pderiv{\ell_i}{\phi_i}} + \sum_{j:j>i} \overlabel{distant}{\pderiv{\ell_j}{\phi_i}}}
= \overlabel{global}{ \overlabel{local}{\pderiv{\ell_i}{s_i}\pderiv{s_i}{\phi_i}} + \sum_{j:j>i} \overlabel{distant}{\pderiv{\ell_j}{s_j}\pderiv{s_j}{s_i}\pderiv{\mathcal s_i}{\phi_i}}}
\end{align}</script></p>
<p>The <strong>surprise</strong> is that we empirically find that not only does our <em>local</em> gradient usually align with the <em>global</em> gradient, but also with the <em>distant</em> gradients. </p>
<p>i.e. <br>
<script type="math/tex; mode=display" id="MathJax-Element-7">
\mathcal S\left(\pderiv{\ell_i}{\phi_i}, \pderiv{\ell_j}{\phi_i}\right) \overset{usually}> 0 : j>i
</script> <br>
Where <script type="math/tex" id="MathJax-Element-8">\mathcal S</script> is the cosine-similarity metric. For example, with a randomly generated twenty-layer network with random targets, it is almost always the case that all <script type="math/tex" id="MathJax-Element-9">\frac{19\cdot 20}{2}</script> pairwise inter-layer gradients are positively aligned. The strength of the similarity ranges from ~0.65 when <script type="math/tex" id="MathJax-Element-10">j=i+1</script> to ~0.14 when <script type="math/tex" id="MathJax-Element-11">j=i+19</script>.</p>
<p>This is unintuitive: Why should the gradient due to the losses in downstream layers happen to align with the gradient of the local layer’s loss? Remember the “targets” we are using for the loss are just randomly generated points.</p>
<h1 id="understanding-the-alignment">Understanding the alignment</h1>
<p>Suppose our we have a two layer linear network with randomly-drawn parameters <script type="math/tex" id="MathJax-Element-12">w</script> which produces activations <script type="math/tex" id="MathJax-Element-13">s</script>:</p>
<p><script type="math/tex; mode=display" id="MathJax-Element-14">\begin{align}
s_1 &= w_1 x \\
s_2 &= w_2 s_1
\end{align}</script></p>
<p>We randomly generate targets <script type="math/tex" id="MathJax-Element-15">s^*</script>, and define layerwise losses:</p>
<p><script type="math/tex; mode=display" id="MathJax-Element-16">\begin{align}
\ell_1 = \|s_1- s_1^* \|^2 \\
\ell_2 = \|s_2 - s_2^* \|^2
\end{align}</script></p>
<p>We observe that:</p>
<p><script type="math/tex; mode=display" id="MathJax-Element-17">
\mathcal S\left(\pderiv{\ell_1}{w_1}, \pderiv{\ell_2}{w_1}\right) \overset{usually}> 0
</script> <br>
Where <script type="math/tex" id="MathJax-Element-18">\mathcal S</script> is the cosine-similarity. It seems that our “local” gradient <script type="math/tex" id="MathJax-Element-19">\pderiv{\ell_1}{w_1}</script> tends to be aligned with our “distant” gradient <script type="math/tex" id="MathJax-Element-20">\pderiv{\ell_2}{w_1}</script>.</p>
<h1 id="so-what-is-going-on">So what is going on?</h1>
<p>Lets write out the gradients:</p>
<p><script type="math/tex; mode=display" id="MathJax-Element-21">\begin{align}
local: \pderiv{\ell_1}{w_1}^T &= \pderiv{\ell_1}{s_1} \pderiv{s_1}{w_1} \\
&= (s_1 - s_1^*)^T x \\
&= \overlabel{internal_1}{s_1^T x} - \overlabel{external_1}{s_1^{*T} x} \\
distant: \pderiv{\ell_2}{w_1}^T &= \pderiv{\ell_1}{s_2} \pderiv{s_2}{s_1} \pderiv{s_1}{w_1} \\
&= ((s_2 - s_2^*) \cdot w_2^T)^T x \\
&= \overlabel{internal_2}{((s_1 \cdot w_2) \cdot w_2^T)^T x} - \overlabel{external_2}{(s_2^* w_2^T)^T x} \\
\end{align}</script></p>
<p>Here we can see the cause of the alignment. Each gradient is composed of two terms - and <em>internal</em> term which just depends on the input, and an <em>external</em> term which also depends on the target. While the <em>external</em> terms are multiplied by the (arbitrary) target, the <em>internal</em> terms of the two loss-gradients tend to be aligned.</p>
<p>Namely, if <script type="math/tex" id="MathJax-Element-22">w_2</script> acts like an autoencoder for inputs <script type="math/tex" id="MathJax-Element-23">s_1</script>, then <script type="math/tex" id="MathJax-Element-24">s_1w_2 w_2 ^T \approx s_1</script>, and therefore the internal terms of the two loss-gradients align: <script type="math/tex" id="MathJax-Element-25">s_1 w_2 w_2^T x \propto s_1 x</script>. </p>
<p>The alignment happens so long as:</p>
<p><script type="math/tex; mode=display" id="MathJax-Element-26">\begin{align}
\left(\overlabel{in_1}{s_1^T x} - \overlabel{ex_1}{s_1^{*T} x},\right)\cdot \left( \overlabel{in_2}{((s_1 \cdot w_2) \cdot w_2^T)^T x} - \overlabel{ex_2}{(s_2^* w_2^T)^T x}\right) &> 0 \\
in_1\cdot in_2 - in_1\cdot ex_2 - in_2\cdot ex_2 + ex_2\cdot ex_1 &> 0
\end{align}</script></p>
<p>So alignment is strongest when both internal terms aligned with eachother, the external terms are aligned with eachother, and the internal term from one layer is anti-aligned with the external term from the other layer.</p>
<h1 id="what-does-this-tell-us">What does this tell us?</h1>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment